mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-28 10:59:08 +00:00
Compare commits
6 Commits
f40301cac9
...
284a29d7af
Author | SHA1 | Date | |
---|---|---|---|
|
284a29d7af | ||
|
3747af9d22 | ||
|
4fb81317f0 | ||
|
12601964bd | ||
|
c448e02512 | ||
|
848d385482 |
File diff suppressed because it is too large
Load Diff
338
dqn_wordle.ipynb
338
dqn_wordle.ipynb
@ -1,338 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import gym\n",
|
|
||||||
"import gym_wordle\n",
|
|
||||||
"from stable_baselines3 import DQN, PPO, common\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import tqdm"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"<Monitor<WordleEnv instance>>\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
|
||||||
"env = common.monitor.Monitor(env)\n",
|
|
||||||
"\n",
|
|
||||||
"print(env)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Using cuda device\n",
|
|
||||||
"Wrapping the env in a DummyVecEnv.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "6921a0721569456abf5bceac7e7b6b34",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Output()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 4.97 |\n",
|
|
||||||
"| ep_rew_mean | -63.8 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 10000 |\n",
|
|
||||||
"| fps | 1628 |\n",
|
|
||||||
"| time_elapsed | 30 |\n",
|
|
||||||
"| total_timesteps | 49995 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -70.5 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 20000 |\n",
|
|
||||||
"| fps | 662 |\n",
|
|
||||||
"| time_elapsed | 150 |\n",
|
|
||||||
"| total_timesteps | 99992 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 11.7 |\n",
|
|
||||||
"| n_updates | 12497 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
||||||
],
|
|
||||||
"text/plain": []
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
||||||
"</pre>\n"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"total_timesteps = 100_000\n",
|
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
|
||||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model.save(\"dqn_new_state\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
|
||||||
"Exception: code() argument 13 must be str, not int\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
|
||||||
"Exception: code() argument 13 must be str, not int\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# model = DQN.load(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"0\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
|
||||||
"\n",
|
|
||||||
"for i in range(1000):\n",
|
|
||||||
" \n",
|
|
||||||
" state, info = env.reset()\n",
|
|
||||||
"\n",
|
|
||||||
" done = False\n",
|
|
||||||
"\n",
|
|
||||||
" wins = 0\n",
|
|
||||||
"\n",
|
|
||||||
" while not done:\n",
|
|
||||||
"\n",
|
|
||||||
" action, _states = model.predict(state, deterministic=True)\n",
|
|
||||||
"\n",
|
|
||||||
" state, reward, done, truncated, info = env.step(action)\n",
|
|
||||||
"\n",
|
|
||||||
" print(state)\n",
|
|
||||||
" if info[\"correct\"]:\n",
|
|
||||||
" wins += 1\n",
|
|
||||||
"\n",
|
|
||||||
"print(wins)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
|
||||||
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
|
||||||
" 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
|
||||||
" 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
|
|
||||||
" -50)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"state, reward"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
@ -3,9 +3,27 @@ import string
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO, DQN
|
||||||
|
from letter_guess import LetterGuessingEnv
|
||||||
|
|
||||||
|
|
||||||
|
def load_valid_words(file_path='wordle_words.txt'):
|
||||||
|
"""
|
||||||
|
Load valid five-letter words from a specified text file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file_path (str): The path to the text file containing valid words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- list[str]: A list of valid words loaded from the file.
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
||||||
|
return valid_words
|
||||||
|
|
||||||
|
|
||||||
class AI:
|
class AI:
|
||||||
def __init__(self, vocab_file, num_letters=5, num_guesses=6):
|
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.num_letters = num_letters
|
self.num_letters = num_letters
|
||||||
self.num_guesses = 6
|
self.num_guesses = 6
|
||||||
@ -16,8 +34,38 @@ class AI:
|
|||||||
self.domains = None
|
self.domains = None
|
||||||
self.possible_letters = None
|
self.possible_letters = None
|
||||||
|
|
||||||
|
self.use_q_model = use_q_model
|
||||||
|
if use_q_model:
|
||||||
|
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
||||||
|
self.q_env = LetterGuessingEnv(vocab_file)
|
||||||
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
|
||||||
|
# load model
|
||||||
|
self.q_model = PPO.load(model_file)
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def solve_eval(self, results_callback):
|
||||||
|
num_guesses = 0
|
||||||
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
|
num_guesses += 1
|
||||||
|
|
||||||
|
# sample a word, this would use the q_env_state if the q_model is used
|
||||||
|
word = self.sample()
|
||||||
|
|
||||||
|
# get emulated results
|
||||||
|
results = results_callback(word)
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env.set_state(self.q_env_state)
|
||||||
|
# step the q_env to match the guess we just made
|
||||||
|
for i in range(len(word)):
|
||||||
|
char = word[i]
|
||||||
|
action = ord(char) - ord('a')
|
||||||
|
self.q_env_state, _, _, _, _ = self.q_env.step(action)
|
||||||
|
|
||||||
|
self.arc_consistency(word, results)
|
||||||
|
return num_guesses, word
|
||||||
|
|
||||||
def solve(self):
|
def solve(self):
|
||||||
num_guesses = 0
|
num_guesses = 0
|
||||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
@ -33,12 +81,7 @@ class AI:
|
|||||||
print('-----------------------------------------------')
|
print('-----------------------------------------------')
|
||||||
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
|
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
|
||||||
print('-----------------------------------------------')
|
print('-----------------------------------------------')
|
||||||
self.arc_consistency(word)
|
|
||||||
|
|
||||||
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
|
||||||
|
|
||||||
|
|
||||||
def arc_consistency(self, word):
|
|
||||||
print(f'Performing arc consistency check on {word}...')
|
print(f'Performing arc consistency check on {word}...')
|
||||||
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
|
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
|
||||||
results = []
|
results = []
|
||||||
@ -53,6 +96,12 @@ class AI:
|
|||||||
results.append(result)
|
results.append(result)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
self.arc_consistency(word, results)
|
||||||
|
|
||||||
|
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
||||||
|
return num_guesses
|
||||||
|
|
||||||
|
def arc_consistency(self, word, results):
|
||||||
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
|
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
|
||||||
|
|
||||||
for i in range(len(word)):
|
for i in range(len(word)):
|
||||||
@ -70,11 +119,13 @@ class AI:
|
|||||||
if results[i] == '2':
|
if results[i] == '2':
|
||||||
self.domains[i] = [word[i]]
|
self.domains[i] = [word[i]]
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
||||||
self.possible_letters = []
|
self.possible_letters = []
|
||||||
|
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""
|
"""
|
||||||
Samples a best word given the current domains
|
Samples a best word given the current domains
|
||||||
@ -87,9 +138,30 @@ class AI:
|
|||||||
pattern = re.compile(regex_string)
|
pattern = re.compile(regex_string)
|
||||||
|
|
||||||
# From the words with the highest scores, only return the best word that match the regex pattern
|
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||||
|
max_qval = float('-inf')
|
||||||
|
best_word = None
|
||||||
for word, _ in self.best_words:
|
for word, _ in self.best_words:
|
||||||
|
# reset the state back to before we guessed a word
|
||||||
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env.set_state(self.q_env_state)
|
||||||
|
# Use policy to grade word
|
||||||
|
# get the state and action pairs
|
||||||
|
curr_qval = 0
|
||||||
|
|
||||||
|
for l in word:
|
||||||
|
action = ord(l) - ord('a')
|
||||||
|
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action)
|
||||||
|
_, _, _, _, _ = self.q_env.step(action)
|
||||||
|
curr_qval += q_val
|
||||||
|
|
||||||
|
if curr_qval > max_qval:
|
||||||
|
max_qval = curr_qval
|
||||||
|
best_word = word
|
||||||
|
else:
|
||||||
|
# otherwise return the word from eric heuristic
|
||||||
return word
|
return word
|
||||||
|
return best_word
|
||||||
|
|
||||||
def get_vocab(self, vocab_file):
|
def get_vocab(self, vocab_file):
|
||||||
vocab = []
|
vocab = []
|
||||||
|
58
eric_wordle/eval.py
Normal file
58
eric_wordle/eval.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import argparse
|
||||||
|
from ai import AI
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
global solution
|
||||||
|
|
||||||
|
def result_callback(word):
|
||||||
|
|
||||||
|
global solution
|
||||||
|
|
||||||
|
result = ['0', '0', '0', '0', '0']
|
||||||
|
|
||||||
|
for i, letter in enumerate(word):
|
||||||
|
|
||||||
|
if solution[i] == word[i]:
|
||||||
|
result[i] = '2'
|
||||||
|
elif letter in solution:
|
||||||
|
result[i] = '1'
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
global solution
|
||||||
|
|
||||||
|
if args.n is None:
|
||||||
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
|
||||||
|
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model)
|
||||||
|
|
||||||
|
total_guesses = 0
|
||||||
|
wins = 0
|
||||||
|
num_eval = args.num_eval
|
||||||
|
|
||||||
|
for i in tqdm(range(num_eval)):
|
||||||
|
idx = np.random.choice(range(len(ai.vocab)))
|
||||||
|
solution = ai.vocab[idx]
|
||||||
|
guesses, word = ai.solve_eval(results_callback=result_callback)
|
||||||
|
if word != solution:
|
||||||
|
total_guesses += 5
|
||||||
|
else:
|
||||||
|
total_guesses += guesses
|
||||||
|
wins += 1
|
||||||
|
ai.reset()
|
||||||
|
|
||||||
|
print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||||
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
|
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
|
||||||
|
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
||||||
|
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
1
eric_wordle/letter_guess.py
Symbolic link
1
eric_wordle/letter_guess.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../letter_guess.py
|
2
eval.sh
Executable file
2
eval.sh
Executable file
@ -0,0 +1,2 @@
|
|||||||
|
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000
|
||||||
|
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000 --q_model True --model_file wordle_ppo_model
|
@ -3,6 +3,7 @@ from gymnasium import spaces
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
class LetterGuessingEnv(gym.Env):
|
class LetterGuessingEnv(gym.Env):
|
||||||
@ -29,8 +30,28 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def clone_state(self):
|
||||||
|
# Clone the current state
|
||||||
|
return {
|
||||||
|
'target_word': self.target_word,
|
||||||
|
'letter_flags': copy.deepcopy(self.letter_flags),
|
||||||
|
'letter_positions': copy.deepcopy(self.letter_positions),
|
||||||
|
'guessed_letters': copy.deepcopy(self.guessed_letters),
|
||||||
|
'guess_prefix': self.guess_prefix,
|
||||||
|
'round': self.round
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
# Restore the state
|
||||||
|
self.target_word = state['target_word']
|
||||||
|
self.letter_flags = copy.deepcopy(state['letter_flags'])
|
||||||
|
self.letter_positions = copy.deepcopy(state['letter_positions'])
|
||||||
|
self.guessed_letters = copy.deepcopy(state['guessed_letters'])
|
||||||
|
self.guess_prefix = state['guess_prefix']
|
||||||
|
self.round = state['round']
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
letter_index = action % 26 # Assuming action is the letter index directly
|
letter_index = action # Assuming action is the letter index directly
|
||||||
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||||
letter = chr(ord('a') + letter_index)
|
letter = chr(ord('a') + letter_index)
|
||||||
|
|
||||||
@ -72,15 +93,16 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
self.guess_prefix = ''
|
self.guess_prefix = ''
|
||||||
self.round += 1
|
self.round += 1
|
||||||
|
|
||||||
# end after 5 rounds of total guesses
|
# end after 3 rounds of total guesses
|
||||||
if self.round == 2:
|
if self.round == 3:
|
||||||
# reward = 5
|
# reward = 5
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
obs = self._get_obs()
|
obs = self.get_obs()
|
||||||
|
|
||||||
if reward < -50:
|
if reward < -5:
|
||||||
print(obs, reward, done)
|
print(obs, reward, done)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
return obs, reward, done, False, {}
|
return obs, reward, done, False, {}
|
||||||
|
|
||||||
@ -91,8 +113,8 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
||||||
self.guessed_letters = set()
|
self.guessed_letters = set()
|
||||||
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||||
self.round = 1
|
self.round = 0
|
||||||
return self._get_obs(), {}
|
return self.get_obs(), {}
|
||||||
|
|
||||||
def encode_word(self, word):
|
def encode_word(self, word):
|
||||||
encoded = np.zeros((26,))
|
encoded = np.zeros((26,))
|
||||||
@ -101,7 +123,7 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
encoded[index] = 1
|
encoded[index] = 1
|
||||||
return encoded
|
return encoded
|
||||||
|
|
||||||
def _get_obs(self):
|
def get_obs(self):
|
||||||
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
|
Loading…
Reference in New Issue
Block a user