1 Commits

Author SHA1 Message Date
Arthur Lu
cf977e4797 try penalizing duplicate guesses 2024-03-15 18:48:21 -07:00
3 changed files with 293 additions and 839 deletions

View File

@@ -6,11 +6,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import gym\n",
"import gym_wordle\n", "import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n", "from stable_baselines3 import DQN, PPO, common\n",
"import numpy as np\n", "import numpy as np\n",
"import tqdm" "from tqdm import tqdm"
] ]
}, },
{ {
@@ -35,517 +34,249 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c52630b65904d5e8e200be505d2121a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Using cuda device\n", "Using cuda device\n",
"Wrapping the env with a `Monitor` wrapper\n", "Wrapping the env in a DummyVecEnv.\n",
"Wrapping the env in a DummyVecEnv.\n" "---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 2.14 |\n",
"| time/ | |\n",
"| fps | 750 |\n",
"| iterations | 1 |\n",
"| time_elapsed | 2 |\n",
"| total_timesteps | 2048 |\n",
"---------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.59 |\n",
"| time/ | |\n",
"| fps | 625 |\n",
"| iterations | 2 |\n",
"| time_elapsed | 6 |\n",
"| total_timesteps | 4096 |\n",
"| train/ | |\n",
"| approx_kl | 0.022059526 |\n",
"| clip_fraction | 0.331 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.0118 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 130 |\n",
"| n_updates | 10 |\n",
"| policy_gradient_loss | -0.0851 |\n",
"| value_loss | 253 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 5.86 |\n",
"| time/ | |\n",
"| fps | 585 |\n",
"| iterations | 3 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 6144 |\n",
"| train/ | |\n",
"| approx_kl | 0.024416003 |\n",
"| clip_fraction | 0.462 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.152 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 85.2 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0987 |\n",
"| value_loss | 218 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.75 |\n",
"| time/ | |\n",
"| fps | 566 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 14 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.026305672 |\n",
"| clip_fraction | 0.45 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.161 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 144 |\n",
"| n_updates | 30 |\n",
"| policy_gradient_loss | -0.105 |\n",
"| value_loss | 220 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.47 |\n",
"| time/ | |\n",
"| fps | 554 |\n",
"| iterations | 5 |\n",
"| time_elapsed | 18 |\n",
"| total_timesteps | 10240 |\n",
"| train/ | |\n",
"| approx_kl | 0.02928267 |\n",
"| clip_fraction | 0.498 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.167 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 127 |\n",
"| n_updates | 40 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 207 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.62 |\n",
"| time/ | |\n",
"| fps | 546 |\n",
"| iterations | 6 |\n",
"| time_elapsed | 22 |\n",
"| total_timesteps | 12288 |\n",
"| train/ | |\n",
"| approx_kl | 0.028425258 |\n",
"| clip_fraction | 0.483 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.143 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 109 |\n",
"| n_updates | 50 |\n",
"| policy_gradient_loss | -0.117 |\n",
"| value_loss | 240 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5.98 |\n",
"| ep_rew_mean | 6.14 |\n",
"| time/ | |\n",
"| fps | 541 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 26 |\n",
"| total_timesteps | 14336 |\n",
"| train/ | |\n",
"| approx_kl | 0.026178032 |\n",
"| clip_fraction | 0.453 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.174 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 141 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 235 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.03 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 30 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.02457074 |\n",
"| clip_fraction | 0.423 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.171 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 111 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 212 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 9.54 |\n",
"| time/ | |\n",
"| fps | 532 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 34 |\n",
"| total_timesteps | 18432 |\n",
"| train/ | |\n",
"| approx_kl | 0.024578478 |\n",
"| clip_fraction | 0.417 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.178 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 121 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.114 |\n",
"| value_loss | 232 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.81 |\n",
"| time/ | |\n",
"| fps | 527 |\n",
"| iterations | 10 |\n",
"| time_elapsed | 38 |\n",
"| total_timesteps | 20480 |\n",
"| train/ | |\n",
"| approx_kl | 0.022704324 |\n",
"| clip_fraction | 0.379 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.194 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 108 |\n",
"| n_updates | 90 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 216 |\n",
"-----------------------------------------\n"
] ]
}, },
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -175 |\n",
"| exploration_rate | 0.525 |\n",
"| time/ | |\n",
"| episodes | 10000 |\n",
"| fps | 4606 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 49989 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -208 |\n",
"| exploration_rate | 0.0502 |\n",
"| time/ | |\n",
"| episodes | 20000 |\n",
"| fps | 1118 |\n",
"| time_elapsed | 89 |\n",
"| total_timesteps | 99980 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 24.6 |\n",
"| n_updates | 12494 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -230 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 30000 |\n",
"| fps | 856 |\n",
"| time_elapsed | 175 |\n",
"| total_timesteps | 149974 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 18.7 |\n",
"| n_updates | 24993 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -242 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 40000 |\n",
"| fps | 766 |\n",
"| time_elapsed | 260 |\n",
"| total_timesteps | 199967 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 24 |\n",
"| n_updates | 37491 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -186 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 50000 |\n",
"| fps | 722 |\n",
"| time_elapsed | 346 |\n",
"| total_timesteps | 249962 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.5 |\n",
"| n_updates | 49990 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -183 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 60000 |\n",
"| fps | 694 |\n",
"| time_elapsed | 431 |\n",
"| total_timesteps | 299957 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.6 |\n",
"| n_updates | 62489 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -181 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 70000 |\n",
"| fps | 675 |\n",
"| time_elapsed | 517 |\n",
"| total_timesteps | 349953 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 26.8 |\n",
"| n_updates | 74988 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -196 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 80000 |\n",
"| fps | 663 |\n",
"| time_elapsed | 603 |\n",
"| total_timesteps | 399936 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.5 |\n",
"| n_updates | 87483 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -174 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 90000 |\n",
"| fps | 653 |\n",
"| time_elapsed | 688 |\n",
"| total_timesteps | 449928 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.1 |\n",
"| n_updates | 99981 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -155 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 100000 |\n",
"| fps | 645 |\n",
"| time_elapsed | 774 |\n",
"| total_timesteps | 499920 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.8 |\n",
"| n_updates | 112479 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -153 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 110000 |\n",
"| fps | 638 |\n",
"| time_elapsed | 860 |\n",
"| total_timesteps | 549916 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 16 |\n",
"| n_updates | 124978 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -164 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 120000 |\n",
"| fps | 633 |\n",
"| time_elapsed | 947 |\n",
"| total_timesteps | 599915 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.8 |\n",
"| n_updates | 137478 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -145 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 130000 |\n",
"| fps | 628 |\n",
"| time_elapsed | 1033 |\n",
"| total_timesteps | 649910 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.8 |\n",
"| n_updates | 149977 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -154 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 140000 |\n",
"| fps | 624 |\n",
"| time_elapsed | 1120 |\n",
"| total_timesteps | 699902 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 20.9 |\n",
"| n_updates | 162475 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -192 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 150000 |\n",
"| fps | 621 |\n",
"| time_elapsed | 1206 |\n",
"| total_timesteps | 749884 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 18.3 |\n",
"| n_updates | 174970 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -170 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 160000 |\n",
"| fps | 618 |\n",
"| time_elapsed | 1293 |\n",
"| total_timesteps | 799869 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.7 |\n",
"| n_updates | 187467 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -233 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 170000 |\n",
"| fps | 615 |\n",
"| time_elapsed | 1380 |\n",
"| total_timesteps | 849855 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.6 |\n",
"| n_updates | 199963 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -146 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 180000 |\n",
"| fps | 613 |\n",
"| time_elapsed | 1466 |\n",
"| total_timesteps | 899847 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 19.4 |\n",
"| n_updates | 212461 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -142 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 190000 |\n",
"| fps | 611 |\n",
"| time_elapsed | 1553 |\n",
"| total_timesteps | 949846 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.9 |\n",
"| n_updates | 224961 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -171 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 200000 |\n",
"| fps | 609 |\n",
"| time_elapsed | 1640 |\n",
"| total_timesteps | 999839 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 20.3 |\n",
"| n_updates | 237459 |\n",
"----------------------------------\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>" "<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>"
] ]
}, },
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"total_timesteps = 1_000_000\n", "total_timesteps = 20_000\n",
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n", "model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)" "model.learn(total_timesteps=total_timesteps)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.save(\"dqn_new_rewards\")" "model.save(\"dqn_wordle\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n",
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"# model = DQN.load(\"dqn_wordle\")" "model = PPO.load(\"dqn_wordle\")"
] ]
}, },
{ {
@@ -553,18 +284,38 @@
"execution_count": 7, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[[ 7 18 1 19 16 3 3 3 2 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [ 7 18 1 19 16 3 3 3 2 3]\n",
" [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n",
"0\n" "0\n"
] ]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
} }
], ],
"source": [ "source": [
"env = gym_wordle.wordle.WordleEnv()\n", "env = gym_wordle.wordle.WordleEnv()\n",
"\n", "\n",
"for i in range(1000):\n", "for i in tqdm(range(1000)):\n",
" \n", " \n",
" state, info = env.reset()\n", " state, info = env.reset()\n",
"\n", "\n",
@@ -581,62 +332,9 @@
" if info[\"correct\"]:\n", " if info[\"correct\"]:\n",
" wins += 1\n", " wins += 1\n",
"\n", "\n",
"print(wins)" "print(state, reward, info)\n",
] "\n",
}, "print(wins)\n"
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
" -130)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state, reward"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"blah = (14, 1, 9, 22, 5)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blah in info['guesses']"
] ]
}, },
{ {
@@ -663,7 +361,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.5" "version": "3.8.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -3,11 +3,9 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
from sty import fg, bg, ef, rs from sty import fg, bg, ef, rs
from collections import Counter from collections import Counter, defaultdict
from gym_wordle.utils import to_english, to_array, get_words from gym_wordle.utils import to_english, to_array, get_words
from typing import Optional from typing import Optional
from collections import defaultdict
class WordList(gym.spaces.Discrete): class WordList(gym.spaces.Discrete):
"""Super class for defining a space of valid words according to a specified """Super class for defining a space of valid words according to a specified
@@ -161,14 +159,7 @@ class WordleEnv(gym.Env):
self.n_rounds = 6 self.n_rounds = 6
self.n_letters = 5 self.n_letters = 5
self.info = { self.info = {'correct': False, 'guesses': defaultdict(int)}
'correct': False,
'guesses': set(),
'known_positions': np.full(5, -1), # -1 for unknown, else letter index
'known_letters': set(), # Letters known to be in the word
'not_in_word': set(), # Letters known not to be in the word
'tried_positions': defaultdict(set) # Positions tried for each letter
}
def _highlighter(self, char: str, flag: int) -> str: def _highlighter(self, char: str, flag: int) -> str:
"""Terminal renderer functionality. Properly highlights a character """Terminal renderer functionality. Properly highlights a character
@@ -200,57 +191,13 @@ class WordleEnv(gym.Env):
""" """
self.round = 0 self.round = 0
self.solution = self.solution_space.sample() self.solution = self.solution_space.sample()
self.soln_hash = set(self.solution_space[self.solution])
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
self.info = { self.info = {'correct': False, 'guesses': defaultdict(int)}
'correct': False,
'guesses': set(),
'known_positions': np.full(5, -1),
'known_letters': set(),
'not_in_word': set(),
'tried_positions': defaultdict(set)
}
self.simulate_first_guess()
return self.state, self.info return self.state, self.info
def simulate_first_guess(self):
fixed_first_guess = "rates"
fixed_first_guess_array = to_array(fixed_first_guess)
# Simulate the feedback for each letter in the fixed first guess
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array
for i, letter in enumerate(fixed_first_guess_array):
if letter in self.solution_space[self.solution]:
if letter == self.solution_space[self.solution][i]:
feedback[i] = 1 # Correct position
else:
feedback[i] = 2 # Correct letter, wrong position
else:
feedback[i] = 3 # Letter not in word
# Update the state to reflect the fixed first guess and its feedback
self.state[0, :self.n_letters] = fixed_first_guess_array
self.state[0, self.n_letters:] = feedback
# Update self.info based on the feedback
for i, flag in enumerate(feedback):
if flag == self.right_pos:
# Mark letter as correctly placed
self.info['known_positions'][i] = fixed_first_guess_array[i]
elif flag == self.wrong_pos:
# Note the letter is in the word but in a different position
self.info['known_letters'].add(fixed_first_guess_array[i])
elif flag == self.wrong_char:
# Note the letter is not in the word
self.info['not_in_word'].add(fixed_first_guess_array[i])
# Since we're simulating the first guess, increment the round counter
self.round = 1
def render(self, mode: str = 'human'): def render(self, mode: str = 'human'):
"""Renders the Wordle environment. """Renders the Wordle environment.
@@ -272,69 +219,67 @@ class WordleEnv(gym.Env):
super().render(mode=mode) super().render(mode=mode)
def step(self, action): def step(self, action):
"""Run one step of the Wordle game. Every game must be previously
initialized by a call to the `reset` method.
Args:
action: Word guessed by the agent.
Returns:
state (object): Wordle game state after the guess.
reward (float): Reward associated with the guess.
done (bool): Whether the game has ended.
info (dict): Auxiliary diagnostic information.
"""
assert self.action_space.contains(action), 'Invalid word!' assert self.action_space.contains(action), 'Invalid word!'
guessed_word = self.action_space[action] action = self.action_space[action]
solution_word = self.solution_space[self.solution] solution = self.solution_space[self.solution]
reward = 0 self.state[self.round][:self.n_letters] = action
correct_guess = np.array_equal(guessed_word, solution_word)
# Initialize flags for current guess counter = Counter()
current_flags = np.full(self.n_letters, self.wrong_char) for i, char in enumerate(action):
flag_i = i + self.n_letters
counter[char] += 1
# Track newly discovered information if char == solution[i]:
new_info = False self.state[self.round, flag_i] = self.right_pos
elif counter[char] <= (char == solution).sum():
for i in range(self.n_letters): self.state[self.round, flag_i] = self.wrong_pos
guessed_letter = guessed_word[i]
if guessed_letter in solution_word:
# Penalize for reusing a letter found to not be in the word
if guessed_letter in self.info['not_in_word']:
reward -= 2
# Handle correct letter in the correct position
if guessed_letter == solution_word[i]:
current_flags[i] = self.right_pos
if self.info['known_positions'][i] != guessed_letter:
reward += 10 # Large reward for new correct placement
new_info = True
self.info['known_positions'][i] = guessed_letter
else:
reward += 20 # Large reward for repeating correct placement
else:
current_flags[i] = self.wrong_pos
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
reward += 10 # Reward for guessing a letter in a new position
new_info = True
else:
reward -= 20 # Penalize for not leveraging known information
self.info['known_letters'].add(guessed_letter)
self.info['tried_positions'][guessed_letter].add(i)
else: else:
# New incorrect letter self.state[self.round, flag_i] = self.wrong_char
if guessed_letter not in self.info['not_in_word']:
reward -= 2 # Penalize for guessing a letter not in the word
self.info['not_in_word'].add(guessed_letter)
new_info = True
else:
reward -= 15 # Larger penalty for repeating an incorrect letter
# Update observation state with the current guess and flags
self.state[self.round, :self.n_letters] = guessed_word
self.state[self.round, self.n_letters:] = current_flags
# Check if the game is over
done = self.round == self.n_rounds - 1 or correct_guess
self.info['correct'] = correct_guess
if correct_guess:
reward += 100 # Major reward for winning
elif done:
reward -= 50 # Penalty for losing without using new information effectively
elif not new_info:
reward -= 10 # Penalty if no new information was used in this guess
self.round += 1 self.round += 1
correct = (action == solution).all()
game_over = (self.round == self.n_rounds)
done = correct or game_over
reward = 0
# correct spot
reward += np.sum(self.state[:, 5:] == 1) * 2
# correct letter not correct spot
reward += np.sum(self.state[:, 5:] == 2) * 1
# incorrect letter
reward += np.sum(self.state[:, 5:] == 3) * -1
# guess same word as before
hashable_action = to_english(action)
if hashable_action in self.info['guesses']:
reward += -10 * self.info['guesses'][hashable_action]
else: # guess different word
reward += 10
self.info['guesses'][hashable_action] += 1
# for game ending in win or loss
reward += 10 if correct else -10 if done else 0
self.info['correct'] = correct
# observation, reward, terminated, truncated, info
return self.state, reward, done, False, self.info return self.state, reward, done, False, self.info

View File

@@ -1,189 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def my_func()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"defaultdict(<function __main__.<lambda>()>, {})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2, 3, 4]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t['t']"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'x' in t"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"x = np.array([1, 1, 1])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"x[:] = 0"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"'abcde'aaa\n",
" 33221\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}