diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb index 1e31b94..2f3f31b 100644 --- a/dqn_wordle.ipynb +++ b/dqn_wordle.ipynb @@ -35,240 +35,495 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7c52630b65904d5e8e200be505d2121a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "Using cpu device\n", - "Wrapping the env in a DummyVecEnv.\n", - "---------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 5.59 |\n", - "| time/ | |\n", - "| fps | 544 |\n", - "| iterations | 1 |\n", - "| time_elapsed | 3 |\n", - "| total_timesteps | 2048 |\n", - "---------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.77 |\n", - "| time/ | |\n", - "| fps | 245 |\n", - "| iterations | 2 |\n", - "| time_elapsed | 16 |\n", - "| total_timesteps | 4096 |\n", - "| train/ | |\n", - "| approx_kl | 0.021515464 |\n", - "| clip_fraction | 0.335 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.00118 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 89.5 |\n", - "| n_updates | 10 |\n", - "| policy_gradient_loss | -0.0854 |\n", - "| value_loss | 262 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.31 |\n", - "| time/ | |\n", - "| fps | 211 |\n", - "| iterations | 3 |\n", - "| time_elapsed | 29 |\n", - "| total_timesteps | 6144 |\n", - "| train/ | |\n", - "| approx_kl | 0.02457875 |\n", - "| clip_fraction | 0.465 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.161 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 118 |\n", - "| n_updates | 20 |\n", - "| policy_gradient_loss | -0.0987 |\n", - "| value_loss | 217 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 5.96 |\n", - "| ep_rew_mean | 5.79 |\n", - "| time/ | |\n", - "| fps | 196 |\n", - "| iterations | 4 |\n", - "| time_elapsed | 41 |\n", - "| total_timesteps | 8192 |\n", - "| train/ | |\n", - "| approx_kl | 0.02515613 |\n", - "| clip_fraction | 0.447 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.151 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 138 |\n", - "| n_updates | 30 |\n", - "| policy_gradient_loss | -0.103 |\n", - "| value_loss | 242 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 4.9 |\n", - "| time/ | |\n", - "| fps | 177 |\n", - "| iterations | 5 |\n", - "| time_elapsed | 57 |\n", - "| total_timesteps | 10240 |\n", - "| train/ | |\n", - "| approx_kl | 0.026685718 |\n", - "| clip_fraction | 0.444 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.176 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 96.7 |\n", - "| n_updates | 40 |\n", - "| policy_gradient_loss | -0.111 |\n", - "| value_loss | 211 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.19 |\n", - "| time/ | |\n", - "| fps | 164 |\n", - "| iterations | 6 |\n", - "| time_elapsed | 74 |\n", - "| total_timesteps | 12288 |\n", - "| train/ | |\n", - "| approx_kl | 0.02762504 |\n", - "| clip_fraction | 0.463 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.186 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 103 |\n", - "| n_updates | 50 |\n", - "| policy_gradient_loss | -0.115 |\n", - "| value_loss | 200 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 5.5 |\n", - "| time/ | |\n", - "| fps | 155 |\n", - "| iterations | 7 |\n", - "| time_elapsed | 92 |\n", - "| total_timesteps | 14336 |\n", - "| train/ | |\n", - "| approx_kl | 0.02694263 |\n", - "| clip_fraction | 0.458 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.15 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 84.1 |\n", - "| n_updates | 60 |\n", - "| policy_gradient_loss | -0.116 |\n", - "| value_loss | 225 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 7.27 |\n", - "| time/ | |\n", - "| fps | 154 |\n", - "| iterations | 8 |\n", - "| time_elapsed | 106 |\n", - "| total_timesteps | 16384 |\n", - "| train/ | |\n", - "| approx_kl | 0.024316464 |\n", - "| clip_fraction | 0.412 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.173 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 126 |\n", - "| n_updates | 70 |\n", - "| policy_gradient_loss | -0.112 |\n", - "| value_loss | 227 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 7.8 |\n", - "| time/ | |\n", - "| fps | 151 |\n", - "| iterations | 9 |\n", - "| time_elapsed | 121 |\n", - "| total_timesteps | 18432 |\n", - "| train/ | |\n", - "| approx_kl | 0.022988513 |\n", - "| clip_fraction | 0.391 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.206 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 139 |\n", - "| n_updates | 80 |\n", - "| policy_gradient_loss | -0.111 |\n", - "| value_loss | 228 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 6.14 |\n", - "| time/ | |\n", - "| fps | 153 |\n", - "| iterations | 10 |\n", - "| time_elapsed | 133 |\n", - "| total_timesteps | 20480 |\n", - "| train/ | |\n", - "| approx_kl | 0.022813996 |\n", - "| clip_fraction | 0.372 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.199 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 117 |\n", - "| n_updates | 90 |\n", - "| policy_gradient_loss | -0.108 |\n", - "| value_loss | 212 |\n", - "-----------------------------------------\n" + "Using cuda device\n", + "Wrapping the env with a `Monitor` wrapper\n", + "Wrapping the env in a DummyVecEnv.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -175 |\n", + "| exploration_rate | 0.525 |\n", + "| time/ | |\n", + "| episodes | 10000 |\n", + "| fps | 4606 |\n", + "| time_elapsed | 10 |\n", + "| total_timesteps | 49989 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -208 |\n", + "| exploration_rate | 0.0502 |\n", + "| time/ | |\n", + "| episodes | 20000 |\n", + "| fps | 1118 |\n", + "| time_elapsed | 89 |\n", + "| total_timesteps | 99980 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 24.6 |\n", + "| n_updates | 12494 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -230 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 30000 |\n", + "| fps | 856 |\n", + "| time_elapsed | 175 |\n", + "| total_timesteps | 149974 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 18.7 |\n", + "| n_updates | 24993 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -242 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 40000 |\n", + "| fps | 766 |\n", + "| time_elapsed | 260 |\n", + "| total_timesteps | 199967 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 24 |\n", + "| n_updates | 37491 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -186 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 50000 |\n", + "| fps | 722 |\n", + "| time_elapsed | 346 |\n", + "| total_timesteps | 249962 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 21.5 |\n", + "| n_updates | 49990 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -183 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 60000 |\n", + "| fps | 694 |\n", + "| time_elapsed | 431 |\n", + "| total_timesteps | 299957 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.6 |\n", + "| n_updates | 62489 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -181 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 70000 |\n", + "| fps | 675 |\n", + "| time_elapsed | 517 |\n", + "| total_timesteps | 349953 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 26.8 |\n", + "| n_updates | 74988 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -196 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 80000 |\n", + "| fps | 663 |\n", + "| time_elapsed | 603 |\n", + "| total_timesteps | 399936 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 22.5 |\n", + "| n_updates | 87483 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -174 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 90000 |\n", + "| fps | 653 |\n", + "| time_elapsed | 688 |\n", + "| total_timesteps | 449928 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 21.1 |\n", + "| n_updates | 99981 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -155 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 100000 |\n", + "| fps | 645 |\n", + "| time_elapsed | 774 |\n", + "| total_timesteps | 499920 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 22.8 |\n", + "| n_updates | 112479 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -153 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 110000 |\n", + "| fps | 638 |\n", + "| time_elapsed | 860 |\n", + "| total_timesteps | 549916 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 16 |\n", + "| n_updates | 124978 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -164 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 120000 |\n", + "| fps | 633 |\n", + "| time_elapsed | 947 |\n", + "| total_timesteps | 599915 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.8 |\n", + "| n_updates | 137478 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -145 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 130000 |\n", + "| fps | 628 |\n", + "| time_elapsed | 1033 |\n", + "| total_timesteps | 649910 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.8 |\n", + "| n_updates | 149977 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -154 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 140000 |\n", + "| fps | 624 |\n", + "| time_elapsed | 1120 |\n", + "| total_timesteps | 699902 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 20.9 |\n", + "| n_updates | 162475 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -192 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 150000 |\n", + "| fps | 621 |\n", + "| time_elapsed | 1206 |\n", + "| total_timesteps | 749884 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 18.3 |\n", + "| n_updates | 174970 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -170 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 160000 |\n", + "| fps | 618 |\n", + "| time_elapsed | 1293 |\n", + "| total_timesteps | 799869 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 17.7 |\n", + "| n_updates | 187467 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -233 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 170000 |\n", + "| fps | 615 |\n", + "| time_elapsed | 1380 |\n", + "| total_timesteps | 849855 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 21.6 |\n", + "| n_updates | 199963 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -146 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 180000 |\n", + "| fps | 613 |\n", + "| time_elapsed | 1466 |\n", + "| total_timesteps | 899847 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 19.4 |\n", + "| n_updates | 212461 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -142 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 190000 |\n", + "| fps | 611 |\n", + "| time_elapsed | 1553 |\n", + "| total_timesteps | 949846 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 22.9 |\n", + "| n_updates | 224961 |\n", + "----------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 5 |\n", + "| ep_rew_mean | -171 |\n", + "| exploration_rate | 0.05 |\n", + "| time/ | |\n", + "| episodes | 200000 |\n", + "| fps | 609 |\n", + "| time_elapsed | 1640 |\n", + "| total_timesteps | 999839 |\n", + "| train/ | |\n", + "| learning_rate | 0.0001 |\n", + "| loss | 20.3 |\n", + "| n_updates | 237459 |\n", + "----------------------------------\n" ] }, { "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], "text/plain": [ - "" + "\n" ] }, - "execution_count": 3, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "total_timesteps = 20_000\n", - "model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n", - "model.learn(total_timesteps=total_timesteps)" + "total_timesteps = 1_000_000\n", + "model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n", + "model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "model.save(\"dqn_wordle\")" + "model.save(\"dqn_new_rewards\")" ] }, { @@ -280,88 +535,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n", + "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", "Exception: code() argument 13 must be str, not int\n", " warnings.warn(\n", - "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", + "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n", "Exception: code() argument 13 must be str, not int\n", " warnings.warn(\n" ] } ], "source": [ - "model = PPO.load(\"dqn_wordle\")" + "# model = DQN.load(\"dqn_wordle\")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[[16 1 9 19 5 3 2 3 3 1]\n", - " [18 8 5 13 5 2 3 2 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]]\n", - "[[16 1 9 19 5 3 3 3 3 3]\n", - " [18 8 5 13 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]]\n", - "[[16 1 9 19 5 3 3 1 3 3]\n", - " [18 8 5 13 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]]\n", - "[[16 1 9 19 5 1 2 2 3 3]\n", - " [18 8 5 13 5 3 3 3 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]]\n", - "[[16 1 9 19 5 1 1 3 1 1]\n", - " [18 1 11 5 4 2 1 3 2 3]\n", - " [16 1 9 19 5 1 1 3 1 1]\n", - " [16 1 9 19 5 1 1 3 1 1]\n", - " [16 1 9 19 5 1 1 3 1 1]\n", - " [16 1 9 19 5 1 1 3 1 1]]\n", - "[[16 1 9 19 5 3 3 1 1 3]\n", - " [18 1 11 5 4 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 1 1 3]\n", - " [16 1 9 19 5 3 3 1 1 3]\n", - " [16 1 9 19 5 3 3 1 1 3]\n", - " [16 1 9 19 5 3 3 1 1 3]]\n", - "[[16 1 9 19 5 3 3 3 2 1]\n", - " [18 1 11 5 4 3 3 3 2 3]\n", - " [16 1 9 19 5 3 3 3 2 1]\n", - " [16 1 9 19 5 3 3 3 2 1]\n", - " [16 1 9 19 5 3 3 3 2 1]\n", - " [16 1 9 19 5 3 3 3 2 1]]\n", - "[[16 1 9 19 5 3 2 3 3 3]\n", - " [18 8 5 13 5 1 3 3 2 3]\n", - " [16 1 9 19 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 2 3 3 3]]\n", - "[[16 1 9 19 5 3 1 3 3 3]\n", - " [18 8 5 13 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]]\n", - "[[16 1 9 19 5 3 3 3 3 2]\n", - " [18 8 5 13 5 3 3 1 1 2]\n", - " [16 1 9 19 5 3 3 3 3 2]\n", - " [16 1 9 19 5 3 3 3 3 2]\n", - " [16 1 9 19 5 3 3 3 3 2]\n", - " [16 1 9 19 5 3 3 3 3 2]]\n", "0\n" ] } @@ -369,7 +564,7 @@ "source": [ "env = gym_wordle.wordle.WordleEnv()\n", "\n", - "for i in range(10):\n", + "for i in range(1000):\n", " \n", " state, info = env.reset()\n", "\n", @@ -386,16 +581,62 @@ " if info[\"correct\"]:\n", " wins += 1\n", "\n", - "print(wins)\n" + "print(wins)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n", + " [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n", + " [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n", + " [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n", + " [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n", + " [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n", + " -130)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state, reward" + ] + }, + { + "cell_type": "code", + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "print()" + "blah = (14, 1, 9, 22, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "blah in info['guesses']" ] }, { diff --git a/gym_wordle/wordle.py b/gym_wordle/wordle.py index ff6586a..23c28b9 100644 --- a/gym_wordle/wordle.py +++ b/gym_wordle/wordle.py @@ -6,6 +6,7 @@ from sty import fg, bg, ef, rs from collections import Counter from gym_wordle.utils import to_english, to_array, get_words from typing import Optional +from collections import defaultdict class WordList(gym.spaces.Discrete): @@ -160,7 +161,14 @@ class WordleEnv(gym.Env): self.n_rounds = 6 self.n_letters = 5 - self.info = {'correct': False, 'guesses': set()} + self.info = { + 'correct': False, + 'guesses': set(), + 'known_positions': np.full(5, -1), # -1 for unknown, else letter index + 'known_letters': set(), # Letters known to be in the word + 'not_in_word': set(), # Letters known not to be in the word + 'tried_positions': defaultdict(set) # Positions tried for each letter + } def _highlighter(self, char: str, flag: int) -> str: """Terminal renderer functionality. Properly highlights a character @@ -192,13 +200,57 @@ class WordleEnv(gym.Env): """ self.round = 0 self.solution = self.solution_space.sample() + self.soln_hash = set(self.solution_space[self.solution]) self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) - self.info = {'correct': False, 'guesses': set()} + self.info = { + 'correct': False, + 'guesses': set(), + 'known_positions': np.full(5, -1), + 'known_letters': set(), + 'not_in_word': set(), + 'tried_positions': defaultdict(set) + } + + self.simulate_first_guess() return self.state, self.info + def simulate_first_guess(self): + fixed_first_guess = "rates" + fixed_first_guess_array = to_array(fixed_first_guess) + + # Simulate the feedback for each letter in the fixed first guess + feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array + for i, letter in enumerate(fixed_first_guess_array): + if letter in self.solution_space[self.solution]: + if letter == self.solution_space[self.solution][i]: + feedback[i] = 1 # Correct position + else: + feedback[i] = 2 # Correct letter, wrong position + else: + feedback[i] = 3 # Letter not in word + + # Update the state to reflect the fixed first guess and its feedback + self.state[0, :self.n_letters] = fixed_first_guess_array + self.state[0, self.n_letters:] = feedback + + # Update self.info based on the feedback + for i, flag in enumerate(feedback): + if flag == self.right_pos: + # Mark letter as correctly placed + self.info['known_positions'][i] = fixed_first_guess_array[i] + elif flag == self.wrong_pos: + # Note the letter is in the word but in a different position + self.info['known_letters'].add(fixed_first_guess_array[i]) + elif flag == self.wrong_char: + # Note the letter is not in the word + self.info['not_in_word'].add(fixed_first_guess_array[i]) + + # Since we're simulating the first guess, increment the round counter + self.round = 1 + def render(self, mode: str = 'human'): """Renders the Wordle environment. @@ -220,67 +272,69 @@ class WordleEnv(gym.Env): super().render(mode=mode) def step(self, action): - """Run one step of the Wordle game. Every game must be previously - initialized by a call to the `reset` method. - - Args: - action: Word guessed by the agent. - - Returns: - state (object): Wordle game state after the guess. - reward (float): Reward associated with the guess. - done (bool): Whether the game has ended. - info (dict): Auxiliary diagnostic information. - """ assert self.action_space.contains(action), 'Invalid word!' - action = self.action_space[action] - solution = self.solution_space[self.solution] + guessed_word = self.action_space[action] + solution_word = self.solution_space[self.solution] - self.state[self.round][:self.n_letters] = action + reward = 0 + correct_guess = np.array_equal(guessed_word, solution_word) - counter = Counter() - for i, char in enumerate(action): - flag_i = i + self.n_letters - counter[char] += 1 + # Initialize flags for current guess + current_flags = np.full(self.n_letters, self.wrong_char) - if char == solution[i]: - self.state[self.round, flag_i] = self.right_pos - elif counter[char] <= (char == solution).sum(): - self.state[self.round, flag_i] = self.wrong_pos + # Track newly discovered information + new_info = False + + for i in range(self.n_letters): + guessed_letter = guessed_word[i] + if guessed_letter in solution_word: + # Penalize for reusing a letter found to not be in the word + if guessed_letter in self.info['not_in_word']: + reward -= 2 + + # Handle correct letter in the correct position + if guessed_letter == solution_word[i]: + current_flags[i] = self.right_pos + if self.info['known_positions'][i] != guessed_letter: + reward += 10 # Large reward for new correct placement + new_info = True + self.info['known_positions'][i] = guessed_letter + else: + reward += 20 # Large reward for repeating correct placement + else: + current_flags[i] = self.wrong_pos + if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]: + reward += 10 # Reward for guessing a letter in a new position + new_info = True + else: + reward -= 20 # Penalize for not leveraging known information + self.info['known_letters'].add(guessed_letter) + self.info['tried_positions'][guessed_letter].add(i) else: - self.state[self.round, flag_i] = self.wrong_char + # New incorrect letter + if guessed_letter not in self.info['not_in_word']: + reward -= 2 # Penalize for guessing a letter not in the word + self.info['not_in_word'].add(guessed_letter) + new_info = True + else: + reward -= 15 # Larger penalty for repeating an incorrect letter + + # Update observation state with the current guess and flags + self.state[self.round, :self.n_letters] = guessed_word + self.state[self.round, self.n_letters:] = current_flags + + # Check if the game is over + done = self.round == self.n_rounds - 1 or correct_guess + self.info['correct'] = correct_guess + + if correct_guess: + reward += 100 # Major reward for winning + elif done: + reward -= 50 # Penalty for losing without using new information effectively + elif not new_info: + reward -= 10 # Penalty if no new information was used in this guess self.round += 1 - correct = (action == solution).all() - game_over = (self.round == self.n_rounds) - - done = correct or game_over - - reward = 0 - # correct spot - reward += np.sum(self.state[:, 5:] == 1) * 2 - - # correct letter not correct spot - reward += np.sum(self.state[:, 5:] == 2) * 1 - - # incorrect letter - reward += np.sum(self.state[:, 5:] == 3) * -1 - - # guess same word as before - hashable_action = tuple(action) - if hashable_action in self.info['guesses']: - reward += -10 - else: # guess different word - reward += 10 - - self.info['guesses'].add(hashable_action) - - # for game ending in win or loss - reward += 10 if correct else -10 if done else 0 - - self.info['correct'] = correct - - # observation, reward, terminated, truncated, info return self.state, reward, done, False, self.info diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..9c4c96e --- /dev/null +++ b/test.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def my_func()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "t = defaultdict(lambda: [0, 1, 2, 3, 4])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "defaultdict(()>, {})" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0, 1, 2, 3, 4]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t['t']" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "defaultdict(()>, {'t': [0, 1, 2, 3, 4]})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'x' in t" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "x = np.array([1, 1, 1])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "x[:] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 0])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "'abcde'aaa\n", + " 33221\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}