cse151b-final-project/dqn_wordle.ipynb

1161 lines
55 KiB
Plaintext
Raw Normal View History

2024-03-13 18:04:30 +00:00
{
"cells": [
{
"cell_type": "code",
2024-03-14 22:00:19 +00:00
"execution_count": 1,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n",
2024-03-13 20:57:23 +00:00
"import numpy as np\n",
"import tqdm"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-14 22:00:19 +00:00
"execution_count": 2,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Monitor<WordleEnv instance>>\n"
]
}
],
2024-03-13 18:04:30 +00:00
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"env = common.monitor.Monitor(env)\n",
2024-03-13 18:04:30 +00:00
"\n",
"print(env)"
]
},
{
"cell_type": "code",
2024-03-14 22:00:19 +00:00
"execution_count": 3,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
2024-03-14 22:00:19 +00:00
"Wrapping the env in a DummyVecEnv.\n",
"---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -124 |\n",
"| time/ | |\n",
"| fps | 784 |\n",
"| iterations | 1 |\n",
"| time_elapsed | 2 |\n",
"| total_timesteps | 2048 |\n",
"---------------------------------\n",
"------------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 654 |\n",
"| iterations | 2 |\n",
"| time_elapsed | 6 |\n",
"| total_timesteps | 4096 |\n",
"| train/ | |\n",
"| approx_kl | 0.0065368367 |\n",
"| clip_fraction | 0.027 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.00416 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 2.17e+03 |\n",
"| n_updates | 10 |\n",
"| policy_gradient_loss | -0.0517 |\n",
"| value_loss | 5.21e+03 |\n",
"------------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 612 |\n",
"| iterations | 3 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 6144 |\n",
"| train/ | |\n",
"| approx_kl | 0.006391342 |\n",
"| clip_fraction | 0.0312 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.0923 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 1.97e+03 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0487 |\n",
"| value_loss | 4.27e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -124 |\n",
"| time/ | |\n",
"| fps | 590 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 13 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.007267303 |\n",
"| clip_fraction | 0.0477 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.0297 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 1.41e+03 |\n",
"| n_updates | 30 |\n",
"| policy_gradient_loss | -0.0542 |\n",
"| value_loss | 3.64e+03 |\n",
"-----------------------------------------\n",
"------------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -124 |\n",
"| time/ | |\n",
"| fps | 579 |\n",
"| iterations | 5 |\n",
"| time_elapsed | 17 |\n",
"| total_timesteps | 10240 |\n",
"| train/ | |\n",
"| approx_kl | 0.0079803215 |\n",
"| clip_fraction | 0.0604 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.01 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 1.19e+03 |\n",
"| n_updates | 40 |\n",
"| policy_gradient_loss | -0.0581 |\n",
"| value_loss | 3.17e+03 |\n",
"------------------------------------------\n",
"------------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -120 |\n",
"| time/ | |\n",
"| fps | 572 |\n",
"| iterations | 6 |\n",
"| time_elapsed | 21 |\n",
"| total_timesteps | 12288 |\n",
"| train/ | |\n",
"| approx_kl | 0.0095686875 |\n",
"| clip_fraction | 0.0843 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.00451 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 1.07e+03 |\n",
"| n_updates | 50 |\n",
"| policy_gradient_loss | -0.0646 |\n",
"| value_loss | 2.68e+03 |\n",
"------------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 567 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 25 |\n",
"| total_timesteps | 14336 |\n",
"| train/ | |\n",
"| approx_kl | 0.011415122 |\n",
"| clip_fraction | 0.114 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.00251 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 928 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.0713 |\n",
"| value_loss | 2.17e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5.95 |\n",
"| ep_rew_mean | -120 |\n",
"| time/ | |\n",
"| fps | 563 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 29 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.012646846 |\n",
"| clip_fraction | 0.152 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.00144 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 989 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.0767 |\n",
"| value_loss | 1.92e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 560 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 32 |\n",
"| total_timesteps | 18432 |\n",
"| train/ | |\n",
"| approx_kl | 0.015274222 |\n",
"| clip_fraction | 0.209 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.000909 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 838 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.0848 |\n",
"| value_loss | 1.57e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 558 |\n",
"| iterations | 10 |\n",
"| time_elapsed | 36 |\n",
"| total_timesteps | 20480 |\n",
"| train/ | |\n",
"| approx_kl | 0.017125849 |\n",
"| clip_fraction | 0.269 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.000586 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 634 |\n",
"| n_updates | 90 |\n",
"| policy_gradient_loss | -0.0913 |\n",
"| value_loss | 1.38e+03 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 556 |\n",
"| iterations | 11 |\n",
"| time_elapsed | 40 |\n",
"| total_timesteps | 22528 |\n",
"| train/ | |\n",
"| approx_kl | 0.02106983 |\n",
"| clip_fraction | 0.345 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | -0.000377 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 538 |\n",
"| n_updates | 100 |\n",
"| policy_gradient_loss | -0.0985 |\n",
"| value_loss | 1.25e+03 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 554 |\n",
"| iterations | 12 |\n",
"| time_elapsed | 44 |\n",
"| total_timesteps | 24576 |\n",
"| train/ | |\n",
"| approx_kl | 0.031943925 |\n",
"| clip_fraction | 0.478 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | -0.000226 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 487 |\n",
"| n_updates | 110 |\n",
"| policy_gradient_loss | -0.109 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -120 |\n",
"| time/ | |\n",
"| fps | 553 |\n",
"| iterations | 13 |\n",
"| time_elapsed | 48 |\n",
"| total_timesteps | 26624 |\n",
"| train/ | |\n",
"| approx_kl | 0.05316051 |\n",
"| clip_fraction | 0.627 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | -8.43e-05 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 521 |\n",
"| n_updates | 120 |\n",
"| policy_gradient_loss | -0.128 |\n",
"| value_loss | 1.12e+03 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5.96 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 551 |\n",
"| iterations | 14 |\n",
"| time_elapsed | 51 |\n",
"| total_timesteps | 28672 |\n",
"| train/ | |\n",
"| approx_kl | 0.055725098 |\n",
"| clip_fraction | 0.711 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | -3.3e-05 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 541 |\n",
"| n_updates | 130 |\n",
"| policy_gradient_loss | -0.138 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -127 |\n",
"| time/ | |\n",
"| fps | 550 |\n",
"| iterations | 15 |\n",
"| time_elapsed | 55 |\n",
"| total_timesteps | 30720 |\n",
"| train/ | |\n",
"| approx_kl | 0.057101354 |\n",
"| clip_fraction | 0.73 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | -1.74e-05 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 598 |\n",
"| n_updates | 140 |\n",
"| policy_gradient_loss | -0.141 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 549 |\n",
"| iterations | 16 |\n",
"| time_elapsed | 59 |\n",
"| total_timesteps | 32768 |\n",
"| train/ | |\n",
"| approx_kl | 0.06564396 |\n",
"| clip_fraction | 0.73 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.44 |\n",
"| explained_variance | -1.05e-05 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 506 |\n",
"| n_updates | 150 |\n",
"| policy_gradient_loss | -0.141 |\n",
"| value_loss | 1.17e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -124 |\n",
"| time/ | |\n",
"| fps | 549 |\n",
"| iterations | 17 |\n",
"| time_elapsed | 63 |\n",
"| total_timesteps | 34816 |\n",
"| train/ | |\n",
"| approx_kl | 0.05891238 |\n",
"| clip_fraction | 0.735 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.43 |\n",
"| explained_variance | -7.15e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 564 |\n",
"| n_updates | 160 |\n",
"| policy_gradient_loss | -0.141 |\n",
"| value_loss | 1.15e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -126 |\n",
"| time/ | |\n",
"| fps | 548 |\n",
"| iterations | 18 |\n",
"| time_elapsed | 67 |\n",
"| total_timesteps | 36864 |\n",
"| train/ | |\n",
"| approx_kl | 0.06345081 |\n",
"| clip_fraction | 0.753 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.43 |\n",
"| explained_variance | -5.13e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 625 |\n",
"| n_updates | 170 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.14e+03 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -125 |\n",
"| time/ | |\n",
"| fps | 547 |\n",
"| iterations | 19 |\n",
"| time_elapsed | 71 |\n",
"| total_timesteps | 38912 |\n",
"| train/ | |\n",
"| approx_kl | 0.061060853 |\n",
"| clip_fraction | 0.71 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.42 |\n",
"| explained_variance | -3.7e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 715 |\n",
"| n_updates | 180 |\n",
"| policy_gradient_loss | -0.139 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n",
"---------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -126 |\n",
"| time/ | |\n",
"| fps | 548 |\n",
"| iterations | 20 |\n",
"| time_elapsed | 74 |\n",
"| total_timesteps | 40960 |\n",
"| train/ | |\n",
"| approx_kl | 0.0626598 |\n",
"| clip_fraction | 0.721 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.42 |\n",
"| explained_variance | -2.86e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 665 |\n",
"| n_updates | 190 |\n",
"| policy_gradient_loss | -0.141 |\n",
"| value_loss | 1.17e+03 |\n",
"---------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 547 |\n",
"| iterations | 21 |\n",
"| time_elapsed | 78 |\n",
"| total_timesteps | 43008 |\n",
"| train/ | |\n",
"| approx_kl | 0.066325136 |\n",
"| clip_fraction | 0.743 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.41 |\n",
"| explained_variance | -2.15e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 670 |\n",
"| n_updates | 200 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.18e+03 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 544 |\n",
"| iterations | 22 |\n",
"| time_elapsed | 82 |\n",
"| total_timesteps | 45056 |\n",
"| train/ | |\n",
"| approx_kl | 0.06377452 |\n",
"| clip_fraction | 0.721 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.4 |\n",
"| explained_variance | -1.91e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 695 |\n",
"| n_updates | 210 |\n",
"| policy_gradient_loss | -0.142 |\n",
"| value_loss | 1.15e+03 |\n",
"----------------------------------------\n",
"---------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 543 |\n",
"| iterations | 23 |\n",
"| time_elapsed | 86 |\n",
"| total_timesteps | 47104 |\n",
"| train/ | |\n",
"| approx_kl | 0.0712228 |\n",
"| clip_fraction | 0.734 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.39 |\n",
"| explained_variance | -1.43e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 691 |\n",
"| n_updates | 220 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.15e+03 |\n",
"---------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -124 |\n",
"| time/ | |\n",
"| fps | 542 |\n",
"| iterations | 24 |\n",
"| time_elapsed | 90 |\n",
"| total_timesteps | 49152 |\n",
"| train/ | |\n",
"| approx_kl | 0.06300552 |\n",
"| clip_fraction | 0.737 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.38 |\n",
"| explained_variance | -1.19e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 475 |\n",
"| n_updates | 230 |\n",
"| policy_gradient_loss | -0.144 |\n",
"| value_loss | 1.17e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -126 |\n",
"| time/ | |\n",
"| fps | 540 |\n",
"| iterations | 25 |\n",
"| time_elapsed | 94 |\n",
"| total_timesteps | 51200 |\n",
"| train/ | |\n",
"| approx_kl | 0.07357139 |\n",
"| clip_fraction | 0.738 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.37 |\n",
"| explained_variance | -1.07e-06 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 687 |\n",
"| n_updates | 240 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.18e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -124 |\n",
"| time/ | |\n",
"| fps | 539 |\n",
"| iterations | 26 |\n",
"| time_elapsed | 98 |\n",
"| total_timesteps | 53248 |\n",
"| train/ | |\n",
"| approx_kl | 0.07966857 |\n",
"| clip_fraction | 0.756 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.34 |\n",
"| explained_variance | -8.34e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 408 |\n",
"| n_updates | 250 |\n",
"| policy_gradient_loss | -0.144 |\n",
"| value_loss | 1.17e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 539 |\n",
"| iterations | 27 |\n",
"| time_elapsed | 102 |\n",
"| total_timesteps | 55296 |\n",
"| train/ | |\n",
"| approx_kl | 0.07282317 |\n",
"| clip_fraction | 0.743 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.31 |\n",
"| explained_variance | -5.96e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 471 |\n",
"| n_updates | 260 |\n",
"| policy_gradient_loss | -0.145 |\n",
"| value_loss | 1.16e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 539 |\n",
"| iterations | 28 |\n",
"| time_elapsed | 106 |\n",
"| total_timesteps | 57344 |\n",
"| train/ | |\n",
"| approx_kl | 0.07441577 |\n",
"| clip_fraction | 0.731 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.28 |\n",
"| explained_variance | -5.96e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 501 |\n",
"| n_updates | 270 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.16e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5.97 |\n",
"| ep_rew_mean | -125 |\n",
"| time/ | |\n",
"| fps | 539 |\n",
"| iterations | 29 |\n",
"| time_elapsed | 110 |\n",
"| total_timesteps | 59392 |\n",
"| train/ | |\n",
"| approx_kl | 0.07315491 |\n",
"| clip_fraction | 0.757 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.23 |\n",
"| explained_variance | -4.77e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 569 |\n",
"| n_updates | 280 |\n",
"| policy_gradient_loss | -0.146 |\n",
"| value_loss | 1.13e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -120 |\n",
"| time/ | |\n",
"| fps | 538 |\n",
"| iterations | 30 |\n",
"| time_elapsed | 114 |\n",
"| total_timesteps | 61440 |\n",
"| train/ | |\n",
"| approx_kl | 0.06323205 |\n",
"| clip_fraction | 0.716 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.18 |\n",
"| explained_variance | -3.58e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 612 |\n",
"| n_updates | 290 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.17e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 538 |\n",
"| iterations | 31 |\n",
"| time_elapsed | 117 |\n",
"| total_timesteps | 63488 |\n",
"| train/ | |\n",
"| approx_kl | 0.07022999 |\n",
"| clip_fraction | 0.735 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.11 |\n",
"| explained_variance | -4.77e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 653 |\n",
"| n_updates | 300 |\n",
"| policy_gradient_loss | -0.143 |\n",
"| value_loss | 1.14e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 32 |\n",
"| time_elapsed | 121 |\n",
"| total_timesteps | 65536 |\n",
"| train/ | |\n",
"| approx_kl | 0.07809374 |\n",
"| clip_fraction | 0.734 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.02 |\n",
"| explained_variance | -4.77e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 505 |\n",
"| n_updates | 310 |\n",
"| policy_gradient_loss | -0.14 |\n",
"| value_loss | 1.14e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 33 |\n",
"| time_elapsed | 125 |\n",
"| total_timesteps | 67584 |\n",
"| train/ | |\n",
"| approx_kl | 0.07546057 |\n",
"| clip_fraction | 0.731 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.93 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 549 |\n",
"| n_updates | 320 |\n",
"| policy_gradient_loss | -0.139 |\n",
"| value_loss | 1.14e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 536 |\n",
"| iterations | 34 |\n",
"| time_elapsed | 129 |\n",
"| total_timesteps | 69632 |\n",
"| train/ | |\n",
"| approx_kl | 0.06426198 |\n",
"| clip_fraction | 0.71 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.85 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 546 |\n",
"| n_updates | 330 |\n",
"| policy_gradient_loss | -0.139 |\n",
"| value_loss | 1.16e+03 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 35 |\n",
"| time_elapsed | 133 |\n",
"| total_timesteps | 71680 |\n",
"| train/ | |\n",
"| approx_kl | 0.067716464 |\n",
"| clip_fraction | 0.693 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.76 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 577 |\n",
"| n_updates | 340 |\n",
"| policy_gradient_loss | -0.138 |\n",
"| value_loss | 1.16e+03 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -120 |\n",
"| time/ | |\n",
"| fps | 536 |\n",
"| iterations | 36 |\n",
"| time_elapsed | 137 |\n",
"| total_timesteps | 73728 |\n",
"| train/ | |\n",
"| approx_kl | 0.07348621 |\n",
"| clip_fraction | 0.719 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.72 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 594 |\n",
"| n_updates | 350 |\n",
"| policy_gradient_loss | -0.138 |\n",
"| value_loss | 1.15e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -125 |\n",
"| time/ | |\n",
"| fps | 536 |\n",
"| iterations | 37 |\n",
"| time_elapsed | 141 |\n",
"| total_timesteps | 75776 |\n",
"| train/ | |\n",
"| approx_kl | 0.06314379 |\n",
"| clip_fraction | 0.663 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.69 |\n",
"| explained_variance | -3.58e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 546 |\n",
"| n_updates | 360 |\n",
"| policy_gradient_loss | -0.135 |\n",
"| value_loss | 1.15e+03 |\n",
"----------------------------------------\n",
"---------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 535 |\n",
"| iterations | 38 |\n",
"| time_elapsed | 145 |\n",
"| total_timesteps | 77824 |\n",
"| train/ | |\n",
"| approx_kl | 0.0605906 |\n",
"| clip_fraction | 0.678 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.58 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 569 |\n",
"| n_updates | 370 |\n",
"| policy_gradient_loss | -0.137 |\n",
"| value_loss | 1.15e+03 |\n",
"---------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 535 |\n",
"| iterations | 39 |\n",
"| time_elapsed | 149 |\n",
"| total_timesteps | 79872 |\n",
"| train/ | |\n",
"| approx_kl | 0.06449682 |\n",
"| clip_fraction | 0.682 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.48 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 602 |\n",
"| n_updates | 380 |\n",
"| policy_gradient_loss | -0.134 |\n",
"| value_loss | 1.14e+03 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 535 |\n",
"| iterations | 40 |\n",
"| time_elapsed | 153 |\n",
"| total_timesteps | 81920 |\n",
"| train/ | |\n",
"| approx_kl | 0.05844091 |\n",
"| clip_fraction | 0.633 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.33 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 575 |\n",
"| n_updates | 390 |\n",
"| policy_gradient_loss | -0.13 |\n",
"| value_loss | 1.16e+03 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -120 |\n",
"| time/ | |\n",
"| fps | 535 |\n",
"| iterations | 41 |\n",
"| time_elapsed | 156 |\n",
"| total_timesteps | 83968 |\n",
"| train/ | |\n",
"| approx_kl | 0.049975857 |\n",
"| clip_fraction | 0.601 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -8.15 |\n",
"| explained_variance | -1.19e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 581 |\n",
"| n_updates | 400 |\n",
"| policy_gradient_loss | -0.127 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 534 |\n",
"| iterations | 42 |\n",
"| time_elapsed | 160 |\n",
"| total_timesteps | 86016 |\n",
"| train/ | |\n",
"| approx_kl | 0.055807307 |\n",
"| clip_fraction | 0.616 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -7.91 |\n",
"| explained_variance | 1.19e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 590 |\n",
"| n_updates | 410 |\n",
"| policy_gradient_loss | -0.123 |\n",
"| value_loss | 1.14e+03 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 534 |\n",
"| iterations | 43 |\n",
"| time_elapsed | 164 |\n",
"| total_timesteps | 88064 |\n",
"| train/ | |\n",
"| approx_kl | 0.05585126 |\n",
"| clip_fraction | 0.592 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -7.69 |\n",
"| explained_variance | 0 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 512 |\n",
"| n_updates | 420 |\n",
"| policy_gradient_loss | -0.117 |\n",
"| value_loss | 1.17e+03 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 534 |\n",
"| iterations | 44 |\n",
"| time_elapsed | 168 |\n",
"| total_timesteps | 90112 |\n",
"| train/ | |\n",
"| approx_kl | 0.050130654 |\n",
"| clip_fraction | 0.546 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -7.48 |\n",
"| explained_variance | 0 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 608 |\n",
"| n_updates | 430 |\n",
"| policy_gradient_loss | -0.106 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -122 |\n",
"| time/ | |\n",
"| fps | 534 |\n",
"| iterations | 45 |\n",
"| time_elapsed | 172 |\n",
"| total_timesteps | 92160 |\n",
"| train/ | |\n",
"| approx_kl | 0.050139036 |\n",
"| clip_fraction | 0.547 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -7.23 |\n",
"| explained_variance | 0 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 628 |\n",
"| n_updates | 440 |\n",
"| policy_gradient_loss | -0.103 |\n",
"| value_loss | 1.14e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -126 |\n",
"| time/ | |\n",
"| fps | 533 |\n",
"| iterations | 46 |\n",
"| time_elapsed | 176 |\n",
"| total_timesteps | 94208 |\n",
"| train/ | |\n",
"| approx_kl | 0.045009457 |\n",
"| clip_fraction | 0.495 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -7.04 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 557 |\n",
"| n_updates | 450 |\n",
"| policy_gradient_loss | -0.0923 |\n",
"| value_loss | 1.13e+03 |\n",
"-----------------------------------------\n",
"---------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 533 |\n",
"| iterations | 47 |\n",
"| time_elapsed | 180 |\n",
"| total_timesteps | 96256 |\n",
"| train/ | |\n",
"| approx_kl | 0.0487647 |\n",
"| clip_fraction | 0.499 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -6.85 |\n",
"| explained_variance | -2.38e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 531 |\n",
"| n_updates | 460 |\n",
"| policy_gradient_loss | -0.0868 |\n",
"| value_loss | 1.17e+03 |\n",
"---------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -123 |\n",
"| time/ | |\n",
"| fps | 533 |\n",
"| iterations | 48 |\n",
"| time_elapsed | 184 |\n",
"| total_timesteps | 98304 |\n",
"| train/ | |\n",
"| approx_kl | 0.052709702 |\n",
"| clip_fraction | 0.507 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -6.58 |\n",
"| explained_variance | -1.19e-07 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 547 |\n",
"| n_updates | 470 |\n",
"| policy_gradient_loss | -0.0865 |\n",
"| value_loss | 1.14e+03 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | -121 |\n",
"| time/ | |\n",
"| fps | 533 |\n",
"| iterations | 49 |\n",
"| time_elapsed | 188 |\n",
"| total_timesteps | 100352 |\n",
"| train/ | |\n",
"| approx_kl | 0.048466682 |\n",
"| clip_fraction | 0.511 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -6.52 |\n",
"| explained_variance | 0 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 624 |\n",
"| n_updates | 480 |\n",
"| policy_gradient_loss | -0.0834 |\n",
"| value_loss | 1.15e+03 |\n",
"-----------------------------------------\n"
]
2024-03-14 22:00:19 +00:00
},
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7facf85ac970>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
2024-03-13 18:04:30 +00:00
"source": [
2024-03-14 22:00:19 +00:00
"total_timesteps = 100000\n",
"model = PPO(\"MlpPolicy\", env, verbose=1)\n",
"model.learn(total_timesteps=total_timesteps)"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-14 22:00:19 +00:00
"execution_count": 4,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [],
"source": [
"model.save(\"dqn_wordle\")"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-14 22:00:19 +00:00
"execution_count": 5,
2024-03-13 18:04:30 +00:00
"metadata": {},
2024-03-13 20:57:23 +00:00
"outputs": [],
2024-03-13 18:04:30 +00:00
"source": [
"model = PPO.load(\"dqn_wordle\")"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-14 22:00:19 +00:00
"execution_count": 9,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-14 22:00:19 +00:00
"0\n"
]
}
],
2024-03-13 20:57:23 +00:00
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"\n",
"for i in range(1):\n",
" \n",
" state = env.reset()\n",
"\n",
" done = False\n",
"\n",
" wins = 0\n",
"\n",
" while not done:\n",
"\n",
" action, _states = model.predict(state, deterministic=True)\n",
"\n",
" state, reward, done, info = env.step(action)\n",
"\n",
" if info[\"correct\"]:\n",
" wins += 1\n",
2024-03-14 22:00:19 +00:00
"\n",
"print(wins)\n"
2024-03-13 20:57:23 +00:00
]
2024-03-13 18:04:30 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}