cse151b-final-project/dqn_wordle.ipynb

370 lines
14 KiB
Plaintext
Raw Normal View History

2024-03-13 18:04:30 +00:00
{
"cells": [
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 1,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [],
"source": [
"import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n",
2024-03-13 20:57:23 +00:00
"import numpy as np\n",
2024-03-16 01:48:21 +00:00
"from tqdm import tqdm"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 2,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Monitor<WordleEnv instance>>\n"
]
}
],
2024-03-13 18:04:30 +00:00
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"env = common.monitor.Monitor(env)\n",
2024-03-13 18:04:30 +00:00
"\n",
"print(env)"
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 3,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-16 01:48:21 +00:00
"Using cuda device\n",
2024-03-14 22:00:19 +00:00
"Wrapping the env in a DummyVecEnv.\n",
"---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 2.14 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 750 |\n",
2024-03-14 22:00:19 +00:00
"| iterations | 1 |\n",
2024-03-16 01:48:21 +00:00
"| time_elapsed | 2 |\n",
2024-03-14 22:00:19 +00:00
"| total_timesteps | 2048 |\n",
"---------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 4.59 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 625 |\n",
"| iterations | 2 |\n",
2024-03-16 01:48:21 +00:00
"| time_elapsed | 6 |\n",
"| total_timesteps | 4096 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:48:21 +00:00
"| approx_kl | 0.022059526 |\n",
"| clip_fraction | 0.331 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
2024-03-16 01:48:21 +00:00
"| explained_variance | -0.0118 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:48:21 +00:00
"| loss | 130 |\n",
"| n_updates | 10 |\n",
2024-03-16 01:48:21 +00:00
"| policy_gradient_loss | -0.0851 |\n",
"| value_loss | 253 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 5.86 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 585 |\n",
"| iterations | 3 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 6144 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:48:21 +00:00
"| approx_kl | 0.024416003 |\n",
"| clip_fraction | 0.462 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:48:21 +00:00
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.152 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:48:21 +00:00
"| loss | 85.2 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0987 |\n",
"| value_loss | 218 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.75 |\n",
"| time/ | |\n",
"| fps | 566 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 14 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.026305672 |\n",
"| clip_fraction | 0.45 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.161 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 144 |\n",
"| n_updates | 30 |\n",
"| policy_gradient_loss | -0.105 |\n",
"| value_loss | 220 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 1.47 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 554 |\n",
"| iterations | 5 |\n",
"| time_elapsed | 18 |\n",
"| total_timesteps | 10240 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:48:21 +00:00
"| approx_kl | 0.02928267 |\n",
"| clip_fraction | 0.498 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.46 |\n",
2024-03-16 01:48:21 +00:00
"| explained_variance | 0.167 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:48:21 +00:00
"| loss | 127 |\n",
"| n_updates | 40 |\n",
2024-03-16 01:19:58 +00:00
"| policy_gradient_loss | -0.116 |\n",
2024-03-16 01:48:21 +00:00
"| value_loss | 207 |\n",
2024-03-14 22:00:19 +00:00
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 1.62 |\n",
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 546 |\n",
"| iterations | 6 |\n",
"| time_elapsed | 22 |\n",
"| total_timesteps | 12288 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:48:21 +00:00
"| approx_kl | 0.028425258 |\n",
"| clip_fraction | 0.483 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:48:21 +00:00
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.143 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:48:21 +00:00
"| loss | 109 |\n",
"| n_updates | 50 |\n",
"| policy_gradient_loss | -0.117 |\n",
"| value_loss | 240 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
2024-03-16 01:48:21 +00:00
"| ep_len_mean | 5.98 |\n",
"| ep_rew_mean | 6.14 |\n",
"| time/ | |\n",
"| fps | 541 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 26 |\n",
"| total_timesteps | 14336 |\n",
"| train/ | |\n",
"| approx_kl | 0.026178032 |\n",
"| clip_fraction | 0.453 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.174 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 141 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 235 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.03 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 30 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.02457074 |\n",
"| clip_fraction | 0.423 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.171 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 111 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 212 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
2024-03-14 22:00:19 +00:00
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 9.54 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 532 |\n",
2024-03-16 01:19:58 +00:00
"| iterations | 9 |\n",
2024-03-16 01:48:21 +00:00
"| time_elapsed | 34 |\n",
2024-03-16 01:19:58 +00:00
"| total_timesteps | 18432 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:48:21 +00:00
"| approx_kl | 0.024578478 |\n",
"| clip_fraction | 0.417 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.45 |\n",
2024-03-16 01:48:21 +00:00
"| explained_variance | 0.178 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:48:21 +00:00
"| loss | 121 |\n",
2024-03-16 01:19:58 +00:00
"| n_updates | 80 |\n",
2024-03-16 01:48:21 +00:00
"| policy_gradient_loss | -0.114 |\n",
"| value_loss | 232 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
2024-03-16 01:48:21 +00:00
"| ep_rew_mean | 3.81 |\n",
2024-03-14 22:00:19 +00:00
"| time/ | |\n",
2024-03-16 01:48:21 +00:00
"| fps | 527 |\n",
2024-03-16 01:19:58 +00:00
"| iterations | 10 |\n",
2024-03-16 01:48:21 +00:00
"| time_elapsed | 38 |\n",
2024-03-16 01:19:58 +00:00
"| total_timesteps | 20480 |\n",
2024-03-14 22:00:19 +00:00
"| train/ | |\n",
2024-03-16 01:48:21 +00:00
"| approx_kl | 0.022704324 |\n",
"| clip_fraction | 0.379 |\n",
2024-03-14 22:00:19 +00:00
"| clip_range | 0.2 |\n",
2024-03-16 01:19:58 +00:00
"| entropy_loss | -9.45 |\n",
2024-03-16 01:48:21 +00:00
"| explained_variance | 0.194 |\n",
2024-03-14 22:00:19 +00:00
"| learning_rate | 0.0003 |\n",
2024-03-16 01:48:21 +00:00
"| loss | 108 |\n",
2024-03-16 01:19:58 +00:00
"| n_updates | 90 |\n",
2024-03-16 01:48:21 +00:00
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 216 |\n",
2024-03-14 22:00:19 +00:00
"-----------------------------------------\n"
]
2024-03-14 22:00:19 +00:00
},
{
"data": {
"text/plain": [
2024-03-16 01:48:21 +00:00
"<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>"
2024-03-14 22:00:19 +00:00
]
},
2024-03-16 01:19:58 +00:00
"execution_count": 3,
2024-03-14 22:00:19 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
2024-03-13 18:04:30 +00:00
"source": [
2024-03-16 01:19:58 +00:00
"total_timesteps = 20_000\n",
"model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
"model.learn(total_timesteps=total_timesteps)"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 4,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [],
"source": [
"model.save(\"dqn_wordle\")"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:19:58 +00:00
"execution_count": 5,
2024-03-13 18:04:30 +00:00
"metadata": {},
2024-03-16 01:48:21 +00:00
"outputs": [],
2024-03-13 18:04:30 +00:00
"source": [
"model = PPO.load(\"dqn_wordle\")"
2024-03-13 18:04:30 +00:00
]
},
{
"cell_type": "code",
2024-03-16 01:48:21 +00:00
"execution_count": 7,
2024-03-13 18:04:30 +00:00
"metadata": {},
"outputs": [
2024-03-16 01:48:21 +00:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-16 01:48:21 +00:00
"[[ 7 18 1 19 16 3 3 3 2 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [ 7 18 1 19 16 3 3 3 2 3]\n",
" [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n",
2024-03-14 22:00:19 +00:00
"0\n"
]
2024-03-16 01:48:21 +00:00
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
2024-03-13 20:57:23 +00:00
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"\n",
2024-03-16 01:48:21 +00:00
"for i in tqdm(range(1000)):\n",
" \n",
2024-03-16 01:19:58 +00:00
" state, info = env.reset()\n",
"\n",
" done = False\n",
"\n",
" wins = 0\n",
"\n",
" while not done:\n",
"\n",
" action, _states = model.predict(state, deterministic=True)\n",
"\n",
2024-03-16 01:19:58 +00:00
" state, reward, done, truncated, info = env.step(action)\n",
"\n",
" if info[\"correct\"]:\n",
" wins += 1\n",
2024-03-14 22:00:19 +00:00
"\n",
2024-03-16 01:48:21 +00:00
"print(state, reward, info)\n",
"\n",
2024-03-14 22:00:19 +00:00
"print(wins)\n"
2024-03-13 20:57:23 +00:00
]
2024-03-16 01:19:58 +00:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2024-03-13 18:04:30 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-03-16 01:48:21 +00:00
"version": "3.8.10"
2024-03-13 18:04:30 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}