started new letter guess environment

This commit is contained in:
Ethan Shapiro 2024-03-19 11:52:10 -07:00
parent e799c14ece
commit fc197acb6e
14 changed files with 389053 additions and 700 deletions

87
dqn_letter_gssr.ipynb Normal file
View File

@ -0,0 +1,87 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def load_valid_words(file_path='wordle_words.txt'):\n",
" \"\"\"\n",
" Load valid five-letter words from a specified text file.\n",
"\n",
" Parameters:\n",
" - file_path (str): The path to the text file containing valid words.\n",
"\n",
" Returns:\n",
" - list[str]: A list of valid words loaded from the file.\n",
" \"\"\"\n",
" with open(file_path, 'r') as file:\n",
" valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n",
" return valid_words"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[2], line 5\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menv_checker\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m check_env\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mletter_guess\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LetterGuessingEnv\n\u001b[1;32m----> 5\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mLetterGuessingEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalid_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_valid_words\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Make sure to load your valid words\u001b[39;00m\n\u001b[0;32m 6\u001b[0m check_env(env) \u001b[38;5;66;03m# Optional: Verify the environment is compatible with SB3\u001b[39;00m\n\u001b[0;32m 8\u001b[0m model \u001b[38;5;241m=\u001b[39m PPO(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
"\u001b[1;31mTypeError\u001b[0m: LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'"
]
}
],
"source": [
"from stable_baselines3 import PPO # Or any other suitable RL algorithm\n",
"from stable_baselines3.common.env_checker import check_env\n",
"from letter_guess import LetterGuessingEnv\n",
"\n",
"env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n",
"check_env(env) # Optional: Verify the environment is compatible with SB3\n",
"\n",
"model = PPO(\"MlpPolicy\", env, verbose=1)\n",
"\n",
"# Train for a certain number of timesteps\n",
"model.learn(total_timesteps=100000)\n",
"\n",
"# Save the model\n",
"model.save(\"wordle_ppo_model\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -35,13 +35,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
"Wrapping the env in a DummyVecEnv.\n"
]
},
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "7c52630b65904d5e8e200be505d2121a", "model_id": "6921a0721569456abf5bceac7e7b6b34",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -52,29 +60,20 @@
"metadata": {}, "metadata": {},
"output_type": "display_data" "output_type": "display_data"
}, },
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"----------------------------------\n", "----------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 5 |\n", "| ep_len_mean | 4.97 |\n",
"| ep_rew_mean | -175 |\n", "| ep_rew_mean | -63.8 |\n",
"| exploration_rate | 0.525 |\n", "| exploration_rate | 0.05 |\n",
"| time/ | |\n", "| time/ | |\n",
"| episodes | 10000 |\n", "| episodes | 10000 |\n",
"| fps | 4606 |\n", "| fps | 1628 |\n",
"| time_elapsed | 10 |\n", "| time_elapsed | 30 |\n",
"| total_timesteps | 49989 |\n", "| total_timesteps | 49995 |\n",
"----------------------------------\n" "----------------------------------\n"
] ]
}, },
@ -85,395 +84,17 @@
"----------------------------------\n", "----------------------------------\n",
"| rollout/ | |\n", "| rollout/ | |\n",
"| ep_len_mean | 5 |\n", "| ep_len_mean | 5 |\n",
"| ep_rew_mean | -208 |\n", "| ep_rew_mean | -70.5 |\n",
"| exploration_rate | 0.0502 |\n", "| exploration_rate | 0.05 |\n",
"| time/ | |\n", "| time/ | |\n",
"| episodes | 20000 |\n", "| episodes | 20000 |\n",
"| fps | 1118 |\n", "| fps | 662 |\n",
"| time_elapsed | 89 |\n", "| time_elapsed | 150 |\n",
"| total_timesteps | 99980 |\n", "| total_timesteps | 99992 |\n",
"| train/ | |\n", "| train/ | |\n",
"| learning_rate | 0.0001 |\n", "| learning_rate | 0.0001 |\n",
"| loss | 24.6 |\n", "| loss | 11.7 |\n",
"| n_updates | 12494 |\n", "| n_updates | 12497 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -230 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 30000 |\n",
"| fps | 856 |\n",
"| time_elapsed | 175 |\n",
"| total_timesteps | 149974 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 18.7 |\n",
"| n_updates | 24993 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -242 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 40000 |\n",
"| fps | 766 |\n",
"| time_elapsed | 260 |\n",
"| total_timesteps | 199967 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 24 |\n",
"| n_updates | 37491 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -186 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 50000 |\n",
"| fps | 722 |\n",
"| time_elapsed | 346 |\n",
"| total_timesteps | 249962 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.5 |\n",
"| n_updates | 49990 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -183 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 60000 |\n",
"| fps | 694 |\n",
"| time_elapsed | 431 |\n",
"| total_timesteps | 299957 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.6 |\n",
"| n_updates | 62489 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -181 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 70000 |\n",
"| fps | 675 |\n",
"| time_elapsed | 517 |\n",
"| total_timesteps | 349953 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 26.8 |\n",
"| n_updates | 74988 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -196 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 80000 |\n",
"| fps | 663 |\n",
"| time_elapsed | 603 |\n",
"| total_timesteps | 399936 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.5 |\n",
"| n_updates | 87483 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -174 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 90000 |\n",
"| fps | 653 |\n",
"| time_elapsed | 688 |\n",
"| total_timesteps | 449928 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.1 |\n",
"| n_updates | 99981 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -155 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 100000 |\n",
"| fps | 645 |\n",
"| time_elapsed | 774 |\n",
"| total_timesteps | 499920 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.8 |\n",
"| n_updates | 112479 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -153 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 110000 |\n",
"| fps | 638 |\n",
"| time_elapsed | 860 |\n",
"| total_timesteps | 549916 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 16 |\n",
"| n_updates | 124978 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -164 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 120000 |\n",
"| fps | 633 |\n",
"| time_elapsed | 947 |\n",
"| total_timesteps | 599915 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.8 |\n",
"| n_updates | 137478 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -145 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 130000 |\n",
"| fps | 628 |\n",
"| time_elapsed | 1033 |\n",
"| total_timesteps | 649910 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.8 |\n",
"| n_updates | 149977 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -154 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 140000 |\n",
"| fps | 624 |\n",
"| time_elapsed | 1120 |\n",
"| total_timesteps | 699902 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 20.9 |\n",
"| n_updates | 162475 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -192 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 150000 |\n",
"| fps | 621 |\n",
"| time_elapsed | 1206 |\n",
"| total_timesteps | 749884 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 18.3 |\n",
"| n_updates | 174970 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -170 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 160000 |\n",
"| fps | 618 |\n",
"| time_elapsed | 1293 |\n",
"| total_timesteps | 799869 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 17.7 |\n",
"| n_updates | 187467 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -233 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 170000 |\n",
"| fps | 615 |\n",
"| time_elapsed | 1380 |\n",
"| total_timesteps | 849855 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 21.6 |\n",
"| n_updates | 199963 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -146 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 180000 |\n",
"| fps | 613 |\n",
"| time_elapsed | 1466 |\n",
"| total_timesteps | 899847 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 19.4 |\n",
"| n_updates | 212461 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -142 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 190000 |\n",
"| fps | 611 |\n",
"| time_elapsed | 1553 |\n",
"| total_timesteps | 949846 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 22.9 |\n",
"| n_updates | 224961 |\n",
"----------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5 |\n",
"| ep_rew_mean | -171 |\n",
"| exploration_rate | 0.05 |\n",
"| time/ | |\n",
"| episodes | 200000 |\n",
"| fps | 609 |\n",
"| time_elapsed | 1640 |\n",
"| total_timesteps | 999839 |\n",
"| train/ | |\n",
"| learning_rate | 0.0001 |\n",
"| loss | 20.3 |\n",
"| n_updates | 237459 |\n",
"----------------------------------\n" "----------------------------------\n"
] ]
}, },
@ -503,27 +124,27 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>" "<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>"
] ]
}, },
"execution_count": 5, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"total_timesteps = 1_000_000\n", "total_timesteps = 100_000\n",
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n", "model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)" "model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.save(\"dqn_new_rewards\")" "model.save(\"dqn_new_state\")"
] ]
}, },
{ {
@ -557,6 +178,76 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
" 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
" 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n",
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"0\n" "0\n"
] ]
} }
@ -578,6 +269,7 @@
"\n", "\n",
" state, reward, done, truncated, info = env.step(action)\n", " state, reward, done, truncated, info = env.step(action)\n",
"\n", "\n",
" print(state)\n",
" if info[\"correct\"]:\n", " if info[\"correct\"]:\n",
" wins += 1\n", " wins += 1\n",
"\n", "\n",
@ -586,22 +278,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n", "(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n", " 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n",
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" -130)" " 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
" -50)"
] ]
}, },
"execution_count": 8, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -610,35 +306,6 @@
"state, reward" "state, reward"
] ]
}, },
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"blah = (14, 1, 9, 22, 5)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blah in info['guesses']"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,

