mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-22 18:49:21 +00:00
Compare commits
15 Commits
arthur-tes
...
4f8ca4aa06
Author | SHA1 | Date | |
---|---|---|---|
|
4f8ca4aa06 | ||
|
b46d335044 | ||
|
284a29d7af | ||
|
3747af9d22 | ||
|
4fb81317f0 | ||
|
12601964bd | ||
|
c448e02512 | ||
|
848d385482 | ||
|
f40301cac9 | ||
|
fc197acb6e | ||
|
e799c14ece | ||
|
bbe9a1891c | ||
|
9172326013 | ||
|
4836be8121 | ||
|
5672169073 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,2 +1,6 @@
|
|||||||
**/data/*
|
**/data/*
|
||||||
**/*.zip
|
**/__pycache__
|
||||||
|
/env
|
||||||
|
**/runs/*
|
||||||
|
**/wandb/*
|
||||||
|
**/models/*
|
545
dqn_letter_gssr.ipynb
Normal file
545
dqn_letter_gssr.ipynb
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def load_valid_words(file_path='wordle_words.txt'):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Load valid five-letter words from a specified text file.\n",
|
||||||
|
"\n",
|
||||||
|
" Parameters:\n",
|
||||||
|
" - file_path (str): The path to the text file containing valid words.\n",
|
||||||
|
"\n",
|
||||||
|
" Returns:\n",
|
||||||
|
" - list[str]: A list of valid words loaded from the file.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" with open(file_path, 'r') as file:\n",
|
||||||
|
" valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n",
|
||||||
|
" return valid_words"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from stable_baselines3 import PPO, DQN # Or any other suitable RL algorithm\n",
|
||||||
|
"from stable_baselines3.common.env_checker import check_env\n",
|
||||||
|
"from letter_guess import LetterGuessingEnv\n",
|
||||||
|
"from tqdm import tqdm"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n",
|
||||||
|
"check_env(env) # Optional: Verify the environment is compatible with SB3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"initial_state = env.clone_state()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"obs, _ = env.reset()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_save_path = \"wordle_ppo_model\"\n",
|
||||||
|
"model = PPO.load(model_save_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"action, _ = model.predict(obs)\n",
|
||||||
|
"obs, reward, done, _, info = env.step(action)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"action % 26"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 28,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ord('f') - ord('a')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'f'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 26,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chr(ord('a') + action % 26)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||||
|
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||||
|
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||||
|
" 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
|
||||||
|
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||||
|
" 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
|
||||||
|
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||||
|
" 1, 1])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"obs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"env.set_state(initial_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"all(env.get_obs() == obs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Perform your action to see the outcome\n",
|
||||||
|
"action = # Define your action\n",
|
||||||
|
"observation, reward, done, info = env.step(action)\n",
|
||||||
|
"\n",
|
||||||
|
"# Revert to the initial state\n",
|
||||||
|
"env.env.set_state(initial_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import wandb\n",
|
||||||
|
"from wandb.integration.sb3 import WandbCallback"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
|
||||||
|
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mltcptgeneral\u001b[0m (\u001b[33mfulltime\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"Tracking run with wandb version 0.16.4"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"Run data is saved locally in <code>/home/art/cse151b-final-project/wandb/run-20240319_211220-cyh5nscz</code>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"Syncing run <strong><a href='https://wandb.ai/fulltime/wordle/runs/cyh5nscz' target=\"_blank\">distinctive-flower-20</a></strong> to <a href='https://wandb.ai/fulltime/wordle' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
" View project at <a href='https://wandb.ai/fulltime/wordle' target=\"_blank\">https://wandb.ai/fulltime/wordle</a>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
" View run at <a href='https://wandb.ai/fulltime/wordle/runs/cyh5nscz' target=\"_blank\">https://wandb.ai/fulltime/wordle/runs/cyh5nscz</a>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.HTML object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model_save_path = \"wordle_ppo_model_test\"\n",
|
||||||
|
"config = {\n",
|
||||||
|
" \"policy_type\": \"MlpPolicy\",\n",
|
||||||
|
" \"total_timesteps\": 200_000\n",
|
||||||
|
"}\n",
|
||||||
|
"run = wandb.init(\n",
|
||||||
|
" project=\"wordle\",\n",
|
||||||
|
" config=config,\n",
|
||||||
|
" sync_tensorboard=True\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Using cuda device\n",
|
||||||
|
"Wrapping the env with a `Monitor` wrapper\n",
|
||||||
|
"Wrapping the env in a DummyVecEnv.\n",
|
||||||
|
"Logging to runs/cyh5nscz/PPO_1\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "ca60c274a90b4dddaf275fe164012f16",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Output()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"---------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 2.54 |\n",
|
||||||
|
"| ep_rew_mean | -3.66 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| fps | 721 |\n",
|
||||||
|
"| iterations | 1 |\n",
|
||||||
|
"| time_elapsed | 2 |\n",
|
||||||
|
"| total_timesteps | 2048 |\n",
|
||||||
|
"---------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"-----------------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 2.53 |\n",
|
||||||
|
"| ep_rew_mean | -3.61 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| fps | 718 |\n",
|
||||||
|
"| iterations | 2 |\n",
|
||||||
|
"| time_elapsed | 5 |\n",
|
||||||
|
"| total_timesteps | 4096 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| approx_kl | 0.011673957 |\n",
|
||||||
|
"| clip_fraction | 0.0292 |\n",
|
||||||
|
"| clip_range | 0.2 |\n",
|
||||||
|
"| entropy_loss | -3.25 |\n",
|
||||||
|
"| explained_variance | -0.126 |\n",
|
||||||
|
"| learning_rate | 0.0003 |\n",
|
||||||
|
"| loss | 0.576 |\n",
|
||||||
|
"| n_updates | 10 |\n",
|
||||||
|
"| policy_gradient_loss | -0.0197 |\n",
|
||||||
|
"| value_loss | 3.58 |\n",
|
||||||
|
"-----------------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"-----------------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 2.7 |\n",
|
||||||
|
"| ep_rew_mean | -3.56 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| fps | 698 |\n",
|
||||||
|
"| iterations | 3 |\n",
|
||||||
|
"| time_elapsed | 8 |\n",
|
||||||
|
"| total_timesteps | 6144 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| approx_kl | 0.019258872 |\n",
|
||||||
|
"| clip_fraction | 0.198 |\n",
|
||||||
|
"| clip_range | 0.2 |\n",
|
||||||
|
"| entropy_loss | -3.22 |\n",
|
||||||
|
"| explained_variance | -0.211 |\n",
|
||||||
|
"| learning_rate | 0.0003 |\n",
|
||||||
|
"| loss | 0.187 |\n",
|
||||||
|
"| n_updates | 20 |\n",
|
||||||
|
"| policy_gradient_loss | -0.0215 |\n",
|
||||||
|
"| value_loss | 0.637 |\n",
|
||||||
|
"-----------------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"-----------------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 2.73 |\n",
|
||||||
|
"| ep_rew_mean | -3.43 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| fps | 681 |\n",
|
||||||
|
"| iterations | 4 |\n",
|
||||||
|
"| time_elapsed | 12 |\n",
|
||||||
|
"| total_timesteps | 8192 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| approx_kl | 0.021500897 |\n",
|
||||||
|
"| clip_fraction | 0.171 |\n",
|
||||||
|
"| clip_range | 0.2 |\n",
|
||||||
|
"| entropy_loss | -3.17 |\n",
|
||||||
|
"| explained_variance | 0.378 |\n",
|
||||||
|
"| learning_rate | 0.0003 |\n",
|
||||||
|
"| loss | 0.185 |\n",
|
||||||
|
"| n_updates | 30 |\n",
|
||||||
|
"| policy_gradient_loss | -0.0214 |\n",
|
||||||
|
"| value_loss | 0.479 |\n",
|
||||||
|
"-----------------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"-----------------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 2.92 |\n",
|
||||||
|
"| ep_rew_mean | -3.36 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| fps | 682 |\n",
|
||||||
|
"| iterations | 5 |\n",
|
||||||
|
"| time_elapsed | 14 |\n",
|
||||||
|
"| total_timesteps | 10240 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| approx_kl | 0.018113121 |\n",
|
||||||
|
"| clip_fraction | 0.101 |\n",
|
||||||
|
"| clip_range | 0.2 |\n",
|
||||||
|
"| entropy_loss | -3.11 |\n",
|
||||||
|
"| explained_variance | 0.448 |\n",
|
||||||
|
"| learning_rate | 0.0003 |\n",
|
||||||
|
"| loss | 0.203 |\n",
|
||||||
|
"| n_updates | 40 |\n",
|
||||||
|
"| policy_gradient_loss | -0.0183 |\n",
|
||||||
|
"| value_loss | 0.455 |\n",
|
||||||
|
"-----------------------------------------\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = PPO(config[\"policy_type\"], env=env, verbose=2, tensorboard_log=f\"runs/{run.id}\", batch_size=64)\n",
|
||||||
|
"\n",
|
||||||
|
"# Train for a certain number of timesteps\n",
|
||||||
|
"model.learn(\n",
|
||||||
|
" total_timesteps=config[\"total_timesteps\"],\n",
|
||||||
|
" callback=WandbCallback(\n",
|
||||||
|
" model_save_path=f\"models/{run.id}\",\n",
|
||||||
|
" verbose=2,\n",
|
||||||
|
" ),\n",
|
||||||
|
"\tprogress_bar=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"run.finish()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model.save(model_save_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = PPO.load(model_save_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"rewards = 0\n",
|
||||||
|
"for i in tqdm(range(1000)):\n",
|
||||||
|
" obs, _ = env.reset()\n",
|
||||||
|
" done = False\n",
|
||||||
|
" while not done:\n",
|
||||||
|
" action, _ = model.predict(obs)\n",
|
||||||
|
" obs, reward, done, _, info = env.step(action)\n",
|
||||||
|
" rewards += reward\n",
|
||||||
|
"print(rewards / 1000)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "env",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
114
dqn_wordle.ipynb
114
dqn_wordle.ipynb
@@ -1,114 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import gym\n",
|
|
||||||
"import gym_wordle\n",
|
|
||||||
"from stable_baselines3 import DQN\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import tqdm"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"env = gym.make(\"Wordle-v0\")\n",
|
|
||||||
"\n",
|
|
||||||
"print(env)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 35,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"total_timesteps = 100000\n",
|
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=0)\n",
|
|
||||||
"model.learn(total_timesteps=total_timesteps, progress_bar=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def test(model):\n",
|
|
||||||
"\n",
|
|
||||||
" end_rewards = []\n",
|
|
||||||
"\n",
|
|
||||||
" for i in range(1000):\n",
|
|
||||||
" \n",
|
|
||||||
" state = env.reset()\n",
|
|
||||||
"\n",
|
|
||||||
" done = False\n",
|
|
||||||
"\n",
|
|
||||||
" while not done:\n",
|
|
||||||
"\n",
|
|
||||||
" action, _states = model.predict(state, deterministic=True)\n",
|
|
||||||
"\n",
|
|
||||||
" state, reward, done, info = env.step(action)\n",
|
|
||||||
" \n",
|
|
||||||
" end_rewards.append(reward == 0)\n",
|
|
||||||
" \n",
|
|
||||||
" return np.sum(end_rewards) / len(end_rewards)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model.save(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = DQN.load(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(test(model))"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.10"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
129
eric_wordle/.gitignore
vendored
Normal file
129
eric_wordle/.gitignore
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
11
eric_wordle/README.md
Normal file
11
eric_wordle/README.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# N-dle Solver
|
||||||
|
|
||||||
|
A solver designed to beat New York Time's Wordle (link [here](https://www.nytimes.com/games/wordle/index.html)). If you are bored enough, can extend to solve the more general N-dle problem (for quordle, octordle, etc.)
|
||||||
|
|
||||||
|
I originally made this out of frustration for the game (and my own lack of lingual talent). One day, my friend thought she could beat my bot. To her dismay, she learned that she is no better than a machine. Let's see if you can do any better (the average number of attempts is 3.6).
|
||||||
|
|
||||||
|
## Usage:
|
||||||
|
1. Run `python main.py --n 1`
|
||||||
|
2. Follow the prompts
|
||||||
|
|
||||||
|
Currently only supports solving for 1 word at a time (i.e. wordle).
|
206
eric_wordle/ai.py
Normal file
206
eric_wordle/ai.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO, DQN
|
||||||
|
from letter_guess import LetterGuessingEnv
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def load_valid_words(file_path='wordle_words.txt'):
|
||||||
|
"""
|
||||||
|
Load valid five-letter words from a specified text file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file_path (str): The path to the text file containing valid words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- list[str]: A list of valid words loaded from the file.
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
||||||
|
return valid_words
|
||||||
|
|
||||||
|
|
||||||
|
class AI:
|
||||||
|
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False, device="cuda"):
|
||||||
|
self.device = device
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
self.num_letters = num_letters
|
||||||
|
self.num_guesses = 6
|
||||||
|
|
||||||
|
self.vocab, self.vocab_scores, self.letter_scores = self.get_vocab(self.vocab_file)
|
||||||
|
self.best_words = sorted(list(self.vocab_scores.items()), key=lambda tup: tup[1])[::-1]
|
||||||
|
|
||||||
|
self.domains = None
|
||||||
|
self.possible_letters = None
|
||||||
|
|
||||||
|
self.use_q_model = use_q_model
|
||||||
|
if use_q_model:
|
||||||
|
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
||||||
|
self.q_env = LetterGuessingEnv(load_valid_words(vocab_file))
|
||||||
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
|
||||||
|
# load model
|
||||||
|
self.q_model = PPO.load(model_file, device=self.device)
|
||||||
|
|
||||||
|
self.reset("")
|
||||||
|
|
||||||
|
def solve_eval(self, results_callback):
|
||||||
|
num_guesses = 0
|
||||||
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
|
num_guesses += 1
|
||||||
|
if self.use_q_model:
|
||||||
|
self.freeze_state = self.q_env.clone_state()
|
||||||
|
|
||||||
|
# sample a word, this would use the q_env_state if the q_model is used
|
||||||
|
word = self.sample(num_guesses)
|
||||||
|
|
||||||
|
# get emulated results
|
||||||
|
results = results_callback(word)
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env.set_state(self.freeze_state)
|
||||||
|
# step the q_env to match the guess we just made
|
||||||
|
for i in range(len(word)):
|
||||||
|
char = word[i]
|
||||||
|
action = ord(char) - ord('a')
|
||||||
|
self.q_env_state, _, _, _, _ = self.q_env.step(action)
|
||||||
|
|
||||||
|
self.arc_consistency(word, results)
|
||||||
|
return num_guesses, word
|
||||||
|
|
||||||
|
def solve(self):
|
||||||
|
num_guesses = 0
|
||||||
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
|
num_guesses += 1
|
||||||
|
if self.use_q_model:
|
||||||
|
self.freeze_state = self.q_env.clone_state()
|
||||||
|
|
||||||
|
# sample a word, this would use the q_env_state if the q_model is used
|
||||||
|
word = self.sample(num_guesses)
|
||||||
|
|
||||||
|
print('-----------------------------------------------')
|
||||||
|
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
|
||||||
|
print('-----------------------------------------------')
|
||||||
|
|
||||||
|
print(f'Performing arc consistency check on {word}...')
|
||||||
|
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
for l in word:
|
||||||
|
while True:
|
||||||
|
result = input(f'{l}: ')
|
||||||
|
if result not in ['0', '1', '2']:
|
||||||
|
print('Incorrect option. Try again.')
|
||||||
|
continue
|
||||||
|
results.append(result)
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env.set_state(self.freeze_state)
|
||||||
|
# step the q_env to match the guess we just made
|
||||||
|
for i in range(len(word)):
|
||||||
|
char = word[i]
|
||||||
|
action = ord(char) - ord('a')
|
||||||
|
self.q_env_state, _, _, _, _ = self.q_env.step(action)
|
||||||
|
|
||||||
|
self.arc_consistency(word, results)
|
||||||
|
return num_guesses, word
|
||||||
|
|
||||||
|
def arc_consistency(self, word, results):
|
||||||
|
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
|
||||||
|
|
||||||
|
for i in range(len(word)):
|
||||||
|
if results[i] == '0':
|
||||||
|
if word[i] in self.possible_letters:
|
||||||
|
if word[i] in self.domains[i]:
|
||||||
|
self.domains[i].remove(word[i])
|
||||||
|
else:
|
||||||
|
for j in range(len(self.domains)):
|
||||||
|
if word[i] in self.domains[j] and len(self.domains[j]) > 1:
|
||||||
|
self.domains[j].remove(word[i])
|
||||||
|
if results[i] == '1':
|
||||||
|
if word[i] in self.domains[i]:
|
||||||
|
self.domains[i].remove(word[i])
|
||||||
|
if results[i] == '2':
|
||||||
|
self.domains[i] = [word[i]]
|
||||||
|
|
||||||
|
def reset(self, target_word):
|
||||||
|
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
||||||
|
self.possible_letters = []
|
||||||
|
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
self.q_env.target_word = target_word
|
||||||
|
|
||||||
|
def sample(self, num_guesses):
|
||||||
|
"""
|
||||||
|
Samples a best word given the current domains
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Compile a regex of possible words with the current domain
|
||||||
|
regex_string = ''
|
||||||
|
for domain in self.domains:
|
||||||
|
regex_string += ''.join(['[', ''.join(domain), ']', '{1}'])
|
||||||
|
pattern = re.compile(regex_string)
|
||||||
|
|
||||||
|
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||||
|
max_qval = float('-inf')
|
||||||
|
best_word = None
|
||||||
|
for word, _ in self.best_words:
|
||||||
|
# reset the state back to before we guessed a word
|
||||||
|
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||||
|
if self.use_q_model and num_guesses == 3:
|
||||||
|
self.q_env.set_state(self.freeze_state)
|
||||||
|
# Use policy to grade word
|
||||||
|
# get the state and action pairs
|
||||||
|
curr_qval = 0
|
||||||
|
|
||||||
|
for l in word:
|
||||||
|
action = ord(l) - ord('a')
|
||||||
|
q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to(self.device))
|
||||||
|
_, _, _, _, _ = self.q_env.step(action)
|
||||||
|
curr_qval += q_val
|
||||||
|
|
||||||
|
if curr_qval > max_qval:
|
||||||
|
max_qval = curr_qval
|
||||||
|
best_word = word
|
||||||
|
else:
|
||||||
|
# otherwise return the word from eric heuristic
|
||||||
|
return word
|
||||||
|
return best_word
|
||||||
|
|
||||||
|
def get_vocab(self, vocab_file):
|
||||||
|
vocab = []
|
||||||
|
with open(vocab_file, 'r') as f:
|
||||||
|
for l in f:
|
||||||
|
vocab.append(l.strip())
|
||||||
|
|
||||||
|
# Count letter frequencies at each index
|
||||||
|
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||||
|
for word in vocab:
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
letter_freqs[i][l] += 1
|
||||||
|
|
||||||
|
# Assign a score to each letter at each index by the probability of it appearing
|
||||||
|
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||||
|
for i in range(len(letter_scores)):
|
||||||
|
max_freq = np.max(list(letter_freqs[i].values()))
|
||||||
|
for l in letter_scores[i].keys():
|
||||||
|
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||||
|
|
||||||
|
# Find a sorted list of words ranked by sum of letter scores
|
||||||
|
vocab_scores = {} # (score, word)
|
||||||
|
for word in vocab:
|
||||||
|
score = 0
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
score += letter_scores[i][l]
|
||||||
|
|
||||||
|
# # Optimization: If repeating letters, deduct a couple points
|
||||||
|
# if len(set(word)) < len(word):
|
||||||
|
# score -= 0.25 * (len(word) - len(set(word)))
|
||||||
|
|
||||||
|
vocab_scores[word] = score
|
||||||
|
|
||||||
|
return vocab, vocab_scores, letter_scores
|
37
eric_wordle/dist.py
Normal file
37
eric_wordle/dist.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
words = []
|
||||||
|
with open('words.txt', 'r') as f:
|
||||||
|
for l in f:
|
||||||
|
words.append(l.strip())
|
||||||
|
|
||||||
|
# Count letter frequencies at each index
|
||||||
|
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||||
|
for word in words:
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
letter_freqs[i][l] += 1
|
||||||
|
|
||||||
|
# Assign a score to each letter at each index by the probability of it appearing
|
||||||
|
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||||
|
for i in range(len(letter_scores)):
|
||||||
|
max_freq = np.max(list(letter_freqs[i].values()))
|
||||||
|
for l in letter_scores[i].keys():
|
||||||
|
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||||
|
|
||||||
|
# Find a sorted list of words ranked by sum of letter scores
|
||||||
|
word_scores = [] # (score, word)
|
||||||
|
for word in words:
|
||||||
|
score = 0
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
score += letter_scores[i][l]
|
||||||
|
word_scores.append((score, word))
|
||||||
|
|
||||||
|
sorted_by_second = sorted(word_scores, key=lambda tup: tup[0])[::-1]
|
||||||
|
print(sorted_by_second[:10])
|
||||||
|
|
||||||
|
for i, (score, word) in enumerate(sorted_by_second):
|
||||||
|
if word == 'soare':
|
||||||
|
print(f'{word} with a score of {score} is found at index {i}')
|
||||||
|
|
63
eric_wordle/eval.py
Normal file
63
eric_wordle/eval.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import argparse
|
||||||
|
from ai import AI
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
global solution
|
||||||
|
|
||||||
|
def result_callback(word):
|
||||||
|
|
||||||
|
global solution
|
||||||
|
|
||||||
|
result = ['0', '0', '0', '0', '0']
|
||||||
|
|
||||||
|
for i, letter in enumerate(word):
|
||||||
|
|
||||||
|
if solution[i] == word[i]:
|
||||||
|
result[i] = '2'
|
||||||
|
elif letter in solution:
|
||||||
|
result[i] = '1'
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
global solution
|
||||||
|
|
||||||
|
if args.n is None:
|
||||||
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
|
||||||
|
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
|
||||||
|
|
||||||
|
total_guesses = 0
|
||||||
|
wins = 0
|
||||||
|
num_eval = args.num_eval
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
for i in tqdm(range(num_eval)):
|
||||||
|
idx = np.random.choice(range(len(ai.vocab)))
|
||||||
|
solution = ai.vocab[idx]
|
||||||
|
|
||||||
|
ai.reset(solution)
|
||||||
|
|
||||||
|
guesses, word = ai.solve_eval(results_callback=result_callback)
|
||||||
|
if word != solution:
|
||||||
|
total_guesses += 5
|
||||||
|
else:
|
||||||
|
total_guesses += guesses
|
||||||
|
wins += 1
|
||||||
|
|
||||||
|
print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||||
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
|
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
|
||||||
|
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
||||||
|
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
||||||
|
parser.add_argument('--device', dest="device", type=str, default="cuda")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
1
eric_wordle/letter_guess.py
Symbolic link
1
eric_wordle/letter_guess.py
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../letter_guess.py
|
22
eric_wordle/main.py
Normal file
22
eric_wordle/main.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import argparse
|
||||||
|
from ai import AI
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.n is None:
|
||||||
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
print(f"using q model? {args.q_model}")
|
||||||
|
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
|
||||||
|
ai.reset("lingo")
|
||||||
|
ai.solve()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||||
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
|
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
||||||
|
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
||||||
|
parser.add_argument('--device', dest="device", type=str, default="cuda")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
15
eric_wordle/process.py
Normal file
15
eric_wordle/process.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import pandas
|
||||||
|
|
||||||
|
print('Loading in words dictionary; this may take a while...')
|
||||||
|
df = pandas.read_json('words_dictionary.json')
|
||||||
|
print('Done loading words dictionary.')
|
||||||
|
words = []
|
||||||
|
for word in df.axes[0].tolist():
|
||||||
|
if len(word) != 5:
|
||||||
|
continue
|
||||||
|
words.append(word)
|
||||||
|
words.sort()
|
||||||
|
|
||||||
|
with open('words.txt', 'w') as f:
|
||||||
|
for word in words:
|
||||||
|
f.write(word + '\n')
|
15919
eric_wordle/words.txt
Normal file
15919
eric_wordle/words.txt
Normal file
File diff suppressed because it is too large
Load Diff
370104
eric_wordle/words_dictionary.json
Normal file
370104
eric_wordle/words_dictionary.json
Normal file
File diff suppressed because it is too large
Load Diff
2
eval.sh
Executable file
2
eval.sh
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000
|
||||||
|
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000 --q_model True --model_file wordle_ppo_model
|
1
inference.sh
Executable file
1
inference.sh
Executable file
@@ -0,0 +1 @@
|
|||||||
|
python eric_wordle/main.py --n 1 --vocab_file wordle_words.txt --q_model True --model_file wordle_ppo_model --device cpu
|
130
letter_guess.py
Normal file
130
letter_guess.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class LetterGuessingEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
Custom Gymnasium environment for a letter guessing game with a focus on forming
|
||||||
|
valid prefixes and words from a list of valid Wordle words. The environment tracks
|
||||||
|
the current guess prefix and validates it against known valid words, ending the game
|
||||||
|
early with a negative reward for invalid prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {'render_modes': ['human']}
|
||||||
|
|
||||||
|
def __init__(self, valid_words, seed=None):
|
||||||
|
self.action_space = spaces.Discrete(26)
|
||||||
|
self.observation_space = spaces.Box(low=0, high=1, shape=(26*2 + 26*4,), dtype=np.int32)
|
||||||
|
|
||||||
|
self.valid_words = valid_words # List of valid Wordle words
|
||||||
|
self.target_word = '' # Target word for the current episode
|
||||||
|
self.valid_words_str = ' '.join(self.valid_words) + ' '
|
||||||
|
self.letter_flags = None
|
||||||
|
self.letter_positions = None
|
||||||
|
self.guessed_letters = set()
|
||||||
|
self.guess_prefix = "" # Tracks the current guess prefix
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def clone_state(self):
|
||||||
|
# Clone the current state
|
||||||
|
return {
|
||||||
|
'target_word': self.target_word,
|
||||||
|
'letter_flags': copy.deepcopy(self.letter_flags),
|
||||||
|
'letter_positions': copy.deepcopy(self.letter_positions),
|
||||||
|
'guessed_letters': copy.deepcopy(self.guessed_letters),
|
||||||
|
'guess_prefix': self.guess_prefix,
|
||||||
|
'round': self.round
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
# Restore the state
|
||||||
|
self.target_word = state['target_word']
|
||||||
|
self.letter_flags = copy.deepcopy(state['letter_flags'])
|
||||||
|
self.letter_positions = copy.deepcopy(state['letter_positions'])
|
||||||
|
self.guessed_letters = copy.deepcopy(state['guessed_letters'])
|
||||||
|
self.guess_prefix = state['guess_prefix']
|
||||||
|
self.round = state['round']
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
letter_index = action # Assuming action is the letter index directly
|
||||||
|
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||||
|
letter = chr(ord('a') + letter_index)
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# Check if the letter has already been used in the guess prefix
|
||||||
|
if letter in self.guessed_letters:
|
||||||
|
reward = -1 # Penalize for repeating letters in the prefix
|
||||||
|
else:
|
||||||
|
# Add the new letter to the prefix and update guessed letters set
|
||||||
|
self.guess_prefix += letter
|
||||||
|
self.guessed_letters.add(letter)
|
||||||
|
|
||||||
|
# Update letter flags based on whether the letter is in the target word
|
||||||
|
if self.target_word[position] == letter:
|
||||||
|
self.letter_flags[letter_index, :] = [1, 0] # Update flag for correct guess
|
||||||
|
elif letter in self.target_word:
|
||||||
|
self.letter_flags[letter_index, :] = [0, 1] # Update flag for correct guess wrong position
|
||||||
|
else:
|
||||||
|
self.letter_flags[letter_index, :] = [0, 0] # Update flag for incorrect guess
|
||||||
|
|
||||||
|
reward = 1 # Reward for adding new information by trying a new letter
|
||||||
|
|
||||||
|
# Update the letter_positions matrix to reflect the new guess
|
||||||
|
if position == 4:
|
||||||
|
self.letter_positions[:, :] = 1
|
||||||
|
else:
|
||||||
|
self.letter_positions[:, position] = 0
|
||||||
|
self.letter_positions[letter_index, position] = 1
|
||||||
|
|
||||||
|
# Use regex to check if the current prefix can lead to a valid word
|
||||||
|
if not re.search(r'\b' + self.guess_prefix, self.valid_words_str):
|
||||||
|
reward = -5 # Penalize for forming an invalid prefix
|
||||||
|
done = True # End the episode if the prefix is invalid
|
||||||
|
|
||||||
|
# guessed a full word so we reset our guess prefix to guess next round
|
||||||
|
if len(self.guess_prefix) == len(self.target_word):
|
||||||
|
self.guess_prefix = ''
|
||||||
|
self.round += 1
|
||||||
|
|
||||||
|
# end after 3 rounds of total guesses
|
||||||
|
if self.round == 3:
|
||||||
|
# reward = 5
|
||||||
|
done = True
|
||||||
|
|
||||||
|
obs = self.get_obs()
|
||||||
|
|
||||||
|
if reward < -5:
|
||||||
|
print(obs, reward, done)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
return obs, reward, done, False, {}
|
||||||
|
|
||||||
|
def reset(self, seed=None):
|
||||||
|
self.target_word = random.choice(self.valid_words)
|
||||||
|
# self.target_word_encoded = self.encode_word(self.target_word)
|
||||||
|
self.letter_flags = np.ones((26, 2), dtype=np.int32)
|
||||||
|
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
||||||
|
self.guessed_letters = set()
|
||||||
|
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||||
|
self.round = 0
|
||||||
|
return self.get_obs(), {}
|
||||||
|
|
||||||
|
def encode_word(self, word):
|
||||||
|
encoded = np.zeros((26,))
|
||||||
|
for char in word:
|
||||||
|
index = ord(char) - ord('a')
|
||||||
|
encoded[index] = 1
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def get_obs(self):
|
||||||
|
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
pass # Optional: Implement rendering logic if needed
|
BIN
wordle_ppo_model.zip
Normal file
BIN
wordle_ppo_model.zip
Normal file
Binary file not shown.
2317
wordle_words.txt
Normal file
2317
wordle_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user