mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-25 17:49:10 +00:00
try penalizing duplicate guesses
This commit is contained in:
parent
bbe9a1891c
commit
cf977e4797
425
dqn_wordle.ipynb
425
dqn_wordle.ipynb
@ -6,11 +6,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import gym_wordle\n",
|
||||
"from stable_baselines3 import DQN, PPO, common\n",
|
||||
"import numpy as np\n",
|
||||
"import tqdm"
|
||||
"from tqdm import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -42,213 +41,213 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using cpu device\n",
|
||||
"Using cuda device\n",
|
||||
"Wrapping the env in a DummyVecEnv.\n",
|
||||
"---------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 5.59 |\n",
|
||||
"| ep_rew_mean | 2.14 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 544 |\n",
|
||||
"| fps | 750 |\n",
|
||||
"| iterations | 1 |\n",
|
||||
"| time_elapsed | 3 |\n",
|
||||
"| time_elapsed | 2 |\n",
|
||||
"| total_timesteps | 2048 |\n",
|
||||
"---------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.77 |\n",
|
||||
"| ep_rew_mean | 4.59 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 245 |\n",
|
||||
"| fps | 625 |\n",
|
||||
"| iterations | 2 |\n",
|
||||
"| time_elapsed | 16 |\n",
|
||||
"| time_elapsed | 6 |\n",
|
||||
"| total_timesteps | 4096 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.021515464 |\n",
|
||||
"| clip_fraction | 0.335 |\n",
|
||||
"| approx_kl | 0.022059526 |\n",
|
||||
"| clip_fraction | 0.331 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.00118 |\n",
|
||||
"| explained_variance | -0.0118 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 89.5 |\n",
|
||||
"| loss | 130 |\n",
|
||||
"| n_updates | 10 |\n",
|
||||
"| policy_gradient_loss | -0.0854 |\n",
|
||||
"| value_loss | 262 |\n",
|
||||
"| policy_gradient_loss | -0.0851 |\n",
|
||||
"| value_loss | 253 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 5.86 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 585 |\n",
|
||||
"| iterations | 3 |\n",
|
||||
"| time_elapsed | 10 |\n",
|
||||
"| total_timesteps | 6144 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.024416003 |\n",
|
||||
"| clip_fraction | 0.462 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.152 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 85.2 |\n",
|
||||
"| n_updates | 20 |\n",
|
||||
"| policy_gradient_loss | -0.0987 |\n",
|
||||
"| value_loss | 218 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 4.75 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 566 |\n",
|
||||
"| iterations | 4 |\n",
|
||||
"| time_elapsed | 14 |\n",
|
||||
"| total_timesteps | 8192 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.026305672 |\n",
|
||||
"| clip_fraction | 0.45 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.161 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 144 |\n",
|
||||
"| n_updates | 30 |\n",
|
||||
"| policy_gradient_loss | -0.105 |\n",
|
||||
"| value_loss | 220 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.31 |\n",
|
||||
"| ep_rew_mean | 1.47 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 211 |\n",
|
||||
"| iterations | 3 |\n",
|
||||
"| time_elapsed | 29 |\n",
|
||||
"| total_timesteps | 6144 |\n",
|
||||
"| fps | 554 |\n",
|
||||
"| iterations | 5 |\n",
|
||||
"| time_elapsed | 18 |\n",
|
||||
"| total_timesteps | 10240 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02457875 |\n",
|
||||
"| clip_fraction | 0.465 |\n",
|
||||
"| approx_kl | 0.02928267 |\n",
|
||||
"| clip_fraction | 0.498 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.161 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.167 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 118 |\n",
|
||||
"| n_updates | 20 |\n",
|
||||
"| policy_gradient_loss | -0.0987 |\n",
|
||||
"| value_loss | 217 |\n",
|
||||
"----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5.96 |\n",
|
||||
"| ep_rew_mean | 5.79 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 196 |\n",
|
||||
"| iterations | 4 |\n",
|
||||
"| time_elapsed | 41 |\n",
|
||||
"| total_timesteps | 8192 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02515613 |\n",
|
||||
"| clip_fraction | 0.447 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.151 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 138 |\n",
|
||||
"| n_updates | 30 |\n",
|
||||
"| policy_gradient_loss | -0.103 |\n",
|
||||
"| value_loss | 242 |\n",
|
||||
"| loss | 127 |\n",
|
||||
"| n_updates | 40 |\n",
|
||||
"| policy_gradient_loss | -0.116 |\n",
|
||||
"| value_loss | 207 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 4.9 |\n",
|
||||
"| ep_rew_mean | 1.62 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 177 |\n",
|
||||
"| iterations | 5 |\n",
|
||||
"| time_elapsed | 57 |\n",
|
||||
"| total_timesteps | 10240 |\n",
|
||||
"| fps | 546 |\n",
|
||||
"| iterations | 6 |\n",
|
||||
"| time_elapsed | 22 |\n",
|
||||
"| total_timesteps | 12288 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.026685718 |\n",
|
||||
"| clip_fraction | 0.444 |\n",
|
||||
"| approx_kl | 0.028425258 |\n",
|
||||
"| clip_fraction | 0.483 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.176 |\n",
|
||||
"| explained_variance | 0.143 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 96.7 |\n",
|
||||
"| n_updates | 40 |\n",
|
||||
"| policy_gradient_loss | -0.111 |\n",
|
||||
"| value_loss | 211 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.19 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 164 |\n",
|
||||
"| iterations | 6 |\n",
|
||||
"| time_elapsed | 74 |\n",
|
||||
"| total_timesteps | 12288 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02762504 |\n",
|
||||
"| clip_fraction | 0.463 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.186 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 103 |\n",
|
||||
"| n_updates | 50 |\n",
|
||||
"| policy_gradient_loss | -0.115 |\n",
|
||||
"| value_loss | 200 |\n",
|
||||
"----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 5.5 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 155 |\n",
|
||||
"| iterations | 7 |\n",
|
||||
"| time_elapsed | 92 |\n",
|
||||
"| total_timesteps | 14336 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02694263 |\n",
|
||||
"| clip_fraction | 0.458 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.15 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 84.1 |\n",
|
||||
"| n_updates | 60 |\n",
|
||||
"| policy_gradient_loss | -0.116 |\n",
|
||||
"| value_loss | 225 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 7.27 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 154 |\n",
|
||||
"| iterations | 8 |\n",
|
||||
"| time_elapsed | 106 |\n",
|
||||
"| total_timesteps | 16384 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.024316464 |\n",
|
||||
"| clip_fraction | 0.412 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.173 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 126 |\n",
|
||||
"| n_updates | 70 |\n",
|
||||
"| policy_gradient_loss | -0.112 |\n",
|
||||
"| value_loss | 227 |\n",
|
||||
"| loss | 109 |\n",
|
||||
"| n_updates | 50 |\n",
|
||||
"| policy_gradient_loss | -0.117 |\n",
|
||||
"| value_loss | 240 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 7.8 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 151 |\n",
|
||||
"| iterations | 9 |\n",
|
||||
"| time_elapsed | 121 |\n",
|
||||
"| total_timesteps | 18432 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022988513 |\n",
|
||||
"| clip_fraction | 0.391 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.206 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 139 |\n",
|
||||
"| n_updates | 80 |\n",
|
||||
"| policy_gradient_loss | -0.111 |\n",
|
||||
"| value_loss | 228 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_len_mean | 5.98 |\n",
|
||||
"| ep_rew_mean | 6.14 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 153 |\n",
|
||||
"| iterations | 10 |\n",
|
||||
"| time_elapsed | 133 |\n",
|
||||
"| total_timesteps | 20480 |\n",
|
||||
"| fps | 541 |\n",
|
||||
"| iterations | 7 |\n",
|
||||
"| time_elapsed | 26 |\n",
|
||||
"| total_timesteps | 14336 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022813996 |\n",
|
||||
"| clip_fraction | 0.372 |\n",
|
||||
"| approx_kl | 0.026178032 |\n",
|
||||
"| clip_fraction | 0.453 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.174 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 141 |\n",
|
||||
"| n_updates | 60 |\n",
|
||||
"| policy_gradient_loss | -0.116 |\n",
|
||||
"| value_loss | 235 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 3.03 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 537 |\n",
|
||||
"| iterations | 8 |\n",
|
||||
"| time_elapsed | 30 |\n",
|
||||
"| total_timesteps | 16384 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02457074 |\n",
|
||||
"| clip_fraction | 0.423 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.171 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 111 |\n",
|
||||
"| n_updates | 70 |\n",
|
||||
"| policy_gradient_loss | -0.112 |\n",
|
||||
"| value_loss | 212 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 9.54 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 532 |\n",
|
||||
"| iterations | 9 |\n",
|
||||
"| time_elapsed | 34 |\n",
|
||||
"| total_timesteps | 18432 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.024578478 |\n",
|
||||
"| clip_fraction | 0.417 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.199 |\n",
|
||||
"| explained_variance | 0.178 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 117 |\n",
|
||||
"| loss | 121 |\n",
|
||||
"| n_updates | 80 |\n",
|
||||
"| policy_gradient_loss | -0.114 |\n",
|
||||
"| value_loss | 232 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 3.81 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 527 |\n",
|
||||
"| iterations | 10 |\n",
|
||||
"| time_elapsed | 38 |\n",
|
||||
"| total_timesteps | 20480 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022704324 |\n",
|
||||
"| clip_fraction | 0.379 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.194 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 108 |\n",
|
||||
"| n_updates | 90 |\n",
|
||||
"| policy_gradient_loss | -0.108 |\n",
|
||||
"| value_loss | 212 |\n",
|
||||
"| policy_gradient_loss | -0.112 |\n",
|
||||
"| value_loss | 216 |\n",
|
||||
"-----------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<stable_baselines3.ppo.ppo.PPO at 0x2200a962b50>"
|
||||
"<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
@ -275,101 +274,48 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"Exception: code() argument 13 must be str, not int\n",
|
||||
" warnings.warn(\n",
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"Exception: code() argument 13 must be str, not int\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO.load(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [18 8 5 13 5 2 3 2 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 1]]\n",
|
||||
"[[16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [18 8 5 13 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 3]]\n",
|
||||
"[[16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [18 8 5 13 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 3 3]]\n",
|
||||
"[[16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [18 8 5 13 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]\n",
|
||||
" [16 1 9 19 5 1 2 2 3 3]]\n",
|
||||
"[[16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [18 1 11 5 4 2 1 3 2 3]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]\n",
|
||||
" [16 1 9 19 5 1 1 3 1 1]]\n",
|
||||
"[[16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [18 1 11 5 4 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]\n",
|
||||
" [16 1 9 19 5 3 3 1 1 3]]\n",
|
||||
"[[16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [18 1 11 5 4 3 3 3 2 3]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]\n",
|
||||
" [16 1 9 19 5 3 3 3 2 1]]\n",
|
||||
"[[16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [18 8 5 13 5 1 3 3 2 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 2 3 3 3]]\n",
|
||||
"[[16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [18 8 5 13 5 3 3 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]\n",
|
||||
" [16 1 9 19 5 3 1 3 3 3]]\n",
|
||||
"[[16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [18 8 5 13 5 3 3 1 1 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]\n",
|
||||
" [16 1 9 19 5 3 3 3 3 2]]\n",
|
||||
"[[ 7 18 1 19 16 3 3 3 2 3]\n",
|
||||
" [16 9 5 14 4 3 3 3 3 3]\n",
|
||||
" [16 9 5 14 4 3 3 3 3 3]\n",
|
||||
" [16 9 5 14 4 3 3 3 3 3]\n",
|
||||
" [ 7 18 1 19 16 3 3 3 2 3]\n",
|
||||
" [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n",
|
||||
"0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||
"\n",
|
||||
"for i in range(10):\n",
|
||||
"for i in tqdm(range(1000)):\n",
|
||||
" \n",
|
||||
" state, info = env.reset()\n",
|
||||
"\n",
|
||||
@ -386,18 +332,11 @@
|
||||
" if info[\"correct\"]:\n",
|
||||
" wins += 1\n",
|
||||
"\n",
|
||||
"print(state, reward, info)\n",
|
||||
"\n",
|
||||
"print(wins)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -422,7 +361,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -3,11 +3,10 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
from collections import Counter
|
||||
from collections import Counter, defaultdict
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
"""Super class for defining a space of valid words according to a specified
|
||||
list.
|
||||
@ -160,7 +159,7 @@ class WordleEnv(gym.Env):
|
||||
|
||||
self.n_rounds = 6
|
||||
self.n_letters = 5
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
||||
|
||||
def _highlighter(self, char: str, flag: int) -> str:
|
||||
"""Terminal renderer functionality. Properly highlights a character
|
||||
@ -195,7 +194,7 @@ class WordleEnv(gym.Env):
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
||||
|
||||
return self.state, self.info
|
||||
|
||||
@ -269,13 +268,13 @@ class WordleEnv(gym.Env):
|
||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||
|
||||
# guess same word as before
|
||||
hashable_action = tuple(action)
|
||||
hashable_action = to_english(action)
|
||||
if hashable_action in self.info['guesses']:
|
||||
reward += -10
|
||||
reward += -10 * self.info['guesses'][hashable_action]
|
||||
else: # guess different word
|
||||
reward += 10
|
||||
|
||||
self.info['guesses'].add(hashable_action)
|
||||
self.info['guesses'][hashable_action] += 1
|
||||
|
||||
# for game ending in win or loss
|
||||
reward += 10 if correct else -10 if done else 0
|
||||
|
Loading…
Reference in New Issue
Block a user