diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb index 1e31b94..2f3f31b 100644 --- a/dqn_wordle.ipynb +++ b/dqn_wordle.ipynb @@ -35,240 +35,495 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7c52630b65904d5e8e200be505d2121a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "Using cpu device\n", - "Wrapping the env in a DummyVecEnv.\n", - "---------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 5.59 |\n", - "| time/ | |\n", - "| fps | 544 |\n", - "| iterations | 1 |\n", - "| time_elapsed | 3 |\n", - "| total_timesteps | 2048 |\n", - "---------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.77 |\n", - "| time/ | |\n", - "| fps | 245 |\n", - "| iterations | 2 |\n", - "| time_elapsed | 16 |\n", - "| total_timesteps | 4096 |\n", - "| train/ | |\n", - "| approx_kl | 0.021515464 |\n", - "| clip_fraction | 0.335 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.00118 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 89.5 |\n", - "| n_updates | 10 |\n", - "| policy_gradient_loss | -0.0854 |\n", - "| value_loss | 262 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.31 |\n", - "| time/ | |\n", - "| fps | 211 |\n", - "| iterations | 3 |\n", - "| time_elapsed | 29 |\n", - "| total_timesteps | 6144 |\n", - "| train/ | |\n", - "| approx_kl | 0.02457875 |\n", - "| clip_fraction | 0.465 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.161 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 118 |\n", - "| n_updates | 20 |\n", - "| policy_gradient_loss | -0.0987 |\n", - "| value_loss | 217 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 5.96 |\n", - "| ep_rew_mean | 5.79 |\n", - "| time/ | |\n", - "| fps | 196 |\n", - "| iterations | 4 |\n", - "| time_elapsed | 41 |\n", - "| total_timesteps | 8192 |\n", - "| train/ | |\n", - "| approx_kl | 0.02515613 |\n", - "| clip_fraction | 0.447 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.151 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 138 |\n", - "| n_updates | 30 |\n", - "| policy_gradient_loss | -0.103 |\n", - "| value_loss | 242 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 4.9 |\n", - "| time/ | |\n", - "| fps | 177 |\n", - "| iterations | 5 |\n", - "| time_elapsed | 57 |\n", - "| total_timesteps | 10240 |\n", - "| train/ | |\n", - "| approx_kl | 0.026685718 |\n", - "| clip_fraction | 0.444 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.176 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 96.7 |\n", - "| n_updates | 40 |\n", - "| policy_gradient_loss | -0.111 |\n", - "| value_loss | 211 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.19 |\n", - "| time/ | |\n", - "| fps | 164 |\n", - "| iterations | 6 |\n", - "| time_elapsed | 74 |\n", - "| total_timesteps | 12288 |\n", - "| train/ | |\n", - "| approx_kl | 0.02762504 |\n", - "| clip_fraction | 0.463 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.186 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 103 |\n", - "| n_updates | 50 |\n", - "| policy_gradient_loss | -0.115 |\n", - "| value_loss | 200 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 5.5 |\n", - "| time/ | |\n", - "| fps | 155 |\n", - "| iterations | 7 |\n", - "| time_elapsed | 92 |\n", - "| total_timesteps | 14336 |\n", - "| train/ | |\n", - "| approx_kl | 0.02694263 |\n", - "| clip_fraction | 0.458 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.15 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 84.1 |\n", - "| n_updates | 60 |\n", - "| policy_gradient_loss | -0.116 |\n", - "| value_loss | 225 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 7.27 |\n", - "| time/ | |\n", - "| fps | 154 |\n", - "| iterations | 8 |\n", - "| time_elapsed | 106 |\n", - "| total_timesteps | 16384 |\n", - "| train/ | |\n", - "| approx_kl | 0.024316464 |\n", - "| clip_fraction | 0.412 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.173 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 126 |\n", - "| n_updates | 70 |\n", - "| policy_gradient_loss | -0.112 |\n", - "| value_loss | 227 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 7.8 |\n", - "| time/ | |\n", - "| fps | 151 |\n", - "| iterations | 9 |\n", - "| time_elapsed | 121 |\n", - "| total_timesteps | 18432 |\n", - "| train/ | |\n", - "| approx_kl | 0.022988513 |\n", - "| clip_fraction | 0.391 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.206 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 139 |\n", - "| n_updates | 80 |\n", - "| policy_gradient_loss | -0.111 |\n", - "| value_loss | 228 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 6.14 |\n", - "| time/ | |\n", - "| fps | 153 |\n", - "| iterations | 10 |\n", - "| time_elapsed | 133 |\n", - "| total_timesteps | 20480 |\n", - "| train/ | |\n", - "| approx_kl | 0.022813996 |\n", - "| clip_fraction | 0.372 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.199 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 117 |\n", - "| n_updates | 90 |\n", - "| policy_gradient_loss | -0.108 |\n", - "| value_loss | 212 |\n", - "-----------------------------------------\n" + "Using cuda device\n", + "Wrapping the env with a `Monitor` wrapper\n", + "Wrapping the env in a DummyVecEnv.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -175 |\n", + "| exploration_rate | 0.525 |\n", + "| time/ | |\n", + "| episodes | 10000 |\n", + "| fps | 4606 |\n", + "| time_elapsed | 10 |\n", + "| total_timesteps | 49989 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -208 |\n", + "| exploration_rate | 0.0502 |\n", + "| time/ | |\n", + "| episodes | 20000 |\n", + "| fps | 1118 |\n", + "| time_elapsed | 89 |\n", + "| total_timesteps | 99980 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 24.6 |\n", + "| n_updates | 12494 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -230 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 30000 |\n", + "| fps | 856 |\n", + "| time_elapsed | 175 |\n", + "| total_timesteps | 149974 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 18.7 |\n", + "| n_updates | 24993 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -242 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 40000 |\n", + "| fps | 766 |\n", + "| time_elapsed | 260 |\n", + "| total_timesteps | 199967 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 24 |\n", + "| n_updates | 37491 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -186 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 50000 |\n", + "| fps | 722 |\n", + "| time_elapsed | 346 |\n", + "| total_timesteps | 249962 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 21.5 |\n", + "| n_updates | 49990 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -183 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 60000 |\n", + "| fps | 694 |\n", + "| time_elapsed | 431 |\n", + "| total_timesteps | 299957 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.6 |\n", + "| n_updates | 62489 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -181 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 70000 |\n", + "| fps | 675 |\n", + "| time_elapsed | 517 |\n", + "| total_timesteps | 349953 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 26.8 |\n", + "| n_updates | 74988 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -196 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 80000 |\n", + "| fps | 663 |\n", + "| time_elapsed | 603 |\n", + "| total_timesteps | 399936 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 22.5 |\n", + "| n_updates | 87483 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -174 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 90000 |\n", + "| fps | 653 |\n", + "| time_elapsed | 688 |\n", + "| total_timesteps | 449928 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 21.1 |\n", + "| n_updates | 99981 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -155 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 100000 |\n", + "| fps | 645 |\n", + "| time_elapsed | 774 |\n", + "| total_timesteps | 499920 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 22.8 |\n", + "| n_updates | 112479 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -153 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 110000 |\n", + "| fps | 638 |\n", + "| time_elapsed | 860 |\n", + "| total_timesteps | 549916 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 16 |\n", + "| n_updates | 124978 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -164 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 120000 |\n", + "| fps | 633 |\n", + "| time_elapsed | 947 |\n", + "| total_timesteps | 599915 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.8 |\n", + "| n_updates | 137478 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -145 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 130000 |\n", + "| fps | 628 |\n", + "| time_elapsed | 1033 |\n", + "| total_timesteps | 649910 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.8 |\n", + "| n_updates | 149977 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -154 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 140000 |\n", + "| fps | 624 |\n", + "| time_elapsed | 1120 |\n", + "| total_timesteps | 699902 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 20.9 |\n", + "| n_updates | 162475 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -192 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 150000 |\n", + "| fps | 621 |\n", + "| time_elapsed | 1206 |\n", + "| total_timesteps | 749884 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 18.3 |\n", + "| n_updates | 174970 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -170 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 160000 |\n", + "| fps | 618 |\n", + "| time_elapsed | 1293 |\n", + "| total_timesteps | 799869 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.7 |\n", + "| n_updates | 187467 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -233 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 170000 |\n", + "| fps | 615 |\n", + "| time_elapsed | 1380 |\n", + "| total_timesteps | 849855 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 21.6 |\n", + "| n_updates | 199963 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -146 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 180000 |\n", + "| fps | 613 |\n", + "| time_elapsed | 1466 |\n", + "| total_timesteps | 899847 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 19.4 |\n", + "| n_updates | 212461 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -142 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 190000 |\n", + "| fps | 611 |\n", + "| time_elapsed | 1553 |\n", + "| total_timesteps | 949846 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 22.9 |\n", + "| n_updates | 224961 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -171 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 200000 |\n", + "| fps | 609 |\n", + "| time_elapsed | 1640 |\n", + "| total_timesteps | 999839 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 20.3 |\n", + "| n_updates | 237459 |\n", + "----------------------------------\n" ] }, { "data": { + "text/html": [ + "
\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], "text/plain": [ - "