diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb index 36185ba..85faf54 100644 --- a/dqn_wordle.ipynb +++ b/dqn_wordle.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 83, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -43,19 +43,1047 @@ "output_type": "stream", "text": [ "Using cuda device\n", - "Wrapping the env in a DummyVecEnv.\n" + "Wrapping the env in a DummyVecEnv.\n", + "---------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -124 |\n", + "| time/ | |\n", + "| fps | 784 |\n", + "| iterations | 1 |\n", + "| time_elapsed | 2 |\n", + "| total_timesteps | 2048 |\n", + "---------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 654 |\n", + "| iterations | 2 |\n", + "| time_elapsed | 6 |\n", + "| total_timesteps | 4096 |\n", + "| train/ | |\n", + "| approx_kl | 0.0065368367 |\n", + "| clip_fraction | 0.027 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | 0.00416 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 2.17e+03 |\n", + "| n_updates | 10 |\n", + "| policy_gradient_loss | -0.0517 |\n", + "| value_loss | 5.21e+03 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 612 |\n", + "| iterations | 3 |\n", + "| time_elapsed | 10 |\n", + "| total_timesteps | 6144 |\n", + "| train/ | |\n", + "| approx_kl | 0.006391342 |\n", + "| clip_fraction | 0.0312 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.0923 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.97e+03 |\n", + "| n_updates | 20 |\n", + "| policy_gradient_loss | -0.0487 |\n", + "| value_loss | 4.27e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -124 |\n", + "| time/ | |\n", + "| fps | 590 |\n", + "| iterations | 4 |\n", + "| time_elapsed | 13 |\n", + "| total_timesteps | 8192 |\n", + "| train/ | |\n", + "| approx_kl | 0.007267303 |\n", + "| clip_fraction | 0.0477 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.0297 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.41e+03 |\n", + "| n_updates | 30 |\n", + "| policy_gradient_loss | -0.0542 |\n", + "| value_loss | 3.64e+03 |\n", + "-----------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -124 |\n", + "| time/ | |\n", + "| fps | 579 |\n", + "| iterations | 5 |\n", + "| time_elapsed | 17 |\n", + "| total_timesteps | 10240 |\n", + "| train/ | |\n", + "| approx_kl | 0.0079803215 |\n", + "| clip_fraction | 0.0604 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.01 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.19e+03 |\n", + "| n_updates | 40 |\n", + "| policy_gradient_loss | -0.0581 |\n", + "| value_loss | 3.17e+03 |\n", + "------------------------------------------\n", + "------------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -120 |\n", + "| time/ | |\n", + "| fps | 572 |\n", + "| iterations | 6 |\n", + "| time_elapsed | 21 |\n", + "| total_timesteps | 12288 |\n", + "| train/ | |\n", + "| approx_kl | 0.0095686875 |\n", + "| clip_fraction | 0.0843 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.00451 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 1.07e+03 |\n", + "| n_updates | 50 |\n", + "| policy_gradient_loss | -0.0646 |\n", + "| value_loss | 2.68e+03 |\n", + "------------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 567 |\n", + "| iterations | 7 |\n", + "| time_elapsed | 25 |\n", + "| total_timesteps | 14336 |\n", + "| train/ | |\n", + "| approx_kl | 0.011415122 |\n", + "| clip_fraction | 0.114 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.00251 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 928 |\n", + "| n_updates | 60 |\n", + "| policy_gradient_loss | -0.0713 |\n", + "| value_loss | 2.17e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5.95 |\n", + "| ep_rew_mean | -120 |\n", + "| time/ | |\n", + "| fps | 563 |\n", + "| iterations | 8 |\n", + "| time_elapsed | 29 |\n", + "| total_timesteps | 16384 |\n", + "| train/ | |\n", + "| approx_kl | 0.012646846 |\n", + "| clip_fraction | 0.152 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.00144 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 989 |\n", + "| n_updates | 70 |\n", + "| policy_gradient_loss | -0.0767 |\n", + "| value_loss | 1.92e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 560 |\n", + "| iterations | 9 |\n", + "| time_elapsed | 32 |\n", + "| total_timesteps | 18432 |\n", + "| train/ | |\n", + "| approx_kl | 0.015274222 |\n", + "| clip_fraction | 0.209 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.000909 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 838 |\n", + "| n_updates | 80 |\n", + "| policy_gradient_loss | -0.0848 |\n", + "| value_loss | 1.57e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 558 |\n", + "| iterations | 10 |\n", + "| time_elapsed | 36 |\n", + "| total_timesteps | 20480 |\n", + "| train/ | |\n", + "| approx_kl | 0.017125849 |\n", + "| clip_fraction | 0.269 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | -0.000586 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 634 |\n", + "| n_updates | 90 |\n", + "| policy_gradient_loss | -0.0913 |\n", + "| value_loss | 1.38e+03 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 556 |\n", + "| iterations | 11 |\n", + "| time_elapsed | 40 |\n", + "| total_timesteps | 22528 |\n", + "| train/ | |\n", + "| approx_kl | 0.02106983 |\n", + "| clip_fraction | 0.345 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.46 |\n", + "| explained_variance | -0.000377 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 538 |\n", + "| n_updates | 100 |\n", + "| policy_gradient_loss | -0.0985 |\n", + "| value_loss | 1.25e+03 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 554 |\n", + "| iterations | 12 |\n", + "| time_elapsed | 44 |\n", + "| total_timesteps | 24576 |\n", + "| train/ | |\n", + "| approx_kl | 0.031943925 |\n", + "| clip_fraction | 0.478 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.46 |\n", + "| explained_variance | -0.000226 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 487 |\n", + "| n_updates | 110 |\n", + "| policy_gradient_loss | -0.109 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -120 |\n", + "| time/ | |\n", + "| fps | 553 |\n", + "| iterations | 13 |\n", + "| time_elapsed | 48 |\n", + "| total_timesteps | 26624 |\n", + "| train/ | |\n", + "| approx_kl | 0.05316051 |\n", + "| clip_fraction | 0.627 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.46 |\n", + "| explained_variance | -8.43e-05 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 521 |\n", + "| n_updates | 120 |\n", + "| policy_gradient_loss | -0.128 |\n", + "| value_loss | 1.12e+03 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5.96 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 551 |\n", + "| iterations | 14 |\n", + "| time_elapsed | 51 |\n", + "| total_timesteps | 28672 |\n", + "| train/ | |\n", + "| approx_kl | 0.055725098 |\n", + "| clip_fraction | 0.711 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | -3.3e-05 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 541 |\n", + "| n_updates | 130 |\n", + "| policy_gradient_loss | -0.138 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -127 |\n", + "| time/ | |\n", + "| fps | 550 |\n", + "| iterations | 15 |\n", + "| time_elapsed | 55 |\n", + "| total_timesteps | 30720 |\n", + "| train/ | |\n", + "| approx_kl | 0.057101354 |\n", + "| clip_fraction | 0.73 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | -1.74e-05 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 598 |\n", + "| n_updates | 140 |\n", + "| policy_gradient_loss | -0.141 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 549 |\n", + "| iterations | 16 |\n", + "| time_elapsed | 59 |\n", + "| total_timesteps | 32768 |\n", + "| train/ | |\n", + "| approx_kl | 0.06564396 |\n", + "| clip_fraction | 0.73 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.44 |\n", + "| explained_variance | -1.05e-05 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 506 |\n", + "| n_updates | 150 |\n", + "| policy_gradient_loss | -0.141 |\n", + "| value_loss | 1.17e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -124 |\n", + "| time/ | |\n", + "| fps | 549 |\n", + "| iterations | 17 |\n", + "| time_elapsed | 63 |\n", + "| total_timesteps | 34816 |\n", + "| train/ | |\n", + "| approx_kl | 0.05891238 |\n", + "| clip_fraction | 0.735 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.43 |\n", + "| explained_variance | -7.15e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 564 |\n", + "| n_updates | 160 |\n", + "| policy_gradient_loss | -0.141 |\n", + "| value_loss | 1.15e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -126 |\n", + "| time/ | |\n", + "| fps | 548 |\n", + "| iterations | 18 |\n", + "| time_elapsed | 67 |\n", + "| total_timesteps | 36864 |\n", + "| train/ | |\n", + "| approx_kl | 0.06345081 |\n", + "| clip_fraction | 0.753 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.43 |\n", + "| explained_variance | -5.13e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 625 |\n", + "| n_updates | 170 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.14e+03 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -125 |\n", + "| time/ | |\n", + "| fps | 547 |\n", + "| iterations | 19 |\n", + "| time_elapsed | 71 |\n", + "| total_timesteps | 38912 |\n", + "| train/ | |\n", + "| approx_kl | 0.061060853 |\n", + "| clip_fraction | 0.71 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.42 |\n", + "| explained_variance | -3.7e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 715 |\n", + "| n_updates | 180 |\n", + "| policy_gradient_loss | -0.139 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n", + "---------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -126 |\n", + "| time/ | |\n", + "| fps | 548 |\n", + "| iterations | 20 |\n", + "| time_elapsed | 74 |\n", + "| total_timesteps | 40960 |\n", + "| train/ | |\n", + "| approx_kl | 0.0626598 |\n", + "| clip_fraction | 0.721 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.42 |\n", + "| explained_variance | -2.86e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 665 |\n", + "| n_updates | 190 |\n", + "| policy_gradient_loss | -0.141 |\n", + "| value_loss | 1.17e+03 |\n", + "---------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 547 |\n", + "| iterations | 21 |\n", + "| time_elapsed | 78 |\n", + "| total_timesteps | 43008 |\n", + "| train/ | |\n", + "| approx_kl | 0.066325136 |\n", + "| clip_fraction | 0.743 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.41 |\n", + "| explained_variance | -2.15e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 670 |\n", + "| n_updates | 200 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.18e+03 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 544 |\n", + "| iterations | 22 |\n", + "| time_elapsed | 82 |\n", + "| total_timesteps | 45056 |\n", + "| train/ | |\n", + "| approx_kl | 0.06377452 |\n", + "| clip_fraction | 0.721 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.4 |\n", + "| explained_variance | -1.91e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 695 |\n", + "| n_updates | 210 |\n", + "| policy_gradient_loss | -0.142 |\n", + "| value_loss | 1.15e+03 |\n", + "----------------------------------------\n", + "---------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 543 |\n", + "| iterations | 23 |\n", + "| time_elapsed | 86 |\n", + "| total_timesteps | 47104 |\n", + "| train/ | |\n", + "| approx_kl | 0.0712228 |\n", + "| clip_fraction | 0.734 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.39 |\n", + "| explained_variance | -1.43e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 691 |\n", + "| n_updates | 220 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.15e+03 |\n", + "---------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -124 |\n", + "| time/ | |\n", + "| fps | 542 |\n", + "| iterations | 24 |\n", + "| time_elapsed | 90 |\n", + "| total_timesteps | 49152 |\n", + "| train/ | |\n", + "| approx_kl | 0.06300552 |\n", + "| clip_fraction | 0.737 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.38 |\n", + "| explained_variance | -1.19e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 475 |\n", + "| n_updates | 230 |\n", + "| policy_gradient_loss | -0.144 |\n", + "| value_loss | 1.17e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -126 |\n", + "| time/ | |\n", + "| fps | 540 |\n", + "| iterations | 25 |\n", + "| time_elapsed | 94 |\n", + "| total_timesteps | 51200 |\n", + "| train/ | |\n", + "| approx_kl | 0.07357139 |\n", + "| clip_fraction | 0.738 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.37 |\n", + "| explained_variance | -1.07e-06 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 687 |\n", + "| n_updates | 240 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.18e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -124 |\n", + "| time/ | |\n", + "| fps | 539 |\n", + "| iterations | 26 |\n", + "| time_elapsed | 98 |\n", + "| total_timesteps | 53248 |\n", + "| train/ | |\n", + "| approx_kl | 0.07966857 |\n", + "| clip_fraction | 0.756 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.34 |\n", + "| explained_variance | -8.34e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 408 |\n", + "| n_updates | 250 |\n", + "| policy_gradient_loss | -0.144 |\n", + "| value_loss | 1.17e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 539 |\n", + "| iterations | 27 |\n", + "| time_elapsed | 102 |\n", + "| total_timesteps | 55296 |\n", + "| train/ | |\n", + "| approx_kl | 0.07282317 |\n", + "| clip_fraction | 0.743 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.31 |\n", + "| explained_variance | -5.96e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 471 |\n", + "| n_updates | 260 |\n", + "| policy_gradient_loss | -0.145 |\n", + "| value_loss | 1.16e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 539 |\n", + "| iterations | 28 |\n", + "| time_elapsed | 106 |\n", + "| total_timesteps | 57344 |\n", + "| train/ | |\n", + "| approx_kl | 0.07441577 |\n", + "| clip_fraction | 0.731 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.28 |\n", + "| explained_variance | -5.96e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 501 |\n", + "| n_updates | 270 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.16e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5.97 |\n", + "| ep_rew_mean | -125 |\n", + "| time/ | |\n", + "| fps | 539 |\n", + "| iterations | 29 |\n", + "| time_elapsed | 110 |\n", + "| total_timesteps | 59392 |\n", + "| train/ | |\n", + "| approx_kl | 0.07315491 |\n", + "| clip_fraction | 0.757 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.23 |\n", + "| explained_variance | -4.77e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 569 |\n", + "| n_updates | 280 |\n", + "| policy_gradient_loss | -0.146 |\n", + "| value_loss | 1.13e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -120 |\n", + "| time/ | |\n", + "| fps | 538 |\n", + "| iterations | 30 |\n", + "| time_elapsed | 114 |\n", + "| total_timesteps | 61440 |\n", + "| train/ | |\n", + "| approx_kl | 0.06323205 |\n", + "| clip_fraction | 0.716 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.18 |\n", + "| explained_variance | -3.58e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 612 |\n", + "| n_updates | 290 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.17e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 538 |\n", + "| iterations | 31 |\n", + "| time_elapsed | 117 |\n", + "| total_timesteps | 63488 |\n", + "| train/ | |\n", + "| approx_kl | 0.07022999 |\n", + "| clip_fraction | 0.735 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.11 |\n", + "| explained_variance | -4.77e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 653 |\n", + "| n_updates | 300 |\n", + "| policy_gradient_loss | -0.143 |\n", + "| value_loss | 1.14e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 537 |\n", + "| iterations | 32 |\n", + "| time_elapsed | 121 |\n", + "| total_timesteps | 65536 |\n", + "| train/ | |\n", + "| approx_kl | 0.07809374 |\n", + "| clip_fraction | 0.734 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.02 |\n", + "| explained_variance | -4.77e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 505 |\n", + "| n_updates | 310 |\n", + "| policy_gradient_loss | -0.14 |\n", + "| value_loss | 1.14e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 537 |\n", + "| iterations | 33 |\n", + "| time_elapsed | 125 |\n", + "| total_timesteps | 67584 |\n", + "| train/ | |\n", + "| approx_kl | 0.07546057 |\n", + "| clip_fraction | 0.731 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.93 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 549 |\n", + "| n_updates | 320 |\n", + "| policy_gradient_loss | -0.139 |\n", + "| value_loss | 1.14e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 536 |\n", + "| iterations | 34 |\n", + "| time_elapsed | 129 |\n", + "| total_timesteps | 69632 |\n", + "| train/ | |\n", + "| approx_kl | 0.06426198 |\n", + "| clip_fraction | 0.71 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.85 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 546 |\n", + "| n_updates | 330 |\n", + "| policy_gradient_loss | -0.139 |\n", + "| value_loss | 1.16e+03 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 537 |\n", + "| iterations | 35 |\n", + "| time_elapsed | 133 |\n", + "| total_timesteps | 71680 |\n", + "| train/ | |\n", + "| approx_kl | 0.067716464 |\n", + "| clip_fraction | 0.693 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.76 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 577 |\n", + "| n_updates | 340 |\n", + "| policy_gradient_loss | -0.138 |\n", + "| value_loss | 1.16e+03 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -120 |\n", + "| time/ | |\n", + "| fps | 536 |\n", + "| iterations | 36 |\n", + "| time_elapsed | 137 |\n", + "| total_timesteps | 73728 |\n", + "| train/ | |\n", + "| approx_kl | 0.07348621 |\n", + "| clip_fraction | 0.719 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.72 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 594 |\n", + "| n_updates | 350 |\n", + "| policy_gradient_loss | -0.138 |\n", + "| value_loss | 1.15e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -125 |\n", + "| time/ | |\n", + "| fps | 536 |\n", + "| iterations | 37 |\n", + "| time_elapsed | 141 |\n", + "| total_timesteps | 75776 |\n", + "| train/ | |\n", + "| approx_kl | 0.06314379 |\n", + "| clip_fraction | 0.663 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.69 |\n", + "| explained_variance | -3.58e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 546 |\n", + "| n_updates | 360 |\n", + "| policy_gradient_loss | -0.135 |\n", + "| value_loss | 1.15e+03 |\n", + "----------------------------------------\n", + "---------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 535 |\n", + "| iterations | 38 |\n", + "| time_elapsed | 145 |\n", + "| total_timesteps | 77824 |\n", + "| train/ | |\n", + "| approx_kl | 0.0605906 |\n", + "| clip_fraction | 0.678 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.58 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 569 |\n", + "| n_updates | 370 |\n", + "| policy_gradient_loss | -0.137 |\n", + "| value_loss | 1.15e+03 |\n", + "---------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 535 |\n", + "| iterations | 39 |\n", + "| time_elapsed | 149 |\n", + "| total_timesteps | 79872 |\n", + "| train/ | |\n", + "| approx_kl | 0.06449682 |\n", + "| clip_fraction | 0.682 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.48 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 602 |\n", + "| n_updates | 380 |\n", + "| policy_gradient_loss | -0.134 |\n", + "| value_loss | 1.14e+03 |\n", + "----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 535 |\n", + "| iterations | 40 |\n", + "| time_elapsed | 153 |\n", + "| total_timesteps | 81920 |\n", + "| train/ | |\n", + "| approx_kl | 0.05844091 |\n", + "| clip_fraction | 0.633 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.33 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 575 |\n", + "| n_updates | 390 |\n", + "| policy_gradient_loss | -0.13 |\n", + "| value_loss | 1.16e+03 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -120 |\n", + "| time/ | |\n", + "| fps | 535 |\n", + "| iterations | 41 |\n", + "| time_elapsed | 156 |\n", + "| total_timesteps | 83968 |\n", + "| train/ | |\n", + "| approx_kl | 0.049975857 |\n", + "| clip_fraction | 0.601 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -8.15 |\n", + "| explained_variance | -1.19e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 581 |\n", + "| n_updates | 400 |\n", + "| policy_gradient_loss | -0.127 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 534 |\n", + "| iterations | 42 |\n", + "| time_elapsed | 160 |\n", + "| total_timesteps | 86016 |\n", + "| train/ | |\n", + "| approx_kl | 0.055807307 |\n", + "| clip_fraction | 0.616 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -7.91 |\n", + "| explained_variance | 1.19e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 590 |\n", + "| n_updates | 410 |\n", + "| policy_gradient_loss | -0.123 |\n", + "| value_loss | 1.14e+03 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 534 |\n", + "| iterations | 43 |\n", + "| time_elapsed | 164 |\n", + "| total_timesteps | 88064 |\n", + "| train/ | |\n", + "| approx_kl | 0.05585126 |\n", + "| clip_fraction | 0.592 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -7.69 |\n", + "| explained_variance | 0 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 512 |\n", + "| n_updates | 420 |\n", + "| policy_gradient_loss | -0.117 |\n", + "| value_loss | 1.17e+03 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 534 |\n", + "| iterations | 44 |\n", + "| time_elapsed | 168 |\n", + "| total_timesteps | 90112 |\n", + "| train/ | |\n", + "| approx_kl | 0.050130654 |\n", + "| clip_fraction | 0.546 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -7.48 |\n", + "| explained_variance | 0 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 608 |\n", + "| n_updates | 430 |\n", + "| policy_gradient_loss | -0.106 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -122 |\n", + "| time/ | |\n", + "| fps | 534 |\n", + "| iterations | 45 |\n", + "| time_elapsed | 172 |\n", + "| total_timesteps | 92160 |\n", + "| train/ | |\n", + "| approx_kl | 0.050139036 |\n", + "| clip_fraction | 0.547 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -7.23 |\n", + "| explained_variance | 0 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 628 |\n", + "| n_updates | 440 |\n", + "| policy_gradient_loss | -0.103 |\n", + "| value_loss | 1.14e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -126 |\n", + "| time/ | |\n", + "| fps | 533 |\n", + "| iterations | 46 |\n", + "| time_elapsed | 176 |\n", + "| total_timesteps | 94208 |\n", + "| train/ | |\n", + "| approx_kl | 0.045009457 |\n", + "| clip_fraction | 0.495 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -7.04 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 557 |\n", + "| n_updates | 450 |\n", + "| policy_gradient_loss | -0.0923 |\n", + "| value_loss | 1.13e+03 |\n", + "-----------------------------------------\n", + "---------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 533 |\n", + "| iterations | 47 |\n", + "| time_elapsed | 180 |\n", + "| total_timesteps | 96256 |\n", + "| train/ | |\n", + "| approx_kl | 0.0487647 |\n", + "| clip_fraction | 0.499 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -6.85 |\n", + "| explained_variance | -2.38e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 531 |\n", + "| n_updates | 460 |\n", + "| policy_gradient_loss | -0.0868 |\n", + "| value_loss | 1.17e+03 |\n", + "---------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -123 |\n", + "| time/ | |\n", + "| fps | 533 |\n", + "| iterations | 48 |\n", + "| time_elapsed | 184 |\n", + "| total_timesteps | 98304 |\n", + "| train/ | |\n", + "| approx_kl | 0.052709702 |\n", + "| clip_fraction | 0.507 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -6.58 |\n", + "| explained_variance | -1.19e-07 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 547 |\n", + "| n_updates | 470 |\n", + "| policy_gradient_loss | -0.0865 |\n", + "| value_loss | 1.14e+03 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | -121 |\n", + "| time/ | |\n", + "| fps | 533 |\n", + "| iterations | 49 |\n", + "| time_elapsed | 188 |\n", + "| total_timesteps | 100352 |\n", + "| train/ | |\n", + "| approx_kl | 0.048466682 |\n", + "| clip_fraction | 0.511 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -6.52 |\n", + "| explained_variance | 0 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 624 |\n", + "| n_updates | 480 |\n", + "| policy_gradient_loss | -0.0834 |\n", + "| value_loss | 1.15e+03 |\n", + "-----------------------------------------\n" ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "total_timesteps = 1000\n", + "total_timesteps = 100000\n", "model = PPO(\"MlpPolicy\", env, verbose=1)\n", "model.learn(total_timesteps=total_timesteps)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -64,7 +1092,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -73,60 +1101,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[[16 18 5 15 14 3 3 1 3 3]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", - "[[16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", - "[[16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", - "[[16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [ 0 0 0 0 0 0 0 0 0 0]\n", - " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", - "[[16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", - "[[16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]\n", - " [16 18 5 15 14 3 3 1 3 3]] -1.0 True {}\n" - ] - }, - { - "ename": "KeyError", - "evalue": "'correct'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[82], line 19\u001b[0m\n\u001b[1;32m 15\u001b[0m state, reward, done, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(action)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mprint\u001b[39m(state, reward, done, info)\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43minfo\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcorrect\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m:\n\u001b[1;32m 20\u001b[0m wins \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m#end_rewards.append(reward == 0)\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m#return np.sum(end_rewards) / len(end_rewards)\u001b[39;00m\n", - "\u001b[0;31mKeyError\u001b[0m: 'correct'" + "0\n" ] } ], @@ -147,22 +1129,11 @@ "\n", " state, reward, done, info = env.step(action)\n", "\n", - " print(state, reward, done, info)\n", - "\n", " if info[\"correct\"]:\n", " wins += 1\n", - " \n", - " #end_rewards.append(reward == 0)\n", - " \n", - "#return np.sum(end_rewards) / len(end_rewards)\n" + "\n", + "print(wins)\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {