cse151b-final-project/dqn_wordle.ipynb
2024-03-18 11:25:14 -07:00

672 lines
21 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n",
"import numpy as np\n",
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Monitor<WordleEnv instance>>\n"
]
}
],
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"env = common.monitor.Monitor(env)\n",
"\n",
"print(env)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c52630b65904d5e8e200be505d2121a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -175 |\n",
"| exploration_rate | 0.525 |\n",
"| time/ | |\n",
"| episodes | 10000 |\n",
"| fps | 4606 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 49989 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -208 |\n",
"| exploration_rate | 0.0502 |\n",
"| time/ | |\n",
"| episodes | 20000 |\n",
"| fps | 1118 |\n",
"| time_elapsed | 89 |\n",
"| total_timesteps | 99980 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 24.6 |\n",
"| n_updates | 12494 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -230 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 30000 |\n",
"| fps | 856 |\n",
"| time_elapsed | 175 |\n",
"| total_timesteps | 149974 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 18.7 |\n",
"| n_updates | 24993 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -242 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 40000 |\n",
"| fps | 766 |\n",
"| time_elapsed | 260 |\n",
"| total_timesteps | 199967 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 24 |\n",
"| n_updates | 37491 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -186 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 50000 |\n",
"| fps | 722 |\n",
"| time_elapsed | 346 |\n",
"| total_timesteps | 249962 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.5 |\n",
"| n_updates | 49990 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -183 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 60000 |\n",
"| fps | 694 |\n",
"| time_elapsed | 431 |\n",
"| total_timesteps | 299957 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.6 |\n",
"| n_updates | 62489 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -181 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 70000 |\n",
"| fps | 675 |\n",
"| time_elapsed | 517 |\n",
"| total_timesteps | 349953 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 26.8 |\n",
"| n_updates | 74988 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -196 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 80000 |\n",
"| fps | 663 |\n",
"| time_elapsed | 603 |\n",
"| total_timesteps | 399936 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.5 |\n",
"| n_updates | 87483 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -174 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 90000 |\n",
"| fps | 653 |\n",
"| time_elapsed | 688 |\n",
"| total_timesteps | 449928 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.1 |\n",
"| n_updates | 99981 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -155 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 100000 |\n",
"| fps | 645 |\n",
"| time_elapsed | 774 |\n",
"| total_timesteps | 499920 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.8 |\n",
"| n_updates | 112479 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -153 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 110000 |\n",
"| fps | 638 |\n",
"| time_elapsed | 860 |\n",
"| total_timesteps | 549916 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 16 |\n",
"| n_updates | 124978 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -164 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 120000 |\n",
"| fps | 633 |\n",
"| time_elapsed | 947 |\n",
"| total_timesteps | 599915 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.8 |\n",
"| n_updates | 137478 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -145 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 130000 |\n",
"| fps | 628 |\n",
"| time_elapsed | 1033 |\n",
"| total_timesteps | 649910 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.8 |\n",
"| n_updates | 149977 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -154 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 140000 |\n",
"| fps | 624 |\n",
"| time_elapsed | 1120 |\n",
"| total_timesteps | 699902 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 20.9 |\n",
"| n_updates | 162475 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -192 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 150000 |\n",
"| fps | 621 |\n",
"| time_elapsed | 1206 |\n",
"| total_timesteps | 749884 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 18.3 |\n",
"| n_updates | 174970 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -170 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 160000 |\n",
"| fps | 618 |\n",
"| time_elapsed | 1293 |\n",
"| total_timesteps | 799869 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.7 |\n",
"| n_updates | 187467 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -233 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 170000 |\n",
"| fps | 615 |\n",
"| time_elapsed | 1380 |\n",
"| total_timesteps | 849855 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.6 |\n",
"| n_updates | 199963 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -146 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 180000 |\n",
"| fps | 613 |\n",
"| time_elapsed | 1466 |\n",
"| total_timesteps | 899847 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 19.4 |\n",
"| n_updates | 212461 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -142 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 190000 |\n",
"| fps | 611 |\n",
"| time_elapsed | 1553 |\n",
"| total_timesteps | 949846 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.9 |\n",
"| n_updates | 224961 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -171 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 200000 |\n",
"| fps | 609 |\n",
"| time_elapsed | 1640 |\n",
"| total_timesteps | 999839 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 20.3 |\n",
"| n_updates | 237459 |\n",
"----------------------------------\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_timesteps = 1_000_000\n",
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"dqn_new_rewards\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n",
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
"Exception: code() argument 13 must be str, not int\n",
" warnings.warn(\n"
]
}
],
"source": [
"# model = DQN.load(\"dqn_wordle\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
}
],
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"\n",
"for i in range(1000):\n",
" \n",
" state, info = env.reset()\n",
"\n",
" done = False\n",
"\n",
" wins = 0\n",
"\n",
" while not done:\n",
"\n",
" action, _states = model.predict(state, deterministic=True)\n",
"\n",
" state, reward, done, truncated, info = env.step(action)\n",
"\n",
" if info[\"correct\"]:\n",
" wins += 1\n",
"\n",
"print(wins)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
" -130)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state, reward"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"blah = (14, 1, 9, 22, 5)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blah in info['guesses']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}