cse151b-final-project/dqn_wordle.ipynb

431 lines
16 KiB
Plaintext
Raw Normal View History

2024-03-13 18:04:30 +00:00
{
"cells": [
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 1,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n",
2024-03-13 20:57:23 +00:00
"import numpy as np\n",
"import tqdm"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 2,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Monitor<WordleEnv instance>>\n"
]
}
],
2024-03-13 18:04:30 +00:00
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"env = common.monitor.Monitor(env)\n",
2024-03-13 18:04:30 +00:00
"\n",
"print(env)"
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 3,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-16 01:19:58 +00:00
"Using cpu device\n",
2024-03-14 22:00:19 +00:00
"Wrapping the env in a DummyVecEnv.\n",
"---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 5.59 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 544 |\n",
2024-03-14 22:00:19 +00:00
"| iterations | 1 |\n",
2024-03-16 01:19:58 +00:00
"| time_elapsed | 3 |\n",
2024-03-14 22:00:19 +00:00
"| total_timesteps | 2048 |\n",
"---------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 1.77 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 245 |\n",
"| iterations | 2 |\n",
2024-03-16 01:19:58 +00:00
"| time_elapsed | 16 |\n",
"| total_timesteps | 4096 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.021515464 |\n",
"| clip_fraction | 0.335 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
2024-03-16 01:19:58 +00:00
"| explained_variance | 0.00118 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 89.5 |\n",
"| n_updates | 10 |\n",
2024-03-16 01:19:58 +00:00
"| policy_gradient_loss | -0.0854 |\n",
"| value_loss | 262 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 1.31 |\n",
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 211 |\n",
"| iterations | 3 |\n",
2024-03-16 01:19:58 +00:00
"| time_elapsed | 29 |\n",
"| total_timesteps | 6144 |\n",
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.02457875 |\n",
"| clip_fraction | 0.465 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
2024-03-16 01:19:58 +00:00
"| explained_variance | 0.161 |\n",
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 118 |\n",
"| n_updates | 20 |\n",
2024-03-16 01:19:58 +00:00
"| policy_gradient_loss | -0.0987 |\n",
"| value_loss | 217 |\n",
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
2024-03-16 01:19:58 +00:00
"| ep_len_mean | 5.96 |\n",
"| ep_rew_mean | 5.79 |\n",
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 196 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 41 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.02515613 |\n",
"| clip_fraction | 0.447 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
2024-03-16 01:19:58 +00:00
"| explained_variance | 0.151 |\n",
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 138 |\n",
"| n_updates | 30 |\n",
"| policy_gradient_loss | -0.103 |\n",
"| value_loss | 242 |\n",
"----------------------------------------\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 4.9 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 177 |\n",
"| iterations | 5 |\n",
"| time_elapsed | 57 |\n",
"| total_timesteps | 10240 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.026685718 |\n",
"| clip_fraction | 0.444 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
2024-03-16 01:19:58 +00:00
"| explained_variance | 0.176 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 96.7 |\n",
"| n_updates | 40 |\n",
"| policy_gradient_loss | -0.111 |\n",
"| value_loss | 211 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 1.19 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 164 |\n",
"| iterations | 6 |\n",
"| time_elapsed | 74 |\n",
"| total_timesteps | 12288 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.02762504 |\n",
"| clip_fraction | 0.463 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
2024-03-16 01:19:58 +00:00
"| explained_variance | 0.186 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 103 |\n",
"| n_updates | 50 |\n",
"| policy_gradient_loss | -0.115 |\n",
"| value_loss | 200 |\n",
2024-03-14 22:00:19 +00:00
"----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 5.5 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 155 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 92 |\n",
"| total_timesteps | 14336 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.02694263 |\n",
"| clip_fraction | 0.458 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.15 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 84.1 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 225 |\n",
2024-03-14 22:00:19 +00:00
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 7.27 |\n",
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 154 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 106 |\n",
2024-03-16 01:19:58 +00:00
"| total_timesteps | 16384 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.024316464 |\n",
"| clip_fraction | 0.412 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.173 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 126 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 227 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 7.8 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 151 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 121 |\n",
"| total_timesteps | 18432 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.022988513 |\n",
"| clip_fraction | 0.391 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.206 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 139 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.111 |\n",
"| value_loss | 228 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:19:58 +00:00
"| ep_rew_mean | 6.14 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:19:58 +00:00
"| fps | 153 |\n",
"| iterations | 10 |\n",
"| time_elapsed | 133 |\n",
"| total_timesteps | 20480 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:19:58 +00:00
"| approx_kl | 0.022813996 |\n",
"| clip_fraction | 0.372 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.199 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:19:58 +00:00
"| loss | 117 |\n",
"| n_updates | 90 |\n",
"| policy_gradient_loss | -0.108 |\n",
"| value_loss | 212 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n"
]
2024-03-14 22:00:19 +00:00
},
{
"data": {
"text/plain": [
2024-03-16 01:19:58 +00:00
"<stable_baselines3.ppo.ppo.PPO at 0x2200a962b50>"
2024-03-14 22:00:19 +00:00
]
},
2024-03-16 01:19:58 +00:00
"execution_count": 3,
2024-03-14 22:00:19 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2024-03-13 18:04:30 +00:00
"source": [
2024-03-16 01:19:58 +00:00
"total_timesteps = 20_000\n",
"model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
"model.learn(total_timesteps=total_timesteps)"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 4,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [],
"source": [
"model.save(\"dqn_wordle\")"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 5,
2024-03-13 18:04:30 +00:00
"metadata": {},
2024-03-16 01:19:58 +00:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n",
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n"
]
}
],
2024-03-13 18:04:30 +00:00
"source": [
"model = PPO.load(\"dqn_wordle\")"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 6,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-16 01:19:58 +00:00
"[[16 1 9 19 5 3 2 3 3 1]\n",
" [18 8 5 13 5 2 3 2 3 1]\n",
" [16 1 9 19 5 3 2 3 3 1]\n",
" [16 1 9 19 5 3 2 3 3 1]\n",
" [16 1 9 19 5 3 2 3 3 1]\n",
" [16 1 9 19 5 3 2 3 3 1]]\n",
"[[16 1 9 19 5 3 3 3 3 3]\n",
" [18 8 5 13 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 3 3 3]]\n",
"[[16 1 9 19 5 3 3 1 3 3]\n",
" [18 8 5 13 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]\n",
" [16 1 9 19 5 3 3 1 3 3]]\n",
"[[16 1 9 19 5 1 2 2 3 3]\n",
" [18 8 5 13 5 3 3 3 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]\n",
" [16 1 9 19 5 1 2 2 3 3]]\n",
"[[16 1 9 19 5 1 1 3 1 1]\n",
" [18 1 11 5 4 2 1 3 2 3]\n",
" [16 1 9 19 5 1 1 3 1 1]\n",
" [16 1 9 19 5 1 1 3 1 1]\n",
" [16 1 9 19 5 1 1 3 1 1]\n",
" [16 1 9 19 5 1 1 3 1 1]]\n",
"[[16 1 9 19 5 3 3 1 1 3]\n",
" [18 1 11 5 4 3 3 3 3 3]\n",
" [16 1 9 19 5 3 3 1 1 3]\n",
" [16 1 9 19 5 3 3 1 1 3]\n",
" [16 1 9 19 5 3 3 1 1 3]\n",
" [16 1 9 19 5 3 3 1 1 3]]\n",
"[[16 1 9 19 5 3 3 3 2 1]\n",
" [18 1 11 5 4 3 3 3 2 3]\n",
" [16 1 9 19 5 3 3 3 2 1]\n",
" [16 1 9 19 5 3 3 3 2 1]\n",
" [16 1 9 19 5 3 3 3 2 1]\n",
" [16 1 9 19 5 3 3 3 2 1]]\n",
"[[16 1 9 19 5 3 2 3 3 3]\n",
" [18 8 5 13 5 1 3 3 2 3]\n",
" [16 1 9 19 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 3]\n",
" [16 1 9 19 5 3 2 3 3 3]]\n",
"[[16 1 9 19 5 3 1 3 3 3]\n",
" [18 8 5 13 5 3 3 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]\n",
" [16 1 9 19 5 3 1 3 3 3]]\n",
"[[16 1 9 19 5 3 3 3 3 2]\n",
" [18 8 5 13 5 3 3 1 1 2]\n",
" [16 1 9 19 5 3 3 3 3 2]\n",
" [16 1 9 19 5 3 3 3 3 2]\n",
" [16 1 9 19 5 3 3 3 3 2]\n",
" [16 1 9 19 5 3 3 3 3 2]]\n",
2024-03-14 22:00:19 +00:00
"0\n"
]
}
],
2024-03-13 20:57:23 +00:00
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"\n",
2024-03-16 01:19:58 +00:00
"for i in range(10):\n",
" \n",
2024-03-16 01:19:58 +00:00
" state, info = env.reset()\n",
"\n",
" done = False\n",
"\n",
" wins = 0\n",
"\n",
" while not done:\n",
"\n",
" action, _states = model.predict(state, deterministic=True)\n",
"\n",
2024-03-16 01:19:58 +00:00
" state, reward, done, truncated, info = env.step(action)\n",
"\n",
" if info[\"correct\"]:\n",
" wins += 1\n",
2024-03-14 22:00:19 +00:00
"\n",
"print(wins)\n"
2024-03-13 20:57:23 +00:00
]
2024-03-16 01:19:58 +00:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2024-03-13 18:04:30 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-03-16 01:19:58 +00:00
"version": "3.11.5"
2024-03-13 18:04:30 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}