mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-25 17:49:10 +00:00
started new letter guess environment
This commit is contained in:
parent
e799c14ece
commit
fc197acb6e
87
dqn_letter_gssr.ipynb
Normal file
87
dqn_letter_gssr.ipynb
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def load_valid_words(file_path='wordle_words.txt'):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Load valid five-letter words from a specified text file.\n",
|
||||||
|
"\n",
|
||||||
|
" Parameters:\n",
|
||||||
|
" - file_path (str): The path to the text file containing valid words.\n",
|
||||||
|
"\n",
|
||||||
|
" Returns:\n",
|
||||||
|
" - list[str]: A list of valid words loaded from the file.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" with open(file_path, 'r') as file:\n",
|
||||||
|
" valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n",
|
||||||
|
" return valid_words"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "TypeError",
|
||||||
|
"evalue": "LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[1;32mIn[2], line 5\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menv_checker\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m check_env\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mletter_guess\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LetterGuessingEnv\n\u001b[1;32m----> 5\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mLetterGuessingEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalid_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_valid_words\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Make sure to load your valid words\u001b[39;00m\n\u001b[0;32m 6\u001b[0m check_env(env) \u001b[38;5;66;03m# Optional: Verify the environment is compatible with SB3\u001b[39;00m\n\u001b[0;32m 8\u001b[0m model \u001b[38;5;241m=\u001b[39m PPO(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
||||||
|
"\u001b[1;31mTypeError\u001b[0m: LetterGuessingEnv.__init__() missing 1 required positional argument: 'seed'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from stable_baselines3 import PPO # Or any other suitable RL algorithm\n",
|
||||||
|
"from stable_baselines3.common.env_checker import check_env\n",
|
||||||
|
"from letter_guess import LetterGuessingEnv\n",
|
||||||
|
"\n",
|
||||||
|
"env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n",
|
||||||
|
"check_env(env) # Optional: Verify the environment is compatible with SB3\n",
|
||||||
|
"\n",
|
||||||
|
"model = PPO(\"MlpPolicy\", env, verbose=1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Train for a certain number of timesteps\n",
|
||||||
|
"model.learn(total_timesteps=100000)\n",
|
||||||
|
"\n",
|
||||||
|
"# Save the model\n",
|
||||||
|
"model.save(\"wordle_ppo_model\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "env",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
557
dqn_wordle.ipynb
557
dqn_wordle.ipynb
@ -35,13 +35,21 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Using cuda device\n",
|
||||||
|
"Wrapping the env in a DummyVecEnv.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "7c52630b65904d5e8e200be505d2121a",
|
"model_id": "6921a0721569456abf5bceac7e7b6b34",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -52,29 +60,20 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "display_data"
|
"output_type": "display_data"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Using cuda device\n",
|
|
||||||
"Wrapping the env with a `Monitor` wrapper\n",
|
|
||||||
"Wrapping the env in a DummyVecEnv.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"----------------------------------\n",
|
"----------------------------------\n",
|
||||||
"| rollout/ | |\n",
|
"| rollout/ | |\n",
|
||||||
"| ep_len_mean | 5 |\n",
|
"| ep_len_mean | 4.97 |\n",
|
||||||
"| ep_rew_mean | -175 |\n",
|
"| ep_rew_mean | -63.8 |\n",
|
||||||
"| exploration_rate | 0.525 |\n",
|
"| exploration_rate | 0.05 |\n",
|
||||||
"| time/ | |\n",
|
"| time/ | |\n",
|
||||||
"| episodes | 10000 |\n",
|
"| episodes | 10000 |\n",
|
||||||
"| fps | 4606 |\n",
|
"| fps | 1628 |\n",
|
||||||
"| time_elapsed | 10 |\n",
|
"| time_elapsed | 30 |\n",
|
||||||
"| total_timesteps | 49989 |\n",
|
"| total_timesteps | 49995 |\n",
|
||||||
"----------------------------------\n"
|
"----------------------------------\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -85,395 +84,17 @@
|
|||||||
"----------------------------------\n",
|
"----------------------------------\n",
|
||||||
"| rollout/ | |\n",
|
"| rollout/ | |\n",
|
||||||
"| ep_len_mean | 5 |\n",
|
"| ep_len_mean | 5 |\n",
|
||||||
"| ep_rew_mean | -208 |\n",
|
"| ep_rew_mean | -70.5 |\n",
|
||||||
"| exploration_rate | 0.0502 |\n",
|
"| exploration_rate | 0.05 |\n",
|
||||||
"| time/ | |\n",
|
"| time/ | |\n",
|
||||||
"| episodes | 20000 |\n",
|
"| episodes | 20000 |\n",
|
||||||
"| fps | 1118 |\n",
|
"| fps | 662 |\n",
|
||||||
"| time_elapsed | 89 |\n",
|
"| time_elapsed | 150 |\n",
|
||||||
"| total_timesteps | 99980 |\n",
|
"| total_timesteps | 99992 |\n",
|
||||||
"| train/ | |\n",
|
"| train/ | |\n",
|
||||||
"| learning_rate | 0.0001 |\n",
|
"| learning_rate | 0.0001 |\n",
|
||||||
"| loss | 24.6 |\n",
|
"| loss | 11.7 |\n",
|
||||||
"| n_updates | 12494 |\n",
|
"| n_updates | 12497 |\n",
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -230 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 30000 |\n",
|
|
||||||
"| fps | 856 |\n",
|
|
||||||
"| time_elapsed | 175 |\n",
|
|
||||||
"| total_timesteps | 149974 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 18.7 |\n",
|
|
||||||
"| n_updates | 24993 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -242 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 40000 |\n",
|
|
||||||
"| fps | 766 |\n",
|
|
||||||
"| time_elapsed | 260 |\n",
|
|
||||||
"| total_timesteps | 199967 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 24 |\n",
|
|
||||||
"| n_updates | 37491 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -186 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 50000 |\n",
|
|
||||||
"| fps | 722 |\n",
|
|
||||||
"| time_elapsed | 346 |\n",
|
|
||||||
"| total_timesteps | 249962 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.5 |\n",
|
|
||||||
"| n_updates | 49990 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -183 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 60000 |\n",
|
|
||||||
"| fps | 694 |\n",
|
|
||||||
"| time_elapsed | 431 |\n",
|
|
||||||
"| total_timesteps | 299957 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.6 |\n",
|
|
||||||
"| n_updates | 62489 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -181 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 70000 |\n",
|
|
||||||
"| fps | 675 |\n",
|
|
||||||
"| time_elapsed | 517 |\n",
|
|
||||||
"| total_timesteps | 349953 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 26.8 |\n",
|
|
||||||
"| n_updates | 74988 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -196 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 80000 |\n",
|
|
||||||
"| fps | 663 |\n",
|
|
||||||
"| time_elapsed | 603 |\n",
|
|
||||||
"| total_timesteps | 399936 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.5 |\n",
|
|
||||||
"| n_updates | 87483 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -174 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 90000 |\n",
|
|
||||||
"| fps | 653 |\n",
|
|
||||||
"| time_elapsed | 688 |\n",
|
|
||||||
"| total_timesteps | 449928 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.1 |\n",
|
|
||||||
"| n_updates | 99981 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -155 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 100000 |\n",
|
|
||||||
"| fps | 645 |\n",
|
|
||||||
"| time_elapsed | 774 |\n",
|
|
||||||
"| total_timesteps | 499920 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.8 |\n",
|
|
||||||
"| n_updates | 112479 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -153 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 110000 |\n",
|
|
||||||
"| fps | 638 |\n",
|
|
||||||
"| time_elapsed | 860 |\n",
|
|
||||||
"| total_timesteps | 549916 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 16 |\n",
|
|
||||||
"| n_updates | 124978 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -164 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 120000 |\n",
|
|
||||||
"| fps | 633 |\n",
|
|
||||||
"| time_elapsed | 947 |\n",
|
|
||||||
"| total_timesteps | 599915 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.8 |\n",
|
|
||||||
"| n_updates | 137478 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -145 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 130000 |\n",
|
|
||||||
"| fps | 628 |\n",
|
|
||||||
"| time_elapsed | 1033 |\n",
|
|
||||||
"| total_timesteps | 649910 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.8 |\n",
|
|
||||||
"| n_updates | 149977 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -154 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 140000 |\n",
|
|
||||||
"| fps | 624 |\n",
|
|
||||||
"| time_elapsed | 1120 |\n",
|
|
||||||
"| total_timesteps | 699902 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 20.9 |\n",
|
|
||||||
"| n_updates | 162475 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -192 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 150000 |\n",
|
|
||||||
"| fps | 621 |\n",
|
|
||||||
"| time_elapsed | 1206 |\n",
|
|
||||||
"| total_timesteps | 749884 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 18.3 |\n",
|
|
||||||
"| n_updates | 174970 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -170 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 160000 |\n",
|
|
||||||
"| fps | 618 |\n",
|
|
||||||
"| time_elapsed | 1293 |\n",
|
|
||||||
"| total_timesteps | 799869 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.7 |\n",
|
|
||||||
"| n_updates | 187467 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -233 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 170000 |\n",
|
|
||||||
"| fps | 615 |\n",
|
|
||||||
"| time_elapsed | 1380 |\n",
|
|
||||||
"| total_timesteps | 849855 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.6 |\n",
|
|
||||||
"| n_updates | 199963 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -146 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 180000 |\n",
|
|
||||||
"| fps | 613 |\n",
|
|
||||||
"| time_elapsed | 1466 |\n",
|
|
||||||
"| total_timesteps | 899847 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 19.4 |\n",
|
|
||||||
"| n_updates | 212461 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -142 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 190000 |\n",
|
|
||||||
"| fps | 611 |\n",
|
|
||||||
"| time_elapsed | 1553 |\n",
|
|
||||||
"| total_timesteps | 949846 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.9 |\n",
|
|
||||||
"| n_updates | 224961 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -171 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 200000 |\n",
|
|
||||||
"| fps | 609 |\n",
|
|
||||||
"| time_elapsed | 1640 |\n",
|
|
||||||
"| total_timesteps | 999839 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 20.3 |\n",
|
|
||||||
"| n_updates | 237459 |\n",
|
|
||||||
"----------------------------------\n"
|
"----------------------------------\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -503,27 +124,27 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>"
|
"<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"total_timesteps = 1_000_000\n",
|
"total_timesteps = 100_000\n",
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model.save(\"dqn_new_rewards\")"
|
"model.save(\"dqn_new_state\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -557,6 +178,76 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
"0\n"
|
"0\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -578,6 +269,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" state, reward, done, truncated, info = env.step(action)\n",
|
" state, reward, done, truncated, info = env.step(action)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" print(state)\n",
|
||||||
" if info[\"correct\"]:\n",
|
" if info[\"correct\"]:\n",
|
||||||
" wins += 1\n",
|
" wins += 1\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -586,22 +278,26 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
|
"(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n",
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
" 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n",
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" -130)"
|
" 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
|
||||||
|
" -50)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -610,35 +306,6 @@
|
|||||||
"state, reward"
|
"state, reward"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 21,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"blah = (14, 1, 9, 22, 5)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"True"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"blah in info['guesses']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
129
eric_wordle/.gitignore
vendored
Normal file
129
eric_wordle/.gitignore
vendored
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
11
eric_wordle/README.md
Normal file
11
eric_wordle/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# N-dle Solver
|
||||||
|
|
||||||
|
A solver designed to beat New York Time's Wordle (link [here](https://www.nytimes.com/games/wordle/index.html)). If you are bored enough, can extend to solve the more general N-dle problem (for quordle, octordle, etc.)
|
||||||
|
|
||||||
|
I originally made this out of frustration for the game (and my own lack of lingual talent). One day, my friend thought she could beat my bot. To her dismay, she learned that she is no better than a machine. Let's see if you can do any better (the average number of attempts is 3.6).
|
||||||
|
|
||||||
|
## Usage:
|
||||||
|
1. Run `python main.py --n 1`
|
||||||
|
2. Follow the prompts
|
||||||
|
|
||||||
|
Currently only supports solving for 1 word at a time (i.e. wordle).
|
126
eric_wordle/ai.py
Normal file
126
eric_wordle/ai.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AI:
|
||||||
|
def __init__(self, vocab_file, num_letters=5, num_guesses=6):
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
self.num_letters = num_letters
|
||||||
|
self.num_guesses = 6
|
||||||
|
|
||||||
|
self.vocab, self.vocab_scores, self.letter_scores = self.get_vocab(self.vocab_file)
|
||||||
|
self.best_words = sorted(list(self.vocab_scores.items()), key=lambda tup: tup[1])[::-1]
|
||||||
|
|
||||||
|
self.domains = None
|
||||||
|
self.possible_letters = None
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def solve(self):
|
||||||
|
num_guesses = 0
|
||||||
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
|
num_guesses += 1
|
||||||
|
word = self.sample()
|
||||||
|
|
||||||
|
# # Always start with these two words
|
||||||
|
# if num_guesses == 1:
|
||||||
|
# word = 'soare'
|
||||||
|
# elif num_guesses == 2:
|
||||||
|
# word = 'culti'
|
||||||
|
|
||||||
|
print('-----------------------------------------------')
|
||||||
|
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
|
||||||
|
print('-----------------------------------------------')
|
||||||
|
self.arc_consistency(word)
|
||||||
|
|
||||||
|
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
||||||
|
|
||||||
|
|
||||||
|
def arc_consistency(self, word):
|
||||||
|
print(f'Performing arc consistency check on {word}...')
|
||||||
|
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
for l in word:
|
||||||
|
while True:
|
||||||
|
result = input(f'{l}: ')
|
||||||
|
if result not in ['0', '1', '2']:
|
||||||
|
print('Incorrect option. Try again.')
|
||||||
|
continue
|
||||||
|
results.append(result)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
|
||||||
|
|
||||||
|
for i in range(len(word)):
|
||||||
|
if results[i] == '0':
|
||||||
|
if word[i] in self.possible_letters:
|
||||||
|
if word[i] in self.domains[i]:
|
||||||
|
self.domains[i].remove(word[i])
|
||||||
|
else:
|
||||||
|
for j in range(len(self.domains)):
|
||||||
|
if word[i] in self.domains[j] and len(self.domains[j]) > 1:
|
||||||
|
self.domains[j].remove(word[i])
|
||||||
|
if results[i] == '1':
|
||||||
|
if word[i] in self.domains[i]:
|
||||||
|
self.domains[i].remove(word[i])
|
||||||
|
if results[i] == '2':
|
||||||
|
self.domains[i] = [word[i]]
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
||||||
|
self.possible_letters = []
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""
|
||||||
|
Samples a best word given the current domains
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Compile a regex of possible words with the current domain
|
||||||
|
regex_string = ''
|
||||||
|
for domain in self.domains:
|
||||||
|
regex_string += ''.join(['[', ''.join(domain), ']', '{1}'])
|
||||||
|
pattern = re.compile(regex_string)
|
||||||
|
|
||||||
|
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||||
|
for word, _ in self.best_words:
|
||||||
|
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||||
|
return word
|
||||||
|
|
||||||
|
def get_vocab(self, vocab_file):
|
||||||
|
vocab = []
|
||||||
|
with open(vocab_file, 'r') as f:
|
||||||
|
for l in f:
|
||||||
|
vocab.append(l.strip())
|
||||||
|
|
||||||
|
# Count letter frequencies at each index
|
||||||
|
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||||
|
for word in vocab:
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
letter_freqs[i][l] += 1
|
||||||
|
|
||||||
|
# Assign a score to each letter at each index by the probability of it appearing
|
||||||
|
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||||
|
for i in range(len(letter_scores)):
|
||||||
|
max_freq = np.max(list(letter_freqs[i].values()))
|
||||||
|
for l in letter_scores[i].keys():
|
||||||
|
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||||
|
|
||||||
|
# Find a sorted list of words ranked by sum of letter scores
|
||||||
|
vocab_scores = {} # (score, word)
|
||||||
|
for word in vocab:
|
||||||
|
score = 0
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
score += letter_scores[i][l]
|
||||||
|
|
||||||
|
# # Optimization: If repeating letters, deduct a couple points
|
||||||
|
# if len(set(word)) < len(word):
|
||||||
|
# score -= 0.25 * (len(word) - len(set(word)))
|
||||||
|
|
||||||
|
vocab_scores[word] = score
|
||||||
|
|
||||||
|
return vocab, vocab_scores, letter_scores
|
37
eric_wordle/dist.py
Normal file
37
eric_wordle/dist.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
words = []
|
||||||
|
with open('words.txt', 'r') as f:
|
||||||
|
for l in f:
|
||||||
|
words.append(l.strip())
|
||||||
|
|
||||||
|
# Count letter frequencies at each index
|
||||||
|
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||||
|
for word in words:
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
letter_freqs[i][l] += 1
|
||||||
|
|
||||||
|
# Assign a score to each letter at each index by the probability of it appearing
|
||||||
|
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||||
|
for i in range(len(letter_scores)):
|
||||||
|
max_freq = np.max(list(letter_freqs[i].values()))
|
||||||
|
for l in letter_scores[i].keys():
|
||||||
|
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||||
|
|
||||||
|
# Find a sorted list of words ranked by sum of letter scores
|
||||||
|
word_scores = [] # (score, word)
|
||||||
|
for word in words:
|
||||||
|
score = 0
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
score += letter_scores[i][l]
|
||||||
|
word_scores.append((score, word))
|
||||||
|
|
||||||
|
sorted_by_second = sorted(word_scores, key=lambda tup: tup[0])[::-1]
|
||||||
|
print(sorted_by_second[:10])
|
||||||
|
|
||||||
|
for i, (score, word) in enumerate(sorted_by_second):
|
||||||
|
if word == 'soare':
|
||||||
|
print(f'{word} with a score of {score} is found at index {i}')
|
||||||
|
|
18
eric_wordle/main.py
Normal file
18
eric_wordle/main.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import argparse
|
||||||
|
from ai import AI
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.n is None:
|
||||||
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
|
||||||
|
ai = AI(args.vocab_file)
|
||||||
|
ai.solve()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||||
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
15
eric_wordle/process.py
Normal file
15
eric_wordle/process.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pandas
|
||||||
|
|
||||||
|
print('Loading in words dictionary; this may take a while...')
|
||||||
|
df = pandas.read_json('words_dictionary.json')
|
||||||
|
print('Done loading words dictionary.')
|
||||||
|
words = []
|
||||||
|
for word in df.axes[0].tolist():
|
||||||
|
if len(word) != 5:
|
||||||
|
continue
|
||||||
|
words.append(word)
|
||||||
|
words.sort()
|
||||||
|
|
||||||
|
with open('words.txt', 'w') as f:
|
||||||
|
for word in words:
|
||||||
|
f.write(word + '\n')
|
15919
eric_wordle/words.txt
Normal file
15919
eric_wordle/words.txt
Normal file
File diff suppressed because it is too large
Load Diff
370104
eric_wordle/words_dictionary.json
Normal file
370104
eric_wordle/words_dictionary.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -150,7 +150,14 @@ class WordleEnv(gym.Env):
|
|||||||
self.action_space = GuessList()
|
self.action_space = GuessList()
|
||||||
self.solution_space = SolutionList()
|
self.solution_space = SolutionList()
|
||||||
|
|
||||||
self.observation_space = WordleObsSpace()
|
# Example setup based on the flattened state size you're now using
|
||||||
|
num_position_availability = 26 * 5 # 26 letters for each of the 5 positions
|
||||||
|
num_global_availability = 26 # Global letter availability
|
||||||
|
num_letter_found_flags = 5 # One flag for each position
|
||||||
|
total_size = num_position_availability + num_global_availability + num_letter_found_flags
|
||||||
|
|
||||||
|
# Define the observation space to match the flattened state format
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=2, shape=(total_size,), dtype=np.float32)
|
||||||
|
|
||||||
self._highlights = {
|
self._highlights = {
|
||||||
self.right_pos: (bg.green, bg.rs),
|
self.right_pos: (bg.green, bg.rs),
|
||||||
@ -169,6 +176,7 @@ class WordleEnv(gym.Env):
|
|||||||
'not_in_word': set(), # Letters known not to be in the word
|
'not_in_word': set(), # Letters known not to be in the word
|
||||||
'tried_positions': defaultdict(set) # Positions tried for each letter
|
'tried_positions': defaultdict(set) # Positions tried for each letter
|
||||||
}
|
}
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def _highlighter(self, char: str, flag: int) -> str:
|
def _highlighter(self, char: str, flag: int) -> str:
|
||||||
"""Terminal renderer functionality. Properly highlights a character
|
"""Terminal renderer functionality. Properly highlights a character
|
||||||
@ -202,7 +210,11 @@ class WordleEnv(gym.Env):
|
|||||||
self.solution = self.solution_space.sample()
|
self.solution = self.solution_space.sample()
|
||||||
self.soln_hash = set(self.solution_space[self.solution])
|
self.soln_hash = set(self.solution_space[self.solution])
|
||||||
|
|
||||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
self.state = {
|
||||||
|
'position_availability': [np.ones(26) for _ in range(5)], # Each position can initially have any letter
|
||||||
|
'global_availability': np.ones(26), # Initially, all letters are available
|
||||||
|
'letter_found': np.zeros(5) # Initially, no correct letters are found
|
||||||
|
}
|
||||||
|
|
||||||
self.info = {
|
self.info = {
|
||||||
'correct': False,
|
'correct': False,
|
||||||
@ -215,41 +227,32 @@ class WordleEnv(gym.Env):
|
|||||||
|
|
||||||
self.simulate_first_guess()
|
self.simulate_first_guess()
|
||||||
|
|
||||||
return self.state, self.info
|
return self.get_observation(), self.info
|
||||||
|
|
||||||
def simulate_first_guess(self):
|
def simulate_first_guess(self):
|
||||||
fixed_first_guess = "rates"
|
fixed_first_guess = "rates" # Example: Using 'rates' as the fixed first guess
|
||||||
fixed_first_guess_array = to_array(fixed_first_guess)
|
# Convert the fixed guess into the appropriate format (e.g., indices of letters)
|
||||||
|
fixed_guess_indices = to_array(fixed_first_guess)
|
||||||
|
solution_indices = self.solution_space[self.solution]
|
||||||
|
|
||||||
# Simulate the feedback for each letter in the fixed first guess
|
for pos in range(5): # Iterate over each position in the word
|
||||||
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array
|
letter_idx = fixed_guess_indices[pos]
|
||||||
for i, letter in enumerate(fixed_first_guess_array):
|
if letter_idx == solution_indices[pos]: # Correct letter in the correct position
|
||||||
if letter in self.solution_space[self.solution]:
|
self.state['position_availability'][pos] = np.zeros(26)
|
||||||
if letter == self.solution_space[self.solution][i]:
|
self.state['position_availability'][pos][letter_idx] = 1
|
||||||
feedback[i] = 1 # Correct position
|
self.state['letter_found'][pos] = 1
|
||||||
else:
|
elif letter_idx in solution_indices: # Correct letter in the wrong position
|
||||||
feedback[i] = 2 # Correct letter, wrong position
|
self.state['position_availability'][pos][letter_idx] = 0
|
||||||
else:
|
# Mark this letter as still available in other positions
|
||||||
feedback[i] = 3 # Letter not in word
|
for other_pos in range(5):
|
||||||
|
if self.state['letter_found'][other_pos] == 0: # If not already found
|
||||||
# Update the state to reflect the fixed first guess and its feedback
|
self.state['position_availability'][other_pos][letter_idx] = 1
|
||||||
self.state[0, :self.n_letters] = fixed_first_guess_array
|
else: # Letter not in the word
|
||||||
self.state[0, self.n_letters:] = feedback
|
self.state['global_availability'][letter_idx] = 0
|
||||||
|
# Update all positions to reflect this letter is not in the word
|
||||||
# Update self.info based on the feedback
|
for other_pos in range(5):
|
||||||
for i, flag in enumerate(feedback):
|
self.state['position_availability'][other_pos][letter_idx] = 0
|
||||||
if flag == self.right_pos:
|
self.round = 1 # Increment round to reflect that first guess has been simulated
|
||||||
# Mark letter as correctly placed
|
|
||||||
self.info['known_positions'][i] = fixed_first_guess_array[i]
|
|
||||||
elif flag == self.wrong_pos:
|
|
||||||
# Note the letter is in the word but in a different position
|
|
||||||
self.info['known_letters'].add(fixed_first_guess_array[i])
|
|
||||||
elif flag == self.wrong_char:
|
|
||||||
# Note the letter is not in the word
|
|
||||||
self.info['not_in_word'].add(fixed_first_guess_array[i])
|
|
||||||
|
|
||||||
# Since we're simulating the first guess, increment the round counter
|
|
||||||
self.round = 1
|
|
||||||
|
|
||||||
def render(self, mode: str = 'human'):
|
def render(self, mode: str = 'human'):
|
||||||
"""Renders the Wordle environment.
|
"""Renders the Wordle environment.
|
||||||
@ -280,49 +283,46 @@ class WordleEnv(gym.Env):
|
|||||||
reward = 0
|
reward = 0
|
||||||
correct_guess = np.array_equal(guessed_word, solution_word)
|
correct_guess = np.array_equal(guessed_word, solution_word)
|
||||||
|
|
||||||
# Initialize flags for current guess
|
# Initialize flags for current guess based on the new state structure
|
||||||
current_flags = np.full(self.n_letters, self.wrong_char)
|
current_flags = np.zeros((self.n_letters, 26)) # Replaced with a more detailed flag system
|
||||||
|
|
||||||
# Track newly discovered information
|
# Track newly discovered information
|
||||||
new_info = False
|
new_info = False
|
||||||
|
|
||||||
for i in range(self.n_letters):
|
for i in range(self.n_letters):
|
||||||
guessed_letter = guessed_word[i]
|
guessed_letter = guessed_word[i] - 1
|
||||||
if guessed_letter in solution_word:
|
if guessed_letter in solution_word:
|
||||||
# Penalize for reusing a letter found to not be in the word
|
|
||||||
if guessed_letter in self.info['not_in_word']:
|
if guessed_letter in self.info['not_in_word']:
|
||||||
reward -= 2
|
reward -= 2 # Penalize for reusing a letter found to not be in the word
|
||||||
|
|
||||||
# Handle correct letter in the correct position
|
|
||||||
if guessed_letter == solution_word[i]:
|
if guessed_letter == solution_word[i]:
|
||||||
current_flags[i] = self.right_pos
|
# Handle correct letter in the correct position
|
||||||
if self.info['known_positions'][i] != guessed_letter:
|
current_flags[i, :] = 0 # Set all other letters to not possible
|
||||||
reward += 10 # Large reward for new correct placement
|
current_flags[i, guessed_letter] = 2 # Mark this letter as correct
|
||||||
new_info = True
|
self.info['known_positions'][i] = 1 # Update known_positions
|
||||||
self.info['known_positions'][i] = guessed_letter
|
reward += 10 # Reward for correct placement
|
||||||
else:
|
|
||||||
reward += 20 # Large reward for repeating correct placement
|
|
||||||
else:
|
|
||||||
current_flags[i] = self.wrong_pos
|
|
||||||
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
|
|
||||||
reward += 10 # Reward for guessing a letter in a new position
|
|
||||||
new_info = True
|
|
||||||
else:
|
|
||||||
reward -= 20 # Penalize for not leveraging known information
|
|
||||||
self.info['known_letters'].add(guessed_letter)
|
|
||||||
self.info['tried_positions'][guessed_letter].add(i)
|
|
||||||
else:
|
|
||||||
# New incorrect letter
|
|
||||||
if guessed_letter not in self.info['not_in_word']:
|
|
||||||
reward -= 2 # Penalize for guessing a letter not in the word
|
|
||||||
self.info['not_in_word'].add(guessed_letter)
|
|
||||||
new_info = True
|
new_info = True
|
||||||
else:
|
else:
|
||||||
reward -= 15 # Larger penalty for repeating an incorrect letter
|
# Correct letter, wrong position
|
||||||
|
if self.info['known_positions'][i] == 0:
|
||||||
|
# Only update if we haven't already found the correct letter for this position
|
||||||
|
current_flags[:, guessed_letter] = 2 # Mark this letter as found in another position
|
||||||
|
reward += 5
|
||||||
|
new_info = True
|
||||||
|
else:
|
||||||
|
# Letter not in word
|
||||||
|
if guessed_letter not in self.info['not_in_word']:
|
||||||
|
self.info['not_in_word'].add(guessed_letter)
|
||||||
|
reward -= 2 # Penalize for guessing a letter not in the word
|
||||||
|
new_info = True
|
||||||
|
for pos in range(self.n_letters):
|
||||||
|
# Update all positions to reflect this letter is not correct
|
||||||
|
current_flags[pos, guessed_letter] = 0
|
||||||
|
|
||||||
# Update observation state with the current guess and flags
|
# Update global letter availability based on the guess
|
||||||
self.state[self.round, :self.n_letters] = guessed_word
|
for letter in range(26):
|
||||||
self.state[self.round, self.n_letters:] = current_flags
|
if letter not in guessed_word or letter in self.info['not_in_word']:
|
||||||
|
self.state['global_availability'][letter] = 0
|
||||||
|
|
||||||
# Check if the game is over
|
# Check if the game is over
|
||||||
done = self.round == self.n_rounds - 1 or correct_guess
|
done = self.round == self.n_rounds - 1 or correct_guess
|
||||||
@ -337,4 +337,17 @@ class WordleEnv(gym.Env):
|
|||||||
|
|
||||||
self.round += 1
|
self.round += 1
|
||||||
|
|
||||||
return self.state, reward, done, False, self.info
|
return self.get_observation(), reward, done, False, self.info
|
||||||
|
|
||||||
|
def get_observation(self):
|
||||||
|
# Flatten the position-specific letter availability
|
||||||
|
position_availability_flat = np.concatenate(self.state['position_availability'])
|
||||||
|
|
||||||
|
# Global availability is already a 1D array, but ensure consistency in data handling
|
||||||
|
global_availability_flat = self.state['global_availability'].flatten()
|
||||||
|
|
||||||
|
# Concatenate all parts of the state into a single flat array for the DQN input
|
||||||
|
full_state_flat = np.concatenate(
|
||||||
|
[position_availability_flat, global_availability_flat, self.state['letter_found']])
|
||||||
|
|
||||||
|
return full_state_flat
|
||||||
|
99
letter_guess.py
Normal file
99
letter_guess.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class LetterGuessingEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
Custom Gymnasium environment for a letter guessing game with a focus on forming
|
||||||
|
valid prefixes and words from a list of valid Wordle words. The environment tracks
|
||||||
|
the current guess prefix and validates it against known valid words, ending the game
|
||||||
|
early with a negative reward for invalid prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {'render_modes': ['human']}
|
||||||
|
|
||||||
|
def __init__(self, valid_words, seed):
|
||||||
|
self.action_space = spaces.Discrete(26)
|
||||||
|
self.observation_space = spaces.Box(low=0, high=2, shape=(26*2 + 26*4,), dtype=np.int32)
|
||||||
|
|
||||||
|
self.valid_words = valid_words # List of valid Wordle words
|
||||||
|
self.target_word = '' # Target word for the current episode
|
||||||
|
self.valid_words_str = '_'.join(self.valid_words) + '_'
|
||||||
|
self.letter_flags = None
|
||||||
|
self.letter_positions = None
|
||||||
|
self.guessed_letters = set()
|
||||||
|
self.guess_prefix = "" # Tracks the current guess prefix
|
||||||
|
self.round = 1
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
letter_index = action % 26 # Assuming action is the letter index directly
|
||||||
|
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||||
|
letter = chr(ord('a') + letter_index)
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# Check if the letter has already been used in the guess prefix
|
||||||
|
if letter in self.guessed_letters:
|
||||||
|
reward = -1 # Penalize for repeating letters in the prefix
|
||||||
|
else:
|
||||||
|
# Add the new letter to the prefix and update guessed letters set
|
||||||
|
self.guess_prefix += letter
|
||||||
|
self.guessed_letters.add(letter)
|
||||||
|
|
||||||
|
# Update letter flags based on whether the letter is in the target word
|
||||||
|
if self.target_word_encoded[letter_index] == 1:
|
||||||
|
self.letter_flags[letter_index, :] = [1, 0] # Update flag for correct guess
|
||||||
|
else:
|
||||||
|
self.letter_flags[letter_index, :] = [0, 0] # Update flag for incorrect guess
|
||||||
|
|
||||||
|
reward = 1 # Reward for adding new information by trying a new letter
|
||||||
|
|
||||||
|
# Update the letter_positions matrix to reflect the new guess
|
||||||
|
self.letter_positions[:, position] = 0
|
||||||
|
self.letter_positions[letter_index, position] = 1
|
||||||
|
|
||||||
|
# Use regex to check if the current prefix can lead to a valid word
|
||||||
|
if not re.search(r'\b' + self.guess_prefix, self.valid_words_str):
|
||||||
|
reward = -5 # Penalize for forming an invalid prefix
|
||||||
|
done = True # End the episode if the prefix is invalid
|
||||||
|
|
||||||
|
# guessed a full word so we reset our guess prefix to guess next round
|
||||||
|
if len(self.guess_prefix) == len(self.target_word):
|
||||||
|
self.guess_prefix == ''
|
||||||
|
self.round += 1
|
||||||
|
|
||||||
|
# end after 3 rounds of total guesses
|
||||||
|
if self.round == 3:
|
||||||
|
done = True
|
||||||
|
|
||||||
|
obs = self._get_obs()
|
||||||
|
|
||||||
|
return obs, reward, done, False, {}
|
||||||
|
|
||||||
|
def reset(self, seed):
|
||||||
|
self.target_word = random.choice(self.valid_words)
|
||||||
|
self.target_word_encoded = self.encode_word(self.target_word)
|
||||||
|
self.letter_flags = np.ones((26, 2)) * 2
|
||||||
|
self.letter_positions = np.ones((26, 4))
|
||||||
|
self.guessed_letters = set()
|
||||||
|
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
|
def encode_word(self, word):
|
||||||
|
encoded = np.zeros((26,))
|
||||||
|
for char in word:
|
||||||
|
index = ord(char) - ord('a')
|
||||||
|
encoded[index] = 1
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
pass # Optional: Implement rendering logic if needed
|
189
test.ipynb
189
test.ipynb
@ -1,189 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from collections import defaultdict"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def my_func()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"defaultdict(<function __main__.<lambda>()>, {})"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"[0, 1, 2, 3, 4]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t['t']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"False"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"'x' in t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"x = np.array([1, 1, 1])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"x[:] = 0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"array([0, 0, 0])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"x"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"'abcde'aaa\n",
|
|
||||||
" 33221\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "env",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
2317
wordle_words.txt
Normal file
2317
wordle_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user