129
eric_wordle/.gitignore vendored Normal file
View File

@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

11
eric_wordle/README.md Normal file
View File

@ -0,0 +1,11 @@
# N-dle Solver
A solver designed to beat New York Time's Wordle (link [here](https://www.nytimes.com/games/wordle/index.html)). If you are bored enough, can extend to solve the more general N-dle problem (for quordle, octordle, etc.)
I originally made this out of frustration for the game (and my own lack of lingual talent). One day, my friend thought she could beat my bot. To her dismay, she learned that she is no better than a machine. Let's see if you can do any better (the average number of attempts is 3.6).
## Usage:
1. Run `python main.py --n 1`
2. Follow the prompts
Currently only supports solving for 1 word at a time (i.e. wordle).

126
eric_wordle/ai.py Normal file
View File

@ -0,0 +1,126 @@
import re
import string
import numpy as np
class AI:
def __init__(self, vocab_file, num_letters=5, num_guesses=6):
self.vocab_file = vocab_file
self.num_letters = num_letters
self.num_guesses = 6
self.vocab, self.vocab_scores, self.letter_scores = self.get_vocab(self.vocab_file)
self.best_words = sorted(list(self.vocab_scores.items()), key=lambda tup: tup[1])[::-1]
self.domains = None
self.possible_letters = None
self.reset()
def solve(self):
num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
num_guesses += 1
word = self.sample()
# # Always start with these two words
# if num_guesses == 1:
# word = 'soare'
# elif num_guesses == 2:
# word = 'culti'
print('-----------------------------------------------')
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
print('-----------------------------------------------')
self.arc_consistency(word)
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
def arc_consistency(self, word):
print(f'Performing arc consistency check on {word}...')
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
results = []
# Collect results
for l in word:
while True:
result = input(f'{l}: ')
if result not in ['0', '1', '2']:
print('Incorrect option. Try again.')
continue
results.append(result)
break
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
for i in range(len(word)):
if results[i] == '0':
if word[i] in self.possible_letters:
if word[i] in self.domains[i]:
self.domains[i].remove(word[i])
else:
for j in range(len(self.domains)):
if word[i] in self.domains[j] and len(self.domains[j]) > 1:
self.domains[j].remove(word[i])
if results[i] == '1':
if word[i] in self.domains[i]:
self.domains[i].remove(word[i])
if results[i] == '2':
self.domains[i] = [word[i]]
def reset(self):
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
self.possible_letters = []
def sample(self):
"""
Samples a best word given the current domains
:return:
"""
# Compile a regex of possible words with the current domain
regex_string = ''
for domain in self.domains:
regex_string += ''.join(['[', ''.join(domain), ']', '{1}'])
pattern = re.compile(regex_string)
# From the words with the highest scores, only return the best word that match the regex pattern
for word, _ in self.best_words:
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
return word
def get_vocab(self, vocab_file):
vocab = []
with open(vocab_file, 'r') as f:
for l in f:
vocab.append(l.strip())
# Count letter frequencies at each index
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
for word in vocab:
for i, l in enumerate(word):
letter_freqs[i][l] += 1
# Assign a score to each letter at each index by the probability of it appearing
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
for i in range(len(letter_scores)):
max_freq = np.max(list(letter_freqs[i].values()))
for l in letter_scores[i].keys():
letter_scores[i][l] = letter_freqs[i][l] / max_freq
# Find a sorted list of words ranked by sum of letter scores
vocab_scores = {} # (score, word)
for word in vocab:
score = 0
for i, l in enumerate(word):
score += letter_scores[i][l]
# # Optimization: If repeating letters, deduct a couple points
# if len(set(word)) < len(word):
# score -= 0.25 * (len(word) - len(set(word)))
vocab_scores[word] = score
return vocab, vocab_scores, letter_scores

37
eric_wordle/dist.py Normal file
View File

@ -0,0 +1,37 @@
import string
import numpy as np
words = []
with open('words.txt', 'r') as f:
for l in f:
words.append(l.strip())
# Count letter frequencies at each index
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
for word in words:
for i, l in enumerate(word):
letter_freqs[i][l] += 1
# Assign a score to each letter at each index by the probability of it appearing
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
for i in range(len(letter_scores)):
max_freq = np.max(list(letter_freqs[i].values()))
for l in letter_scores[i].keys():
letter_scores[i][l] = letter_freqs[i][l] / max_freq
# Find a sorted list of words ranked by sum of letter scores
word_scores = [] # (score, word)
for word in words:
score = 0
for i, l in enumerate(word):
score += letter_scores[i][l]
word_scores.append((score, word))
sorted_by_second = sorted(word_scores, key=lambda tup: tup[0])[::-1]
print(sorted_by_second[:10])
for i, (score, word) in enumerate(sorted_by_second):
if word == 'soare':
print(f'{word} with a score of {score} is found at index {i}')

18
eric_wordle/main.py Normal file
View File

@ -0,0 +1,18 @@
import argparse
from ai import AI
def main(args):
if args.n is None:
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
ai = AI(args.vocab_file)
ai.solve()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--n', dest='n', type=int, default=None)
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
args = parser.parse_args()
main(args)

