try penalizing duplicate guesses

This commit is contained in:
Arthur Lu 2024-03-15 18:48:21 -07:00
parent bbe9a1891c
commit cf977e4797
2 changed files with 188 additions and 250 deletions

View File

@ -6,11 +6,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import gym\n",
"import gym_wordle\n", "import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n", "from stable_baselines3 import DQN, PPO, common\n",
"import numpy as np\n", "import numpy as np\n",
"import tqdm" "from tqdm import tqdm"
] ]
}, },
{ {
@ -42,213 +41,213 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Using cpu device\n", "Using cuda device\n",
"Wrapping the env in a DummyVecEnv.\n", "Wrapping the env in a DummyVecEnv.\n",
"---------------------------------\n", "---------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 6 |\n", "| ep_len_mean | 6 |\n",
"| ep_rew_mean | 5.59 |\n", "| ep_rew_mean | 2.14 |\n",
"| time/ | |\n", "| time/ | |\n",
"| fps | 544 |\n", "| fps | 750 |\n",
"| iterations | 1 |\n", "| iterations | 1 |\n",
"| time_elapsed | 3 |\n", "| time_elapsed | 2 |\n",
"| total_timesteps | 2048 |\n", "| total_timesteps | 2048 |\n",
"---------------------------------\n", "---------------------------------\n",
"-----------------------------------------\n", "-----------------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 6 |\n", "| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.77 |\n", "| ep_rew_mean | 4.59 |\n",
"| time/ | |\n", "| time/ | |\n",
"| fps | 245 |\n", "| fps | 625 |\n",
"| iterations | 2 |\n", "| iterations | 2 |\n",
"| time_elapsed | 16 |\n", "| time_elapsed | 6 |\n",
"| total_timesteps | 4096 |\n", "| total_timesteps | 4096 |\n",
"| train/ | |\n", "| train/ | |\n",
"| approx_kl | 0.021515464 |\n", "| approx_kl | 0.022059526 |\n",
"| clip_fraction | 0.335 |\n", "| clip_fraction | 0.331 |\n",
"| clip_range | 0.2 |\n", "| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n", "| entropy_loss | -9.47 |\n",
"| explained_variance | 0.00118 |\n", "| explained_variance | -0.0118 |\n",
"| learning_rate | 0.0003 |\n", "| learning_rate | 0.0003 |\n",
"| loss | 89.5 |\n", "| loss | 130 |\n",
"| n_updates | 10 |\n", "| n_updates | 10 |\n",
"| policy_gradient_loss | -0.0854 |\n", "| policy_gradient_loss | -0.0851 |\n",
"| value_loss | 262 |\n", "| value_loss | 253 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n", "-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 6 |\n", "| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.31 |\n", "| ep_rew_mean | 5.86 |\n",
"| time/ | |\n", "| time/ | |\n",
"| fps | 211 |\n", "| fps | 585 |\n",
"| iterations | 3 |\n", "| iterations | 3 |\n",
"| time_elapsed | 29 |\n", "| time_elapsed | 10 |\n",
"| total_timesteps | 6144 |\n", "| total_timesteps | 6144 |\n",
"| train/ | |\n", "| train/ | |\n",
"| approx_kl | 0.02457875 |\n", "| approx_kl | 0.024416003 |\n",
"| clip_fraction | 0.465 |\n", "| clip_fraction | 0.462 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.152 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 85.2 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0987 |\n",
"| value_loss | 218 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.75 |\n",
"| time/ | |\n",
"| fps | 566 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 14 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.026305672 |\n",
"| clip_fraction | 0.45 |\n",
"| clip_range | 0.2 |\n", "| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n", "| entropy_loss | -9.47 |\n",
"| explained_variance | 0.161 |\n", "| explained_variance | 0.161 |\n",
"| learning_rate | 0.0003 |\n", "| learning_rate | 0.0003 |\n",
"| loss | 118 |\n", "| loss | 144 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0987 |\n",
"| value_loss | 217 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5.96 |\n",
"| ep_rew_mean | 5.79 |\n",
"| time/ | |\n",
"| fps | 196 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 41 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.02515613 |\n",
"| clip_fraction | 0.447 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.151 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 138 |\n",
"| n_updates | 30 |\n", "| n_updates | 30 |\n",
"| policy_gradient_loss | -0.103 |\n", "| policy_gradient_loss | -0.105 |\n",
"| value_loss | 242 |\n", "| value_loss | 220 |\n",
"----------------------------------------\n",
"-----------------------------------------\n", "-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 6 |\n", "| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.9 |\n", "| ep_rew_mean | 1.47 |\n",
"| time/ | |\n", "| time/ | |\n",
"| fps | 177 |\n", "| fps | 554 |\n",
"| iterations | 5 |\n", "| iterations | 5 |\n",
"| time_elapsed | 57 |\n", "| time_elapsed | 18 |\n",
"| total_timesteps | 10240 |\n", "| total_timesteps | 10240 |\n",
"| train/ | |\n", "| train/ | |\n",
"| approx_kl | 0.026685718 |\n", "| approx_kl | 0.02928267 |\n",
"| clip_fraction | 0.444 |\n", "| clip_fraction | 0.498 |\n",
"| clip_range | 0.2 |\n", "| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n", "| entropy_loss | -9.46 |\n",
"| explained_variance | 0.176 |\n", "| explained_variance | 0.167 |\n",
"| learning_rate | 0.0003 |\n", "| learning_rate | 0.0003 |\n",
"| loss | 96.7 |\n", "| loss | 127 |\n",
"| n_updates | 40 |\n", "| n_updates | 40 |\n",
"| policy_gradient_loss | -0.111 |\n", "| policy_gradient_loss | -0.116 |\n",
"| value_loss | 211 |\n", "| value_loss | 207 |\n",
"-----------------------------------------\n",
"----------------------------------------\n", "----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 6 |\n", "| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.19 |\n", "| ep_rew_mean | 1.62 |\n",
"| time/ | |\n", "| time/ | |\n",
"| fps | 164 |\n", "| fps | 546 |\n",
"| iterations | 6 |\n", "| iterations | 6 |\n",
"| time_elapsed | 74 |\n", "| time_elapsed | 22 |\n",
"| total_timesteps | 12288 |\n", "| total_timesteps | 12288 |\n",
"| train/ | |\n", "| train/ | |\n",
"| approx_kl | 0.02762504 |\n", "| approx_kl | 0.028425258 |\n",
"| clip_fraction | 0.463 |\n", "| clip_fraction | 0.483 |\n",
"| clip_range | 0.2 |\n", "| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n", "| entropy_loss | -9.46 |\n",
"| explained_variance | 0.186 |\n", "| explained_variance | 0.143 |\n",
"| learning_rate | 0.0003 |\n", "| learning_rate | 0.0003 |\n",
"| loss | 103 |\n", "| loss | 109 |\n",
"| n_updates | 50 |\n", "| n_updates | 50 |\n",
"| policy_gradient_loss | -0.115 |\n", "| policy_gradient_loss | -0.117 |\n",
"| value_loss | 200 |\n", "| value_loss | 240 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 5.5 |\n",
"| time/ | |\n",
"| fps | 155 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 92 |\n",
"| total_timesteps | 14336 |\n",
"| train/ | |\n",
"| approx_kl | 0.02694263 |\n",
"| clip_fraction | 0.458 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.15 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 84.1 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 225 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 7.27 |\n",
"| time/ | |\n",
"| fps | 154 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 106 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.024316464 |\n",
"| clip_fraction | 0.412 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.173 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 126 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 227 |\n",
"-----------------------------------------\n", "-----------------------------------------\n",
"-----------------------------------------\n", "-----------------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 6 |\n", "| ep_len_mean | 5.98 |\n",
"| ep_rew_mean | 7.8 |\n",
"| time/ | |\n",
"| fps | 151 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 121 |\n",
"| total_timesteps | 18432 |\n",
"| train/ | |\n",
"| approx_kl | 0.022988513 |\n",
"| clip_fraction | 0.391 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.206 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 139 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.111 |\n",
"| value_loss | 228 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 6.14 |\n", "| ep_rew_mean | 6.14 |\n",
"| time/ | |\n", "| time/ | |\n",
"| fps | 153 |\n", "| fps | 541 |\n",
"| iterations | 10 |\n", "| iterations | 7 |\n",
"| time_elapsed | 133 |\n", "| time_elapsed | 26 |\n",
"| total_timesteps | 20480 |\n", "| total_timesteps | 14336 |\n",
"| train/ | |\n", "| train/ | |\n",
"| approx_kl | 0.022813996 |\n", "| approx_kl | 0.026178032 |\n",
"| clip_fraction | 0.372 |\n", "| clip_fraction | 0.453 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.174 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 141 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 235 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.03 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 30 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.02457074 |\n",
"| clip_fraction | 0.423 |\n",
"| clip_range | 0.2 |\n", "| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n", "| entropy_loss | -9.45 |\n",
"| explained_variance | 0.199 |\n", "| explained_variance | 0.171 |\n",
"| learning_rate | 0.0003 |\n", "| learning_rate | 0.0003 |\n",
"| loss | 117 |\n", "| loss | 111 |\n",
"| n_updates | 90 |\n", "| n_updates | 70 |\n",
"| policy_gradient_loss | -0.108 |\n", "| policy_gradient_loss | -0.112 |\n",
"| value_loss | 212 |\n", "| value_loss | 212 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 9.54 |\n",
"| time/ | |\n",
"| fps | 532 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 34 |\n",
"| total_timesteps | 18432 |\n",
"| train/ | |\n",
"| approx_kl | 0.024578478 |\n",
"| clip_fraction | 0.417 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.178 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 121 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.114 |\n",
"| value_loss | 232 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.81 |\n",
"| time/ | |\n",
"| fps | 527 |\n",
"| iterations | 10 |\n",
"| time_elapsed | 38 |\n",
"| total_timesteps | 20480 |\n",
"| train/ | |\n",
"| approx_kl | 0.022704324 |\n",
"| clip_fraction | 0.379 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.194 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 108 |\n",
"| n_updates | 90 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 216 |\n",
"-----------------------------------------\n" "-----------------------------------------\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x2200a962b50>" "<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>"
] ]
}, },
"execution_count": 3, "execution_count": 3,
@ -275,101 +274,48 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n",
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"model = PPO.load(\"dqn_wordle\")" "model = PPO.load(\"dqn_wordle\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[[16 1 9 19 5 3 2 3 3 1]\n", "[[ 7 18 1 19 16 3 3 3 2 3]\n",
" [18 8 5 13 5 2 3 2 3 1]\n", " [16 9 5 14 4 3 3 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 1]\n", " [16 9 5 14 4 3 3 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 1]\n", " [16 9 5 14 4 3 3 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 1]\n", " [ 7 18 1 19 16 3 3 3 2 3]\n",
" [16 1 9 19 5 3 2 3 3 1]]\n", " [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n",
"[[16 1 9 19 5 3 3 3 3 3]\n",
" [18 8 5 13 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]]\n",
"[[16 1 9 19 5 3 3 1 3 3]\n",
" [18 8 5 13 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]]\n",
"[[16 1 9 19 5 1 2 2 3 3]\n",
" [18 8 5 13 5 3 3 3 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]]\n",
"[[16 1 9 19 5 1 1 3 1 1]\n",
" [18 1 11 5 4 2 1 3 2 3]\n",
" [16 1 9 19 5 1 1 3 1 1]\n",
" [16 1 9 19 5 1 1 3 1 1]\n",
" [16 1 9 19 5 1 1 3 1 1]\n",
" [16 1 9 19 5 1 1 3 1 1]]\n",
"[[16 1 9 19 5 3 3 1 1 3]\n",
" [18 1 11 5 4 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 1 1 3]\n",
" [16 1 9 19 5 3 3 1 1 3]\n",
" [16 1 9 19 5 3 3 1 1 3]\n",
" [16 1 9 19 5 3 3 1 1 3]]\n",
"[[16 1 9 19 5 3 3 3 2 1]\n",
" [18 1 11 5 4 3 3 3 2 3]\n",
" [16 1 9 19 5 3 3 3 2 1]\n",
" [16 1 9 19 5 3 3 3 2 1]\n",
" [16 1 9 19 5 3 3 3 2 1]\n",
" [16 1 9 19 5 3 3 3 2 1]]\n",
"[[16 1 9 19 5 3 2 3 3 3]\n",
" [18 8 5 13 5 1 3 3 2 3]\n",
" [16 1 9 19 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 3]]\n",
"[[16 1 9 19 5 3 1 3 3 3]\n",
" [18 8 5 13 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]]\n",
"[[16 1 9 19 5 3 3 3 3 2]\n",
" [18 8 5 13 5 3 3 1 1 2]\n",
" [16 1 9 19 5 3 3 3 3 2]\n",
" [16 1 9 19 5 3 3 3 3 2]\n",
" [16 1 9 19 5 3 3 3 3 2]\n",
" [16 1 9 19 5 3 3 3 3 2]]\n",
"0\n" "0\n"
] ]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
} }
], ],
"source": [ "source": [
"env = gym_wordle.wordle.WordleEnv()\n", "env = gym_wordle.wordle.WordleEnv()\n",
"\n", "\n",
"for i in range(10):\n", "for i in tqdm(range(1000)):\n",
" \n", " \n",
" state, info = env.reset()\n", " state, info = env.reset()\n",
"\n", "\n",
@ -386,18 +332,11 @@
" if info[\"correct\"]:\n", " if info[\"correct\"]:\n",
" wins += 1\n", " wins += 1\n",
"\n", "\n",
"print(state, reward, info)\n",
"\n",
"print(wins)\n" "print(wins)\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print()"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -422,7 +361,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.5" "version": "3.8.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -3,11 +3,10 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
from sty import fg, bg, ef, rs from sty import fg, bg, ef, rs
from collections import Counter from collections import Counter, defaultdict
from gym_wordle.utils import to_english, to_array, get_words from gym_wordle.utils import to_english, to_array, get_words
from typing import Optional from typing import Optional
class WordList(gym.spaces.Discrete): class WordList(gym.spaces.Discrete):
"""Super class for defining a space of valid words according to a specified """Super class for defining a space of valid words according to a specified
list. list.
@ -160,7 +159,7 @@ class WordleEnv(gym.Env):
self.n_rounds = 6 self.n_rounds = 6
self.n_letters = 5 self.n_letters = 5
self.info = {'correct': False, 'guesses': set()} self.info = {'correct': False, 'guesses': defaultdict(int)}
def _highlighter(self, char: str, flag: int) -> str: def _highlighter(self, char: str, flag: int) -> str:
"""Terminal renderer functionality. Properly highlights a character """Terminal renderer functionality. Properly highlights a character
@ -195,7 +194,7 @@ class WordleEnv(gym.Env):
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
self.info = {'correct': False, 'guesses': set()} self.info = {'correct': False, 'guesses': defaultdict(int)}
return self.state, self.info return self.state, self.info
@ -269,13 +268,13 @@ class WordleEnv(gym.Env):
reward += np.sum(self.state[:, 5:] == 3) * -1 reward += np.sum(self.state[:, 5:] == 3) * -1
# guess same word as before # guess same word as before
hashable_action = tuple(action) hashable_action = to_english(action)
if hashable_action in self.info['guesses']: if hashable_action in self.info['guesses']:
reward += -10 reward += -10 * self.info['guesses'][hashable_action]
else: # guess different word else: # guess different word
reward += 10 reward += 10
self.info['guesses'].add(hashable_action) self.info['guesses'][hashable_action] += 1
# for game ending in win or loss # for game ending in win or loss
reward += 10 if correct else -10 if done else 0 reward += 10 if correct else -10 if done else 0