From bbe9a1891c6bfb3de2b764a8d9df86228b7ae54a Mon Sep 17 00:00:00 2001 From: Ethan Shapiro <46407744+Ethan-Shapiro@users.noreply.github.com> Date: Fri, 15 Mar 2024 18:19:58 -0700 Subject: [PATCH] updated wordle to gymnasium env --- .gitignore | 3 +- dqn_wordle.ipynb | 1158 ++++++++---------------------------------- gym_wordle/utils.py | 26 +- gym_wordle/wordle.py | 138 +++-- 4 files changed, 293 insertions(+), 1032 deletions(-) diff --git a/.gitignore b/.gitignore index 6072b32..3d6ff5e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ **/data/* **/*.zip -**/__pycache__ \ No newline at end of file +**/__pycache__ +/env \ No newline at end of file diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb index 9e79659..1e31b94 100644 --- a/dqn_wordle.ipynb +++ b/dqn_wordle.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -35,1055 +35,236 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using cuda device\n", + "Using cpu device\n", "Wrapping the env in a DummyVecEnv.\n", "---------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", + "| ep_rew_mean | 5.59 |\n", "| time/ | |\n", - "| fps | 455 |\n", + "| fps | 544 |\n", "| iterations | 1 |\n", - "| time_elapsed | 4 |\n", + "| time_elapsed | 3 |\n", "| total_timesteps | 2048 |\n", "---------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", + "| ep_rew_mean | 1.77 |\n", "| time/ | |\n", - "| fps | 376 |\n", + "| fps | 245 |\n", "| iterations | 2 |\n", - "| time_elapsed | 10 |\n", + "| time_elapsed | 16 |\n", "| total_timesteps | 4096 |\n", "| train/ | |\n", - "| approx_kl | 0.006769434 |\n", - "| clip_fraction | 0.0309 |\n", + "| approx_kl | 0.021515464 |\n", + "| clip_fraction | 0.335 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.00119 |\n", + "| explained_variance | 0.00118 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 1.87e+03 |\n", + "| loss | 89.5 |\n", "| n_updates | 10 |\n", - "| policy_gradient_loss | -0.0533 |\n", - "| value_loss | 5.21e+03 |\n", + "| policy_gradient_loss | -0.0854 |\n", + "| value_loss | 262 |\n", "-----------------------------------------\n", "----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", + "| ep_rew_mean | 1.31 |\n", "| time/ | |\n", - "| fps | 357 |\n", + "| fps | 211 |\n", "| iterations | 3 |\n", - "| time_elapsed | 17 |\n", + "| time_elapsed | 29 |\n", "| total_timesteps | 6144 |\n", "| train/ | |\n", - "| approx_kl | 0.00641025 |\n", - "| clip_fraction | 0.0321 |\n", + "| approx_kl | 0.02457875 |\n", + "| clip_fraction | 0.465 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.0916 |\n", + "| explained_variance | 0.161 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 2.06e+03 |\n", + "| loss | 118 |\n", "| n_updates | 20 |\n", - "| policy_gradient_loss | -0.0489 |\n", - "| value_loss | 4.36e+03 |\n", + "| policy_gradient_loss | -0.0987 |\n", + "| value_loss | 217 |\n", "----------------------------------------\n", - "------------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -121 |\n", - "| time/ | |\n", - "| fps | 347 |\n", - "| iterations | 4 |\n", - "| time_elapsed | 23 |\n", - "| total_timesteps | 8192 |\n", - "| train/ | |\n", - "| approx_kl | 0.0073487614 |\n", - "| clip_fraction | 0.0466 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.0298 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 1.81e+03 |\n", - "| n_updates | 30 |\n", - "| policy_gradient_loss | -0.0539 |\n", - "| value_loss | 3.73e+03 |\n", - "------------------------------------------\n", "----------------------------------------\n", "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -121 |\n", + "| ep_len_mean | 5.96 |\n", + "| ep_rew_mean | 5.79 |\n", "| time/ | |\n", - "| fps | 343 |\n", - "| iterations | 5 |\n", - "| time_elapsed | 29 |\n", - "| total_timesteps | 10240 |\n", + "| fps | 196 |\n", + "| iterations | 4 |\n", + "| time_elapsed | 41 |\n", + "| total_timesteps | 8192 |\n", "| train/ | |\n", - "| approx_kl | 0.00845159 |\n", - "| clip_fraction | 0.068 |\n", + "| approx_kl | 0.02515613 |\n", + "| clip_fraction | 0.447 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.0105 |\n", + "| explained_variance | 0.151 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 864 |\n", - "| n_updates | 40 |\n", - "| policy_gradient_loss | -0.0601 |\n", - "| value_loss | 2.99e+03 |\n", + "| loss | 138 |\n", + "| n_updates | 30 |\n", + "| policy_gradient_loss | -0.103 |\n", + "| value_loss | 242 |\n", "----------------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", + "| ep_rew_mean | 4.9 |\n", "| time/ | |\n", - "| fps | 341 |\n", - "| iterations | 6 |\n", - "| time_elapsed | 35 |\n", - "| total_timesteps | 12288 |\n", + "| fps | 177 |\n", + "| iterations | 5 |\n", + "| time_elapsed | 57 |\n", + "| total_timesteps | 10240 |\n", "| train/ | |\n", - "| approx_kl | 0.009948943 |\n", - "| clip_fraction | 0.0943 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.00467 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 1.07e+03 |\n", - "| n_updates | 50 |\n", - "| policy_gradient_loss | -0.0664 |\n", - "| value_loss | 2.52e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 339 |\n", - "| iterations | 7 |\n", - "| time_elapsed | 42 |\n", - "| total_timesteps | 14336 |\n", - "| train/ | |\n", - "| approx_kl | 0.011411648 |\n", - "| clip_fraction | 0.121 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.00249 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 903 |\n", - "| n_updates | 60 |\n", - "| policy_gradient_loss | -0.0719 |\n", - "| value_loss | 2.2e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 338 |\n", - "| iterations | 8 |\n", - "| time_elapsed | 48 |\n", - "| total_timesteps | 16384 |\n", - "| train/ | |\n", - "| approx_kl | 0.01300336 |\n", - "| clip_fraction | 0.159 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.00149 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 839 |\n", - "| n_updates | 70 |\n", - "| policy_gradient_loss | -0.0779 |\n", - "| value_loss | 1.88e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 336 |\n", - "| iterations | 9 |\n", - "| time_elapsed | 54 |\n", - "| total_timesteps | 18432 |\n", - "| train/ | |\n", - "| approx_kl | 0.015219824 |\n", - "| clip_fraction | 0.211 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.000915 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 770 |\n", - "| n_updates | 80 |\n", - "| policy_gradient_loss | -0.0854 |\n", - "| value_loss | 1.61e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -128 |\n", - "| time/ | |\n", - "| fps | 335 |\n", - "| iterations | 10 |\n", - "| time_elapsed | 61 |\n", - "| total_timesteps | 20480 |\n", - "| train/ | |\n", - "| approx_kl | 0.017209966 |\n", - "| clip_fraction | 0.27 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | -0.000583 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 689 |\n", - "| n_updates | 90 |\n", - "| policy_gradient_loss | -0.0912 |\n", - "| value_loss | 1.43e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", - "| time/ | |\n", - "| fps | 335 |\n", - "| iterations | 11 |\n", - "| time_elapsed | 67 |\n", - "| total_timesteps | 22528 |\n", - "| train/ | |\n", - "| approx_kl | 0.020546965 |\n", - "| clip_fraction | 0.348 |\n", + "| approx_kl | 0.026685718 |\n", + "| clip_fraction | 0.444 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.46 |\n", - "| explained_variance | -0.000374 |\n", + "| explained_variance | 0.176 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 605 |\n", - "| n_updates | 100 |\n", - "| policy_gradient_loss | -0.0989 |\n", - "| value_loss | 1.27e+03 |\n", + "| loss | 96.7 |\n", + "| n_updates | 40 |\n", + "| policy_gradient_loss | -0.111 |\n", + "| value_loss | 211 |\n", "-----------------------------------------\n", "----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", + "| ep_rew_mean | 1.19 |\n", "| time/ | |\n", - "| fps | 334 |\n", - "| iterations | 12 |\n", - "| time_elapsed | 73 |\n", - "| total_timesteps | 24576 |\n", + "| fps | 164 |\n", + "| iterations | 6 |\n", + "| time_elapsed | 74 |\n", + "| total_timesteps | 12288 |\n", "| train/ | |\n", - "| approx_kl | 0.03119991 |\n", - "| clip_fraction | 0.478 |\n", + "| approx_kl | 0.02762504 |\n", + "| clip_fraction | 0.463 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.46 |\n", - "| explained_variance | -0.000229 |\n", + "| explained_variance | 0.186 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 510 |\n", - "| n_updates | 110 |\n", - "| policy_gradient_loss | -0.109 |\n", - "| value_loss | 1.17e+03 |\n", + "| loss | 103 |\n", + "| n_updates | 50 |\n", + "| policy_gradient_loss | -0.115 |\n", + "| value_loss | 200 |\n", "----------------------------------------\n", - "---------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 334 |\n", - "| iterations | 13 |\n", - "| time_elapsed | 79 |\n", - "| total_timesteps | 26624 |\n", - "| train/ | |\n", - "| approx_kl | 0.0502273 |\n", - "| clip_fraction | 0.605 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | -0.000101 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 582 |\n", - "| n_updates | 120 |\n", - "| policy_gradient_loss | -0.127 |\n", - "| value_loss | 1.13e+03 |\n", - "---------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -121 |\n", - "| time/ | |\n", - "| fps | 334 |\n", - "| iterations | 14 |\n", - "| time_elapsed | 85 |\n", - "| total_timesteps | 28672 |\n", - "| train/ | |\n", - "| approx_kl | 0.060225103 |\n", - "| clip_fraction | 0.736 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | -3.97e-05 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 530 |\n", - "| n_updates | 130 |\n", - "| policy_gradient_loss | -0.142 |\n", - "| value_loss | 1.13e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 332 |\n", - "| iterations | 15 |\n", - "| time_elapsed | 92 |\n", - "| total_timesteps | 30720 |\n", - "| train/ | |\n", - "| approx_kl | 0.057931915 |\n", - "| clip_fraction | 0.743 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | -2e-05 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 571 |\n", - "| n_updates | 140 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.14e+03 |\n", - "-----------------------------------------\n", "----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", + "| ep_rew_mean | 5.5 |\n", "| time/ | |\n", - "| fps | 329 |\n", - "| iterations | 16 |\n", - "| time_elapsed | 99 |\n", - "| total_timesteps | 32768 |\n", + "| fps | 155 |\n", + "| iterations | 7 |\n", + "| time_elapsed | 92 |\n", + "| total_timesteps | 14336 |\n", "| train/ | |\n", - "| approx_kl | 0.06145256 |\n", - "| clip_fraction | 0.737 |\n", + "| approx_kl | 0.02694263 |\n", + "| clip_fraction | 0.458 |\n", "| clip_range | 0.2 |\n", - "| entropy_loss | -9.44 |\n", - "| explained_variance | -1.12e-05 |\n", + "| entropy_loss | -9.46 |\n", + "| explained_variance | 0.15 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 612 |\n", - "| n_updates | 150 |\n", - "| policy_gradient_loss | -0.143 |\n", - "| value_loss | 1.14e+03 |\n", + "| loss | 84.1 |\n", + "| n_updates | 60 |\n", + "| policy_gradient_loss | -0.116 |\n", + "| value_loss | 225 |\n", "----------------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -120 |\n", + "| ep_rew_mean | 7.27 |\n", "| time/ | |\n", - "| fps | 328 |\n", - "| iterations | 17 |\n", + "| fps | 154 |\n", + "| iterations | 8 |\n", "| time_elapsed | 106 |\n", - "| total_timesteps | 34816 |\n", + "| total_timesteps | 16384 |\n", "| train/ | |\n", - "| approx_kl | 0.062183782 |\n", - "| clip_fraction | 0.731 |\n", + "| approx_kl | 0.024316464 |\n", + "| clip_fraction | 0.412 |\n", "| clip_range | 0.2 |\n", - "| entropy_loss | -9.43 |\n", - "| explained_variance | -7.63e-06 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | 0.173 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 560 |\n", - "| n_updates | 160 |\n", - "| policy_gradient_loss | -0.142 |\n", - "| value_loss | 1.14e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -121 |\n", - "| time/ | |\n", - "| fps | 326 |\n", - "| iterations | 18 |\n", - "| time_elapsed | 112 |\n", - "| total_timesteps | 36864 |\n", - "| train/ | |\n", - "| approx_kl | 0.06656339 |\n", - "| clip_fraction | 0.748 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.43 |\n", - "| explained_variance | -5.48e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 605 |\n", - "| n_updates | 170 |\n", - "| policy_gradient_loss | -0.145 |\n", - "| value_loss | 1.12e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 326 |\n", - "| iterations | 19 |\n", - "| time_elapsed | 119 |\n", - "| total_timesteps | 38912 |\n", - "| train/ | |\n", - "| approx_kl | 0.07115179 |\n", - "| clip_fraction | 0.767 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.42 |\n", - "| explained_variance | -3.93e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 561 |\n", - "| n_updates | 180 |\n", - "| policy_gradient_loss | -0.146 |\n", - "| value_loss | 1.13e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 326 |\n", - "| iterations | 20 |\n", - "| time_elapsed | 125 |\n", - "| total_timesteps | 40960 |\n", - "| train/ | |\n", - "| approx_kl | 0.07023676 |\n", - "| clip_fraction | 0.74 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.41 |\n", - "| explained_variance | -2.86e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 590 |\n", - "| n_updates | 190 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.17e+03 |\n", - "----------------------------------------\n", - "---------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 326 |\n", - "| iterations | 21 |\n", - "| time_elapsed | 131 |\n", - "| total_timesteps | 43008 |\n", - "| train/ | |\n", - "| approx_kl | 0.0665413 |\n", - "| clip_fraction | 0.746 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.4 |\n", - "| explained_variance | -2.15e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 550 |\n", - "| n_updates | 200 |\n", - "| policy_gradient_loss | -0.145 |\n", - "| value_loss | 1.16e+03 |\n", - "---------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 22 |\n", - "| time_elapsed | 138 |\n", - "| total_timesteps | 45056 |\n", - "| train/ | |\n", - "| approx_kl | 0.08091866 |\n", - "| clip_fraction | 0.745 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.38 |\n", - "| explained_variance | -1.91e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 594 |\n", - "| n_updates | 210 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -125 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 23 |\n", - "| time_elapsed | 144 |\n", - "| total_timesteps | 47104 |\n", - "| train/ | |\n", - "| approx_kl | 0.070498824 |\n", - "| clip_fraction | 0.734 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.36 |\n", - "| explained_variance | -1.19e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 573 |\n", - "| n_updates | 220 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.18e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 24 |\n", - "| time_elapsed | 151 |\n", - "| total_timesteps | 49152 |\n", - "| train/ | |\n", - "| approx_kl | 0.06726791 |\n", - "| clip_fraction | 0.728 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.34 |\n", - "| explained_variance | -1.07e-06 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 564 |\n", - "| n_updates | 230 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.18e+03 |\n", - "----------------------------------------\n", - "---------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 25 |\n", - "| time_elapsed | 157 |\n", - "| total_timesteps | 51200 |\n", - "| train/ | |\n", - "| approx_kl | 0.0721001 |\n", - "| clip_fraction | 0.727 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.31 |\n", - "| explained_variance | -8.34e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 579 |\n", - "| n_updates | 240 |\n", - "| policy_gradient_loss | -0.143 |\n", - "| value_loss | 1.14e+03 |\n", - "---------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 26 |\n", - "| time_elapsed | 164 |\n", - "| total_timesteps | 53248 |\n", - "| train/ | |\n", - "| approx_kl | 0.08537817 |\n", - "| clip_fraction | 0.767 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.27 |\n", - "| explained_variance | -8.34e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 648 |\n", - "| n_updates | 250 |\n", - "| policy_gradient_loss | -0.145 |\n", - "| value_loss | 1.16e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 27 |\n", - "| time_elapsed | 170 |\n", - "| total_timesteps | 55296 |\n", - "| train/ | |\n", - "| approx_kl | 0.07838201 |\n", - "| clip_fraction | 0.757 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.21 |\n", - "| explained_variance | -5.96e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 580 |\n", - "| n_updates | 260 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.14e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 28 |\n", - "| time_elapsed | 176 |\n", - "| total_timesteps | 57344 |\n", - "| train/ | |\n", - "| approx_kl | 0.08116107 |\n", - "| clip_fraction | 0.748 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.09 |\n", - "| explained_variance | -4.77e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 602 |\n", - "| n_updates | 270 |\n", - "| policy_gradient_loss | -0.144 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -118 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 29 |\n", - "| time_elapsed | 183 |\n", - "| total_timesteps | 59392 |\n", - "| train/ | |\n", - "| approx_kl | 0.085108414 |\n", - "| clip_fraction | 0.741 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.87 |\n", - "| explained_variance | -4.77e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 562 |\n", - "| n_updates | 280 |\n", - "| policy_gradient_loss | -0.142 |\n", - "| value_loss | 1.18e+03 |\n", + "| loss | 126 |\n", + "| n_updates | 70 |\n", + "| policy_gradient_loss | -0.112 |\n", + "| value_loss | 227 |\n", "-----------------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", + "| ep_rew_mean | 7.8 |\n", "| time/ | |\n", - "| fps | 323 |\n", - "| iterations | 30 |\n", - "| time_elapsed | 189 |\n", - "| total_timesteps | 61440 |\n", + "| fps | 151 |\n", + "| iterations | 9 |\n", + "| time_elapsed | 121 |\n", + "| total_timesteps | 18432 |\n", "| train/ | |\n", - "| approx_kl | 0.066152625 |\n", - "| clip_fraction | 0.722 |\n", + "| approx_kl | 0.022988513 |\n", + "| clip_fraction | 0.391 |\n", "| clip_range | 0.2 |\n", - "| entropy_loss | -8.75 |\n", - "| explained_variance | -4.77e-07 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | 0.206 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 568 |\n", - "| n_updates | 290 |\n", - "| policy_gradient_loss | -0.138 |\n", - "| value_loss | 1.14e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -120 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 31 |\n", - "| time_elapsed | 195 |\n", - "| total_timesteps | 63488 |\n", - "| train/ | |\n", - "| approx_kl | 0.06854295 |\n", - "| clip_fraction | 0.7 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.73 |\n", - "| explained_variance | -4.77e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 577 |\n", - "| n_updates | 300 |\n", - "| policy_gradient_loss | -0.139 |\n", - "| value_loss | 1.14e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 32 |\n", - "| time_elapsed | 201 |\n", - "| total_timesteps | 65536 |\n", - "| train/ | |\n", - "| approx_kl | 0.07200403 |\n", - "| clip_fraction | 0.702 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.65 |\n", - "| explained_variance | -4.77e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 570 |\n", - "| n_updates | 310 |\n", - "| policy_gradient_loss | -0.134 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 33 |\n", - "| time_elapsed | 208 |\n", - "| total_timesteps | 67584 |\n", - "| train/ | |\n", - "| approx_kl | 0.07691643 |\n", - "| clip_fraction | 0.692 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.64 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 618 |\n", - "| n_updates | 320 |\n", - "| policy_gradient_loss | -0.137 |\n", - "| value_loss | 1.16e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 34 |\n", - "| time_elapsed | 214 |\n", - "| total_timesteps | 69632 |\n", - "| train/ | |\n", - "| approx_kl | 0.07179158 |\n", - "| clip_fraction | 0.69 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.56 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 684 |\n", - "| n_updates | 330 |\n", - "| policy_gradient_loss | -0.139 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -120 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 35 |\n", - "| time_elapsed | 220 |\n", - "| total_timesteps | 71680 |\n", - "| train/ | |\n", - "| approx_kl | 0.06354737 |\n", - "| clip_fraction | 0.676 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.45 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 573 |\n", - "| n_updates | 340 |\n", - "| policy_gradient_loss | -0.137 |\n", - "| value_loss | 1.17e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 36 |\n", - "| time_elapsed | 226 |\n", - "| total_timesteps | 73728 |\n", - "| train/ | |\n", - "| approx_kl | 0.061548397 |\n", - "| clip_fraction | 0.658 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.38 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 572 |\n", - "| n_updates | 350 |\n", - "| policy_gradient_loss | -0.134 |\n", - "| value_loss | 1.12e+03 |\n", + "| loss | 139 |\n", + "| n_updates | 80 |\n", + "| policy_gradient_loss | -0.111 |\n", + "| value_loss | 228 |\n", "-----------------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", + "| ep_rew_mean | 6.14 |\n", "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 37 |\n", - "| time_elapsed | 232 |\n", - "| total_timesteps | 75776 |\n", + "| fps | 153 |\n", + "| iterations | 10 |\n", + "| time_elapsed | 133 |\n", + "| total_timesteps | 20480 |\n", "| train/ | |\n", - "| approx_kl | 0.059452366 |\n", - "| clip_fraction | 0.651 |\n", + "| approx_kl | 0.022813996 |\n", + "| clip_fraction | 0.372 |\n", "| clip_range | 0.2 |\n", - "| entropy_loss | -8.33 |\n", - "| explained_variance | -2.38e-07 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | 0.199 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 551 |\n", - "| n_updates | 360 |\n", - "| policy_gradient_loss | -0.133 |\n", - "| value_loss | 1.16e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 325 |\n", - "| iterations | 38 |\n", - "| time_elapsed | 239 |\n", - "| total_timesteps | 77824 |\n", - "| train/ | |\n", - "| approx_kl | 0.06572275 |\n", - "| clip_fraction | 0.667 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -8.17 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 554 |\n", - "| n_updates | 370 |\n", - "| policy_gradient_loss | -0.132 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -125 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 39 |\n", - "| time_elapsed | 245 |\n", - "| total_timesteps | 79872 |\n", - "| train/ | |\n", - "| approx_kl | 0.05422177 |\n", - "| clip_fraction | 0.637 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -7.93 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 600 |\n", - "| n_updates | 380 |\n", - "| policy_gradient_loss | -0.127 |\n", - "| value_loss | 1.16e+03 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 40 |\n", - "| time_elapsed | 252 |\n", - "| total_timesteps | 81920 |\n", - "| train/ | |\n", - "| approx_kl | 0.05258019 |\n", - "| clip_fraction | 0.591 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -7.8 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 682 |\n", - "| n_updates | 390 |\n", - "| policy_gradient_loss | -0.123 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 41 |\n", - "| time_elapsed | 258 |\n", - "| total_timesteps | 83968 |\n", - "| train/ | |\n", - "| approx_kl | 0.053135283 |\n", - "| clip_fraction | 0.574 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -7.49 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 560 |\n", - "| n_updates | 400 |\n", - "| policy_gradient_loss | -0.118 |\n", - "| value_loss | 1.16e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 42 |\n", - "| time_elapsed | 265 |\n", - "| total_timesteps | 86016 |\n", - "| train/ | |\n", - "| approx_kl | 0.04523302 |\n", - "| clip_fraction | 0.543 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -7.3 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 472 |\n", - "| n_updates | 410 |\n", - "| policy_gradient_loss | -0.105 |\n", - "| value_loss | 1.16e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -121 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 43 |\n", - "| time_elapsed | 271 |\n", - "| total_timesteps | 88064 |\n", - "| train/ | |\n", - "| approx_kl | 0.044511747 |\n", - "| clip_fraction | 0.487 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -7.08 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 510 |\n", - "| n_updates | 420 |\n", - "| policy_gradient_loss | -0.101 |\n", - "| value_loss | 1.15e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -122 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 44 |\n", - "| time_elapsed | 277 |\n", - "| total_timesteps | 90112 |\n", - "| train/ | |\n", - "| approx_kl | 0.048598923 |\n", - "| clip_fraction | 0.489 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -6.8 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 499 |\n", - "| n_updates | 430 |\n", - "| policy_gradient_loss | -0.096 |\n", - "| value_loss | 1.15e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -124 |\n", - "| time/ | |\n", - "| fps | 324 |\n", - "| iterations | 45 |\n", - "| time_elapsed | 284 |\n", - "| total_timesteps | 92160 |\n", - "| train/ | |\n", - "| approx_kl | 0.043928873 |\n", - "| clip_fraction | 0.514 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -6.63 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 559 |\n", - "| n_updates | 440 |\n", - "| policy_gradient_loss | -0.0893 |\n", - "| value_loss | 1.13e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", - "| time/ | |\n", - "| fps | 323 |\n", - "| iterations | 46 |\n", - "| time_elapsed | 290 |\n", - "| total_timesteps | 94208 |\n", - "| train/ | |\n", - "| approx_kl | 0.053060684 |\n", - "| clip_fraction | 0.495 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -6.49 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 644 |\n", - "| n_updates | 450 |\n", - "| policy_gradient_loss | -0.0849 |\n", - "| value_loss | 1.16e+03 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -123 |\n", - "| time/ | |\n", - "| fps | 323 |\n", - "| iterations | 47 |\n", - "| time_elapsed | 297 |\n", - "| total_timesteps | 96256 |\n", - "| train/ | |\n", - "| approx_kl | 0.056993663 |\n", - "| clip_fraction | 0.587 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -6.33 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 608 |\n", - "| n_updates | 460 |\n", - "| policy_gradient_loss | -0.0832 |\n", - "| value_loss | 1.17e+03 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -126 |\n", - "| time/ | |\n", - "| fps | 323 |\n", - "| iterations | 48 |\n", - "| time_elapsed | 303 |\n", - "| total_timesteps | 98304 |\n", - "| train/ | |\n", - "| approx_kl | 0.05388363 |\n", - "| clip_fraction | 0.536 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -6.15 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 572 |\n", - "| n_updates | 470 |\n", - "| policy_gradient_loss | -0.0811 |\n", - "| value_loss | 1.15e+03 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | -125 |\n", - "| time/ | |\n", - "| fps | 323 |\n", - "| iterations | 49 |\n", - "| time_elapsed | 310 |\n", - "| total_timesteps | 100352 |\n", - "| train/ | |\n", - "| approx_kl | 0.039147377 |\n", - "| clip_fraction | 0.465 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -6.07 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 523 |\n", - "| n_updates | 480 |\n", - "| policy_gradient_loss | -0.0778 |\n", - "| value_loss | 1.16e+03 |\n", + "| loss | 117 |\n", + "| n_updates | 90 |\n", + "| policy_gradient_loss | -0.108 |\n", + "| value_loss | 212 |\n", "-----------------------------------------\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 11, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "total_timesteps = 100000\n", - "model = PPO(\"MlpPolicy\", env, verbose=1)\n", + "total_timesteps = 20_000\n", + "model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n", "model.learn(total_timesteps=total_timesteps)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -1092,22 +273,95 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n", + "Exception: code() argument 13 must be str, not int\n", + " warnings.warn(\n", + "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", + "Exception: code() argument 13 must be str, not int\n", + " warnings.warn(\n" + ] + } + ], "source": [ "model = PPO.load(\"dqn_wordle\")" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "[[16 1 9 19 5 3 2 3 3 1]\n", + " [18 8 5 13 5 2 3 2 3 1]\n", + " [16 1 9 19 5 3 2 3 3 1]\n", + " [16 1 9 19 5 3 2 3 3 1]\n", + " [16 1 9 19 5 3 2 3 3 1]\n", + " [16 1 9 19 5 3 2 3 3 1]]\n", + "[[16 1 9 19 5 3 3 3 3 3]\n", + " [18 8 5 13 5 3 2 3 3 3]\n", + " [16 1 9 19 5 3 3 3 3 3]\n", + " [16 1 9 19 5 3 3 3 3 3]\n", + " [16 1 9 19 5 3 3 3 3 3]\n", + " [16 1 9 19 5 3 3 3 3 3]]\n", + "[[16 1 9 19 5 3 3 1 3 3]\n", + " [18 8 5 13 5 3 3 3 3 3]\n", + " [16 1 9 19 5 3 3 1 3 3]\n", + " [16 1 9 19 5 3 3 1 3 3]\n", + " [16 1 9 19 5 3 3 1 3 3]\n", + " [16 1 9 19 5 3 3 1 3 3]]\n", + "[[16 1 9 19 5 1 2 2 3 3]\n", + " [18 8 5 13 5 3 3 3 3 3]\n", + " [16 1 9 19 5 1 2 2 3 3]\n", + " [16 1 9 19 5 1 2 2 3 3]\n", + " [16 1 9 19 5 1 2 2 3 3]\n", + " [16 1 9 19 5 1 2 2 3 3]]\n", + "[[16 1 9 19 5 1 1 3 1 1]\n", + " [18 1 11 5 4 2 1 3 2 3]\n", + " [16 1 9 19 5 1 1 3 1 1]\n", + " [16 1 9 19 5 1 1 3 1 1]\n", + " [16 1 9 19 5 1 1 3 1 1]\n", + " [16 1 9 19 5 1 1 3 1 1]]\n", + "[[16 1 9 19 5 3 3 1 1 3]\n", + " [18 1 11 5 4 3 3 3 3 3]\n", + " [16 1 9 19 5 3 3 1 1 3]\n", + " [16 1 9 19 5 3 3 1 1 3]\n", + " [16 1 9 19 5 3 3 1 1 3]\n", + " [16 1 9 19 5 3 3 1 1 3]]\n", + "[[16 1 9 19 5 3 3 3 2 1]\n", + " [18 1 11 5 4 3 3 3 2 3]\n", + " [16 1 9 19 5 3 3 3 2 1]\n", + " [16 1 9 19 5 3 3 3 2 1]\n", + " [16 1 9 19 5 3 3 3 2 1]\n", + " [16 1 9 19 5 3 3 3 2 1]]\n", + "[[16 1 9 19 5 3 2 3 3 3]\n", + " [18 8 5 13 5 1 3 3 2 3]\n", + " [16 1 9 19 5 3 2 3 3 3]\n", + " [16 1 9 19 5 3 2 3 3 3]\n", + " [16 1 9 19 5 3 2 3 3 3]\n", + " [16 1 9 19 5 3 2 3 3 3]]\n", + "[[16 1 9 19 5 3 1 3 3 3]\n", + " [18 8 5 13 5 3 3 3 3 3]\n", + " [16 1 9 19 5 3 1 3 3 3]\n", + " [16 1 9 19 5 3 1 3 3 3]\n", + " [16 1 9 19 5 3 1 3 3 3]\n", + " [16 1 9 19 5 3 1 3 3 3]]\n", + "[[16 1 9 19 5 3 3 3 3 2]\n", + " [18 8 5 13 5 3 3 1 1 2]\n", + " [16 1 9 19 5 3 3 3 3 2]\n", + " [16 1 9 19 5 3 3 3 3 2]\n", + " [16 1 9 19 5 3 3 3 3 2]\n", + " [16 1 9 19 5 3 3 3 3 2]]\n", "0\n" ] } @@ -1115,9 +369,9 @@ "source": [ "env = gym_wordle.wordle.WordleEnv()\n", "\n", - "for i in range(1000):\n", + "for i in range(10):\n", " \n", - " state = env.reset()\n", + " state, info = env.reset()\n", "\n", " done = False\n", "\n", @@ -1127,13 +381,29 @@ "\n", " action, _states = model.predict(state, deterministic=True)\n", "\n", - " state, reward, done, info = env.step(action)\n", + " state, reward, done, truncated, info = env.step(action)\n", "\n", " if info[\"correct\"]:\n", " wins += 1\n", "\n", "print(wins)\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1152,7 +422,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/gym_wordle/utils.py b/gym_wordle/utils.py index e52601a..2b04f1e 100644 --- a/gym_wordle/utils.py +++ b/gym_wordle/utils.py @@ -35,7 +35,7 @@ def to_array(word: str) -> npt.NDArray[np.int64]: return np.array([_char_d[c] for c in word]) -def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: +def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]: """Loads a list of words in array form. If specified, this will recompute the list from the human-readable list of @@ -53,14 +53,14 @@ def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: five. """ assert category in {'guess', 'solution'} - + arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' if build: - list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' + list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' - with open(list_path, 'r') as f: - words = np.array([to_array(line.strip()) for line in f]) - np.save(arr_path, words) + with open(list_path, 'r') as f: + words = np.array([to_array(line.strip()) for line in f]) + np.save(arr_path, words) return np.load(arr_path) @@ -69,16 +69,16 @@ def play(): """Play Wordle yourself!""" import gym import gym_wordle - + env = gym.make('Wordle-v0') # load the environment - + env.reset() solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking! done = False - + while not done: - action = -1 + action = -1 # in general, the environment won't be forgiving if you input an # invalid word, but for this function I want to let you screw up user @@ -86,8 +86,8 @@ def play(): while not env.action_space.contains(action): guess = input('Guess: ') action = env.unwrapped.action_space.index_of(to_array(guess)) - + state, reward, done, info = env.step(action) env.render() - - print(f"The word was {solution}") \ No newline at end of file + + print(f"The word was {solution}") diff --git a/gym_wordle/wordle.py b/gym_wordle/wordle.py index 37de045..ff6586a 100644 --- a/gym_wordle/wordle.py +++ b/gym_wordle/wordle.py @@ -1,17 +1,17 @@ -import gym +import gymnasium as gym import numpy as np import numpy.typing as npt from sty import fg, bg, ef, rs -from collections import Counter +from collections import Counter from gym_wordle.utils import to_english, to_array, get_words -from typing import Optional +from typing import Optional + class WordList(gym.spaces.Discrete): """Super class for defining a space of valid words according to a specified list. - TODO: Fix these paragraphs The space is a subclass of gym.spaces.Discrete, where each element corresponds to an index of a valid word in the word list. The obfuscation is necessary for more direct implementation of RL algorithms, which expect @@ -66,16 +66,15 @@ class SolutionList(WordList): In the game Wordle, there are two different collections of words: - * "guesses", which the game accepts as valid words to use to guess the - answer. - * "solutions", which the game uses to choose solutions from. + * "guesses", which the game accepts as valid words to use to guess the + answer. + * "solutions", which the game uses to choose solutions from. Of course, the set of solutions is a strict subset of the set of guesses. - Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ - This class represents the set of solution words. """ + def __init__(self, **kwargs): """ Args: @@ -87,7 +86,7 @@ class SolutionList(WordList): class WordleObsSpace(gym.spaces.Box): """Implementation of the state (observation) space in terms of gym - primatives, in this case, gym.spaces.Box. + primitives, in this case, gym.spaces.Box. The Wordle observation space can be thought of as a 6x5 array with two channels: @@ -100,20 +99,11 @@ class WordleObsSpace(gym.spaces.Box): where there are 6 rows, one for each turn in the game, and 5 columns, since the solution will always be a word of length 5. - For simplicity, and compatibility with the stable_baselines algorithms, + For simplicity, and compatibility with stable_baselines algorithms, this multichannel is modeled as a 6x10 array, where the two channels are horizontally appended (along columns). Thus each row in the observation - should be interpreted as - - c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 - - when the word is c0...c4 and its associated flags are f0...f4. - - While the superclass method `sample` is available to the WordleObsSpace, it - should be emphasized that the output of `sample` will (almost surely) not - correspond to a real game configuration, because the sampling is not out of - possible game configurations. Instead, the Box superclass just samples the - integer array space uniformly. + should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is + c0...c4 and its associated flags are f0...f4. """ def __init__(self, **kwargs): @@ -130,20 +120,11 @@ class WordleObsSpace(gym.spaces.Box): class GuessList(WordList): - """Space for *solution* words to the Wordle environment. - - In the game Wordle, there are two different collections of words: - - * "guesses", which the game accepts as valid words to use to guess the - answer. - * "solutions", which the game uses to choose solutions from. - - Of course, the set of solutions is a strict subset of the set of guesses. - - Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ + """Space for *guess* words to the Wordle environment. This class represents the set of guess words. """ + def __init__(self, **kwargs): """ Args: @@ -154,10 +135,9 @@ class GuessList(WordList): class WordleEnv(gym.Env): - metadata = {'render.modes': ['human']} - # character flag codes + # Character flag codes no_char = 0 right_pos = 1 wrong_pos = 2 @@ -166,7 +146,6 @@ class WordleEnv(gym.Env): def __init__(self): super().__init__() - self.seed() self.action_space = GuessList() self.solution_space = SolutionList() @@ -181,6 +160,7 @@ class WordleEnv(gym.Env): self.n_rounds = 6 self.n_letters = 5 + self.info = {'correct': False, 'guesses': set()} def _highlighter(self, char: str, flag: int) -> str: """Terminal renderer functionality. Properly highlights a character @@ -201,73 +181,74 @@ class WordleEnv(gym.Env): front, back = self._highlights[flag] return front + char + back - def reset(self): + def reset(self, seed=None, options=None): + """Reset the environment to an initial state and returns an initial + observation. + + Note: The observation space instance should be a Box space. + + Returns: + state (object): The initial observation of the space. + """ self.round = 0 self.solution = self.solution_space.sample() self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) - return self.state + self.info = {'correct': False, 'guesses': set()} - def render(self, mode: str ='human'): + return self.state, self.info + + def render(self, mode: str = 'human'): """Renders the Wordle environment. Currently supported render modes: - - human: renders the Wordle game to the terminal. Args: - mode: the mode to render with + mode: the mode to render with. """ if mode == 'human': - for row in self.states: + for row in self.state: text = ''.join(map( - self._highlighter, - to_english(row[:self.n_letters]).upper(), + self._highlighter, + to_english(row[:self.n_letters]).upper(), row[self.n_letters:] )) - print(text) else: - super(WordleEnv, self).render(mode=mode) - + super().render(mode=mode) + def step(self, action): """Run one step of the Wordle game. Every game must be previously initialized by a call to the `reset` method. Args: action: Word guessed by the agent. + Returns: state (object): Wordle game state after the guess. - reward (float): Reward associated with the guess (-1 for incorrect, - 0 for correct) - done (bool): Whether the game has ended (by a correct guess or - after six guesses). - info (dict): Auxiliary diagnostic information (empty). + reward (float): Reward associated with the guess. + done (bool): Whether the game has ended. + info (dict): Auxiliary diagnostic information. """ assert self.action_space.contains(action), 'Invalid word!' - # transform the action, solution indices to their words - action = self.action_space[action] + action = self.action_space[action] solution = self.solution_space[self.solution] - # populate the word chars into the row (character channel) self.state[self.round][:self.n_letters] = action - # populate the flag characters into the row (flag channel) counter = Counter() for i, char in enumerate(action): - flag_i = i + self.n_letters # starts at 5 + flag_i = i + self.n_letters counter[char] += 1 - if char == solution[i]: # character is in correct position + if char == solution[i]: self.state[self.round, flag_i] = self.right_pos - elif counter[char] <= (char == solution).sum(): - # current character has been seen within correct number of - # occurrences + elif counter[char] <= (char == solution).sum(): self.state[self.round, flag_i] = self.wrong_pos else: - # wrong character, or "correct" character too many times self.state[self.round, flag_i] = self.wrong_char self.round += 1 @@ -277,20 +258,29 @@ class WordleEnv(gym.Env): done = correct or game_over - # Total reward equals -(number of incorrect guesses) - # reward = 0. if correct else -1. - - # correct +10 - # guesses new letter +1 - # guesses correct letter +1 - # spent another guess -1 - reward = 0 - reward += np.sum(self.state[:, 5:] == 1) * 1 - reward += np.sum(self.state[:, 5:] == 2) * 0.5 + # correct spot + reward += np.sum(self.state[:, 5:] == 1) * 2 + + # correct letter not correct spot + reward += np.sum(self.state[:, 5:] == 2) * 1 + + # incorrect letter reward += np.sum(self.state[:, 5:] == 3) * -1 + + # guess same word as before + hashable_action = tuple(action) + if hashable_action in self.info['guesses']: + reward += -10 + else: # guess different word + reward += 10 + + self.info['guesses'].add(hashable_action) + + # for game ending in win or loss reward += 10 if correct else -10 if done else 0 - info = {'correct': correct} + self.info['correct'] = correct - return self.state, reward, done, info \ No newline at end of file + # observation, reward, terminated, truncated, info + return self.state, reward, done, False, self.info