15
eric_wordle/process.py Normal file
View File

@ -0,0 +1,15 @@
import pandas
print('Loading in words dictionary; this may take a while...')
df = pandas.read_json('words_dictionary.json')
print('Done loading words dictionary.')
words = []
for word in df.axes[0].tolist():
if len(word) != 5:
continue
words.append(word)
words.sort()
with open('words.txt', 'w') as f:
for word in words:
f.write(word + '\n')

15919
eric_wordle/words.txt Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -150,7 +150,14 @@ class WordleEnv(gym.Env):
self.action_space = GuessList() self.action_space = GuessList()
self.solution_space = SolutionList() self.solution_space = SolutionList()
self.observation_space = WordleObsSpace() # Example setup based on the flattened state size you're now using
num_position_availability = 26 * 5 # 26 letters for each of the 5 positions
num_global_availability = 26 # Global letter availability
num_letter_found_flags = 5 # One flag for each position
total_size = num_position_availability + num_global_availability + num_letter_found_flags
# Define the observation space to match the flattened state format
self.observation_space = gym.spaces.Box(low=0, high=2, shape=(total_size,), dtype=np.float32)
self._highlights = { self._highlights = {
self.right_pos: (bg.green, bg.rs), self.right_pos: (bg.green, bg.rs),
@ -169,6 +176,7 @@ class WordleEnv(gym.Env):
'not_in_word': set(), # Letters known not to be in the word 'not_in_word': set(), # Letters known not to be in the word
'tried_positions': defaultdict(set) # Positions tried for each letter 'tried_positions': defaultdict(set) # Positions tried for each letter
} }
self.reset()
def _highlighter(self, char: str, flag: int) -> str: def _highlighter(self, char: str, flag: int) -> str:
"""Terminal renderer functionality. Properly highlights a character """Terminal renderer functionality. Properly highlights a character
@ -202,7 +210,11 @@ class WordleEnv(gym.Env):
self.solution = self.solution_space.sample() self.solution = self.solution_space.sample()
self.soln_hash = set(self.solution_space[self.solution]) self.soln_hash = set(self.solution_space[self.solution])
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) self.state = {
'position_availability': [np.ones(26) for _ in range(5)], # Each position can initially have any letter
'global_availability': np.ones(26), # Initially, all letters are available
'letter_found': np.zeros(5) # Initially, no correct letters are found
}
self.info = { self.info = {
'correct': False, 'correct': False,
@ -215,41 +227,32 @@ class WordleEnv(gym.Env):
self.simulate_first_guess() self.simulate_first_guess()
return self.state, self.info return self.get_observation(), self.info
def simulate_first_guess(self): def simulate_first_guess(self):
fixed_first_guess = "rates" fixed_first_guess = "rates" # Example: Using 'rates' as the fixed first guess
fixed_first_guess_array = to_array(fixed_first_guess) # Convert the fixed guess into the appropriate format (e.g., indices of letters)
fixed_guess_indices = to_array(fixed_first_guess)
solution_indices = self.solution_space[self.solution]
# Simulate the feedback for each letter in the fixed first guess for pos in range(5): # Iterate over each position in the word
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array letter_idx = fixed_guess_indices[pos]
for i, letter in enumerate(fixed_first_guess_array): if letter_idx == solution_indices[pos]: # Correct letter in the correct position
if letter in self.solution_space[self.solution]: self.state['position_availability'][pos] = np.zeros(26)
if letter == self.solution_space[self.solution][i]: self.state['position_availability'][pos][letter_idx] = 1
feedback[i] = 1 # Correct position self.state['letter_found'][pos] = 1
else: elif letter_idx in solution_indices: # Correct letter in the wrong position
feedback[i] = 2 # Correct letter, wrong position self.state['position_availability'][pos][letter_idx] = 0
else: # Mark this letter as still available in other positions
feedback[i] = 3 # Letter not in word for other_pos in range(5):
if self.state['letter_found'][other_pos] == 0: # If not already found
# Update the state to reflect the fixed first guess and its feedback self.state['position_availability'][other_pos][letter_idx] = 1
self.state[0, :self.n_letters] = fixed_first_guess_array else: # Letter not in the word
self.state[0, self.n_letters:] = feedback self.state['global_availability'][letter_idx] = 0
# Update all positions to reflect this letter is not in the word
# Update self.info based on the feedback for other_pos in range(5):
for i, flag in enumerate(feedback): self.state['position_availability'][other_pos][letter_idx] = 0
if flag == self.right_pos: self.round = 1 # Increment round to reflect that first guess has been simulated
# Mark letter as correctly placed
self.info['known_positions'][i] = fixed_first_guess_array[i]
elif flag == self.wrong_pos:
# Note the letter is in the word but in a different position
self.info['known_letters'].add(fixed_first_guess_array[i])
elif flag == self.wrong_char:
# Note the letter is not in the word
self.info['not_in_word'].add(fixed_first_guess_array[i])
# Since we're simulating the first guess, increment the round counter
self.round = 1
def render(self, mode: str = 'human'): def render(self, mode: str = 'human'):
"""Renders the Wordle environment. """Renders the Wordle environment.
@ -280,49 +283,46 @@ class WordleEnv(gym.Env):
reward = 0 reward = 0
correct_guess = np.array_equal(guessed_word, solution_word) correct_guess = np.array_equal(guessed_word, solution_word)
# Initialize flags for current guess # Initialize flags for current guess based on the new state structure
current_flags = np.full(self.n_letters, self.wrong_char) current_flags = np.zeros((self.n_letters, 26)) # Replaced with a more detailed flag system
# Track newly discovered information # Track newly discovered information
new_info = False new_info = False
for i in range(self.n_letters): for i in range(self.n_letters):
guessed_letter = guessed_word[i] guessed_letter = guessed_word[i] - 1
if guessed_letter in solution_word: if guessed_letter in solution_word:
# Penalize for reusing a letter found to not be in the word
if guessed_letter in self.info['not_in_word']: if guessed_letter in self.info['not_in_word']:
reward -= 2 reward -= 2 # Penalize for reusing a letter found to not be in the word
# Handle correct letter in the correct position
if guessed_letter == solution_word[i]: if guessed_letter == solution_word[i]:
current_flags[i] = self.right_pos # Handle correct letter in the correct position
if self.info['known_positions'][i] != guessed_letter: current_flags[i, :] = 0 # Set all other letters to not possible
reward += 10 # Large reward for new correct placement current_flags[i, guessed_letter] = 2 # Mark this letter as correct
new_info = True self.info['known_positions'][i] = 1 # Update known_positions
self.info['known_positions'][i] = guessed_letter reward += 10 # Reward for correct placement
else:
reward += 20 # Large reward for repeating correct placement
else:
current_flags[i] = self.wrong_pos
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
reward += 10 # Reward for guessing a letter in a new position
new_info = True new_info = True
else: else:
reward -= 20 # Penalize for not leveraging known information # Correct letter, wrong position
self.info['known_letters'].add(guessed_letter) if self.info['known_positions'][i] == 0:
self.info['tried_positions'][guessed_letter].add(i) # Only update if we haven't already found the correct letter for this position
current_flags[:, guessed_letter] = 2 # Mark this letter as found in another position
reward += 5
new_info = True
else: else:
# New incorrect letter # Letter not in word
if guessed_letter not in self.info['not_in_word']: if guessed_letter not in self.info['not_in_word']:
reward -= 2 # Penalize for guessing a letter not in the word
self.info['not_in_word'].add(guessed_letter) self.info['not_in_word'].add(guessed_letter)
reward -= 2 # Penalize for guessing a letter not in the word
new_info = True new_info = True
else: for pos in range(self.n_letters):
reward -= 15 # Larger penalty for repeating an incorrect letter # Update all positions to reflect this letter is not correct
current_flags[pos, guessed_letter] = 0
# Update observation state with the current guess and flags # Update global letter availability based on the guess
self.state[self.round, :self.n_letters] = guessed_word for letter in range(26):
self.state[self.round, self.n_letters:] = current_flags if letter not in guessed_word or letter in self.info['not_in_word']:
self.state['global_availability'][letter] = 0
# Check if the game is over # Check if the game is over
done = self.round == self.n_rounds - 1 or correct_guess done = self.round == self.n_rounds - 1 or correct_guess
@ -337,4 +337,17 @@ class WordleEnv(gym.Env):
self.round += 1 self.round += 1
return self.state, reward, done, False, self.info return self.get_observation(), reward, done, False, self.info
def get_observation(self):
# Flatten the position-specific letter availability
position_availability_flat = np.concatenate(self.state['position_availability'])
# Global availability is already a 1D array, but ensure consistency in data handling
global_availability_flat = self.state['global_availability'].flatten()
# Concatenate all parts of the state into a single flat array for the DQN input
full_state_flat = np.concatenate(
[position_availability_flat, global_availability_flat, self.state['letter_found']])
return full_state_flat

99
letter_guess.py Normal file
View File

@ -0,0 +1,99 @@
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import re
class LetterGuessingEnv(gym.Env):
"""
Custom Gymnasium environment for a letter guessing game with a focus on forming
valid prefixes and words from a list of valid Wordle words. The environment tracks
the current guess prefix and validates it against known valid words, ending the game
early with a negative reward for invalid prefixes.
"""
metadata = {'render_modes': ['human']}
def __init__(self, valid_words, seed):
self.action_space = spaces.Discrete(26)
self.observation_space = spaces.Box(low=0, high=2, shape=(26*2 + 26*4,), dtype=np.int32)
self.valid_words = valid_words # List of valid Wordle words
self.target_word = '' # Target word for the current episode
self.valid_words_str = '_'.join(self.valid_words) + '_'
self.letter_flags = None
self.letter_positions = None
self.guessed_letters = set()
self.guess_prefix = "" # Tracks the current guess prefix
self.round = 1
self.reset()
def step(self, action):
letter_index = action % 26 # Assuming action is the letter index directly
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
letter = chr(ord('a') + letter_index)
reward = 0
done = False
# Check if the letter has already been used in the guess prefix
if letter in self.guessed_letters:
reward = -1 # Penalize for repeating letters in the prefix
else:
# Add the new letter to the prefix and update guessed letters set
self.guess_prefix += letter
self.guessed_letters.add(letter)
# Update letter flags based on whether the letter is in the target word
if self.target_word_encoded[letter_index] == 1:
self.letter_flags[letter_index, :] = [1, 0] # Update flag for correct guess
else:
self.letter_flags[letter_index, :] = [0, 0] # Update flag for incorrect guess
reward = 1 # Reward for adding new information by trying a new letter
# Update the letter_positions matrix to reflect the new guess
self.letter_positions[:, position] = 0
self.letter_positions[letter_index, position] = 1
# Use regex to check if the current prefix can lead to a valid word
if not re.search(r'\b' + self.guess_prefix, self.valid_words_str):
reward = -5 # Penalize for forming an invalid prefix
done = True # End the episode if the prefix is invalid
# guessed a full word so we reset our guess prefix to guess next round
if len(self.guess_prefix) == len(self.target_word):
self.guess_prefix == ''
self.round += 1
# end after 3 rounds of total guesses
if self.round == 3:
done = True
obs = self._get_obs()
return obs, reward, done, False, {}
def reset(self, seed):
self.target_word = random.choice(self.valid_words)
self.target_word_encoded = self.encode_word(self.target_word)
self.letter_flags = np.ones((26, 2)) * 2
self.letter_positions = np.ones((26, 4))
self.guessed_letters = set()
self.guess_prefix = "" # Reset the guess prefix for the new episode
return self._get_obs()
def encode_word(self, word):
encoded = np.zeros((26,))
for char in word:
index = ord(char) - ord('a')
encoded[index] = 1
return encoded
def _get_obs(self):
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
def render(self, mode='human'):
pass # Optional: Implement rendering logic if needed

View File

@ -1,189 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def my_func()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"defaultdict(<function __main__.<lambda>()>, {})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2, 3, 4]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t['t']"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'x' in t"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"x = np.array([1, 1, 1])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"x[:] = 0"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"'abcde'aaa\n",
" 33221\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

2317
wordle_words.txt Normal file

File diff suppressed because it is too large Load Diff