mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-22 18:49:21 +00:00
Compare commits
3 Commits
e799c14ece
...
arthur-tes
Author | SHA1 | Date | |
---|---|---|---|
|
dd5889da33 | ||
|
848ea719b7 | ||
|
f641d77c47 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,3 @@
|
|||||||
**/data/*
|
**/data/*
|
||||||
**/*.zip
|
**/*.zip
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
/env
|
|
671
dqn_wordle.ipynb
671
dqn_wordle.ipynb
@@ -1,671 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import gym\n",
|
|
||||||
"import gym_wordle\n",
|
|
||||||
"from stable_baselines3 import DQN, PPO, common\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import tqdm"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"<Monitor<WordleEnv instance>>\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
|
||||||
"env = common.monitor.Monitor(env)\n",
|
|
||||||
"\n",
|
|
||||||
"print(env)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "7c52630b65904d5e8e200be505d2121a",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Output()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Using cuda device\n",
|
|
||||||
"Wrapping the env with a `Monitor` wrapper\n",
|
|
||||||
"Wrapping the env in a DummyVecEnv.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -175 |\n",
|
|
||||||
"| exploration_rate | 0.525 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 10000 |\n",
|
|
||||||
"| fps | 4606 |\n",
|
|
||||||
"| time_elapsed | 10 |\n",
|
|
||||||
"| total_timesteps | 49989 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -208 |\n",
|
|
||||||
"| exploration_rate | 0.0502 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 20000 |\n",
|
|
||||||
"| fps | 1118 |\n",
|
|
||||||
"| time_elapsed | 89 |\n",
|
|
||||||
"| total_timesteps | 99980 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 24.6 |\n",
|
|
||||||
"| n_updates | 12494 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -230 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 30000 |\n",
|
|
||||||
"| fps | 856 |\n",
|
|
||||||
"| time_elapsed | 175 |\n",
|
|
||||||
"| total_timesteps | 149974 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 18.7 |\n",
|
|
||||||
"| n_updates | 24993 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -242 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 40000 |\n",
|
|
||||||
"| fps | 766 |\n",
|
|
||||||
"| time_elapsed | 260 |\n",
|
|
||||||
"| total_timesteps | 199967 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 24 |\n",
|
|
||||||
"| n_updates | 37491 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -186 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 50000 |\n",
|
|
||||||
"| fps | 722 |\n",
|
|
||||||
"| time_elapsed | 346 |\n",
|
|
||||||
"| total_timesteps | 249962 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.5 |\n",
|
|
||||||
"| n_updates | 49990 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -183 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 60000 |\n",
|
|
||||||
"| fps | 694 |\n",
|
|
||||||
"| time_elapsed | 431 |\n",
|
|
||||||
"| total_timesteps | 299957 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.6 |\n",
|
|
||||||
"| n_updates | 62489 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -181 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 70000 |\n",
|
|
||||||
"| fps | 675 |\n",
|
|
||||||
"| time_elapsed | 517 |\n",
|
|
||||||
"| total_timesteps | 349953 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 26.8 |\n",
|
|
||||||
"| n_updates | 74988 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -196 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 80000 |\n",
|
|
||||||
"| fps | 663 |\n",
|
|
||||||
"| time_elapsed | 603 |\n",
|
|
||||||
"| total_timesteps | 399936 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.5 |\n",
|
|
||||||
"| n_updates | 87483 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -174 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 90000 |\n",
|
|
||||||
"| fps | 653 |\n",
|
|
||||||
"| time_elapsed | 688 |\n",
|
|
||||||
"| total_timesteps | 449928 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.1 |\n",
|
|
||||||
"| n_updates | 99981 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -155 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 100000 |\n",
|
|
||||||
"| fps | 645 |\n",
|
|
||||||
"| time_elapsed | 774 |\n",
|
|
||||||
"| total_timesteps | 499920 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.8 |\n",
|
|
||||||
"| n_updates | 112479 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -153 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 110000 |\n",
|
|
||||||
"| fps | 638 |\n",
|
|
||||||
"| time_elapsed | 860 |\n",
|
|
||||||
"| total_timesteps | 549916 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 16 |\n",
|
|
||||||
"| n_updates | 124978 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -164 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 120000 |\n",
|
|
||||||
"| fps | 633 |\n",
|
|
||||||
"| time_elapsed | 947 |\n",
|
|
||||||
"| total_timesteps | 599915 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.8 |\n",
|
|
||||||
"| n_updates | 137478 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -145 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 130000 |\n",
|
|
||||||
"| fps | 628 |\n",
|
|
||||||
"| time_elapsed | 1033 |\n",
|
|
||||||
"| total_timesteps | 649910 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.8 |\n",
|
|
||||||
"| n_updates | 149977 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -154 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 140000 |\n",
|
|
||||||
"| fps | 624 |\n",
|
|
||||||
"| time_elapsed | 1120 |\n",
|
|
||||||
"| total_timesteps | 699902 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 20.9 |\n",
|
|
||||||
"| n_updates | 162475 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -192 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 150000 |\n",
|
|
||||||
"| fps | 621 |\n",
|
|
||||||
"| time_elapsed | 1206 |\n",
|
|
||||||
"| total_timesteps | 749884 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 18.3 |\n",
|
|
||||||
"| n_updates | 174970 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -170 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 160000 |\n",
|
|
||||||
"| fps | 618 |\n",
|
|
||||||
"| time_elapsed | 1293 |\n",
|
|
||||||
"| total_timesteps | 799869 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.7 |\n",
|
|
||||||
"| n_updates | 187467 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -233 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 170000 |\n",
|
|
||||||
"| fps | 615 |\n",
|
|
||||||
"| time_elapsed | 1380 |\n",
|
|
||||||
"| total_timesteps | 849855 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.6 |\n",
|
|
||||||
"| n_updates | 199963 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -146 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 180000 |\n",
|
|
||||||
"| fps | 613 |\n",
|
|
||||||
"| time_elapsed | 1466 |\n",
|
|
||||||
"| total_timesteps | 899847 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 19.4 |\n",
|
|
||||||
"| n_updates | 212461 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -142 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 190000 |\n",
|
|
||||||
"| fps | 611 |\n",
|
|
||||||
"| time_elapsed | 1553 |\n",
|
|
||||||
"| total_timesteps | 949846 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.9 |\n",
|
|
||||||
"| n_updates | 224961 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -171 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 200000 |\n",
|
|
||||||
"| fps | 609 |\n",
|
|
||||||
"| time_elapsed | 1640 |\n",
|
|
||||||
"| total_timesteps | 999839 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 20.3 |\n",
|
|
||||||
"| n_updates | 237459 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
||||||
],
|
|
||||||
"text/plain": []
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
||||||
"</pre>\n"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"total_timesteps = 1_000_000\n",
|
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
|
||||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model.save(\"dqn_new_rewards\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
|
||||||
"Exception: code() argument 13 must be str, not int\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
|
||||||
"Exception: code() argument 13 must be str, not int\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# model = DQN.load(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"0\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
|
||||||
"\n",
|
|
||||||
"for i in range(1000):\n",
|
|
||||||
" \n",
|
|
||||||
" state, info = env.reset()\n",
|
|
||||||
"\n",
|
|
||||||
" done = False\n",
|
|
||||||
"\n",
|
|
||||||
" wins = 0\n",
|
|
||||||
"\n",
|
|
||||||
" while not done:\n",
|
|
||||||
"\n",
|
|
||||||
" action, _states = model.predict(state, deterministic=True)\n",
|
|
||||||
"\n",
|
|
||||||
" state, reward, done, truncated, info = env.step(action)\n",
|
|
||||||
"\n",
|
|
||||||
" if info[\"correct\"]:\n",
|
|
||||||
" wins += 1\n",
|
|
||||||
"\n",
|
|
||||||
"print(wins)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
|
|
||||||
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
|
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
|
||||||
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
|
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
|
|
||||||
" -130)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"state, reward"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 21,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"blah = (14, 1, 9, 22, 5)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"True"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"blah in info['guesses']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
38
dqn_wordle.py
Normal file
38
dqn_wordle.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import gym
|
||||||
|
import sys
|
||||||
|
from stable_baselines3 import DQN
|
||||||
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
import wordle_gym
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def train (model, env, total_timesteps = 100000):
|
||||||
|
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
||||||
|
model.save("dqn_wordle")
|
||||||
|
|
||||||
|
def test(model, env, test_num=1000):
|
||||||
|
|
||||||
|
total_correct = 0
|
||||||
|
|
||||||
|
for i in tqdm(range(test_num)):
|
||||||
|
|
||||||
|
model = DQN.load("dqn_wordle")
|
||||||
|
|
||||||
|
env = gym.make("wordle-v0")
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action, _states = model.predict(obs)
|
||||||
|
obs, rewards, done, info = env.step(action)
|
||||||
|
|
||||||
|
return total_correct / test_num
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
env = gym.make("wordle-v0")
|
||||||
|
model = DQN("MlpPolicy", env, verbose=0)
|
||||||
|
print(env)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
train(model, env, total_timesteps=500000)
|
||||||
|
print(test(model, env))
|
@@ -1,7 +0,0 @@
|
|||||||
from gym.envs.registration import register
|
|
||||||
from .wordle import WordleEnv
|
|
||||||
|
|
||||||
register(
|
|
||||||
id='Wordle-v0',
|
|
||||||
entry_point='gym_wordle.wordle:WordleEnv'
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@@ -1,93 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
_chars = ' abcdefghijklmnopqrstuvwxyz'
|
|
||||||
_char_d = {c: i for i, c in enumerate(_chars)}
|
|
||||||
|
|
||||||
|
|
||||||
def to_english(array: npt.NDArray[np.int64]) -> str:
|
|
||||||
"""Converts a numpy integer array into a corresponding English string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
array: Word in array (int) form. It is assumed that each integer in the
|
|
||||||
array is between 0,...,26 (inclusive).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A (lowercase) string representation of the word.
|
|
||||||
"""
|
|
||||||
return ''.join(_chars[i] for i in array)
|
|
||||||
|
|
||||||
|
|
||||||
def to_array(word: str) -> npt.NDArray[np.int64]:
|
|
||||||
"""Converts a string of characters into a corresponding numpy array.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word: Word in string form. It is assumed that each character in the
|
|
||||||
string is either an empty space ' ' or lowercase alphabetical
|
|
||||||
character.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array representation of the word.
|
|
||||||
"""
|
|
||||||
return np.array([_char_d[c] for c in word])
|
|
||||||
|
|
||||||
|
|
||||||
def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
|
|
||||||
"""Loads a list of words in array form.
|
|
||||||
|
|
||||||
If specified, this will recompute the list from the human-readable list of
|
|
||||||
words, and save the results in array form.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
category: Either 'guess' or 'solution', which corresponds to the list
|
|
||||||
of acceptable guess words and the list of acceptable solution words.
|
|
||||||
build: If True, recomputes and saves the array-version of the computed
|
|
||||||
list for future access.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array representation of the list of words specified by the category.
|
|
||||||
This array has two dimensions, and the number of columns is fixed at
|
|
||||||
five.
|
|
||||||
"""
|
|
||||||
assert category in {'guess', 'solution'}
|
|
||||||
|
|
||||||
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
|
|
||||||
if build:
|
|
||||||
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
|
|
||||||
|
|
||||||
with open(list_path, 'r') as f:
|
|
||||||
words = np.array([to_array(line.strip()) for line in f])
|
|
||||||
np.save(arr_path, words)
|
|
||||||
|
|
||||||
return np.load(arr_path)
|
|
||||||
|
|
||||||
|
|
||||||
def play():
|
|
||||||
"""Play Wordle yourself!"""
|
|
||||||
import gym
|
|
||||||
import gym_wordle
|
|
||||||
|
|
||||||
env = gym.make('Wordle-v0') # load the environment
|
|
||||||
|
|
||||||
env.reset()
|
|
||||||
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
|
|
||||||
|
|
||||||
done = False
|
|
||||||
|
|
||||||
while not done:
|
|
||||||
action = -1
|
|
||||||
|
|
||||||
# in general, the environment won't be forgiving if you input an
|
|
||||||
# invalid word, but for this function I want to let you screw up user
|
|
||||||
# input without consequence, so just loops until valid input is taken
|
|
||||||
while not env.action_space.contains(action):
|
|
||||||
guess = input('Guess: ')
|
|
||||||
action = env.unwrapped.action_space.index_of(to_array(guess))
|
|
||||||
|
|
||||||
state, reward, done, info = env.step(action)
|
|
||||||
env.render()
|
|
||||||
|
|
||||||
print(f"The word was {solution}")
|
|
@@ -1,340 +0,0 @@
|
|||||||
import gymnasium as gym
|
|
||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
from sty import fg, bg, ef, rs
|
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from gym_wordle.utils import to_english, to_array, get_words
|
|
||||||
from typing import Optional
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
|
|
||||||
class WordList(gym.spaces.Discrete):
|
|
||||||
"""Super class for defining a space of valid words according to a specified
|
|
||||||
list.
|
|
||||||
|
|
||||||
The space is a subclass of gym.spaces.Discrete, where each element
|
|
||||||
corresponds to an index of a valid word in the word list. The obfuscation
|
|
||||||
is necessary for more direct implementation of RL algorithms, which expect
|
|
||||||
spaces of less sophisticated form.
|
|
||||||
|
|
||||||
In addition to the default methods of the Discrete space, it implements
|
|
||||||
a __getitem__ method for easy index lookup, and an index_of method to
|
|
||||||
convert potential words into their corresponding index (if they exist).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
words: Collection of words in array form with shape (_, 5), where
|
|
||||||
each word is a row of the array. Each array element is an integer
|
|
||||||
between 0,...,26 (inclusive).
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
super().__init__(words.shape[0], **kwargs)
|
|
||||||
self.words = words
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
|
||||||
"""Obtains the (int-encoded) word associated with the given index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index: Index for the list of words.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Associated word at the position specified by index.
|
|
||||||
"""
|
|
||||||
return self.words[index]
|
|
||||||
|
|
||||||
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
|
||||||
"""Given a word, determine its index in the list (if it exists),
|
|
||||||
otherwise returning -1 if no index exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word: Word to find in the word list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The index of the given word if it exists, otherwise -1.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
index, = np.nonzero((word == self.words).all(axis=1))
|
|
||||||
return index[0]
|
|
||||||
except:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
class SolutionList(WordList):
|
|
||||||
"""Space for *solution* words to the Wordle environment.
|
|
||||||
|
|
||||||
In the game Wordle, there are two different collections of words:
|
|
||||||
|
|
||||||
* "guesses", which the game accepts as valid words to use to guess the
|
|
||||||
answer.
|
|
||||||
* "solutions", which the game uses to choose solutions from.
|
|
||||||
|
|
||||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
|
||||||
|
|
||||||
This class represents the set of solution words.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
words = get_words('solution')
|
|
||||||
super().__init__(words, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleObsSpace(gym.spaces.Box):
|
|
||||||
"""Implementation of the state (observation) space in terms of gym
|
|
||||||
primitives, in this case, gym.spaces.Box.
|
|
||||||
|
|
||||||
The Wordle observation space can be thought of as a 6x5 array with two
|
|
||||||
channels:
|
|
||||||
|
|
||||||
- the character channel, indicating which characters are placed on the
|
|
||||||
board (unfilled rows are marked with the empty character, 0)
|
|
||||||
- the flag channel, indicating the in-game information associated with
|
|
||||||
each character's placement (green highlight, yellow highlight, etc.)
|
|
||||||
|
|
||||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
|
||||||
the solution will always be a word of length 5.
|
|
||||||
|
|
||||||
For simplicity, and compatibility with stable_baselines algorithms,
|
|
||||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
|
||||||
horizontally appended (along columns). Thus each row in the observation
|
|
||||||
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
|
||||||
c0...c4 and its associated flags are f0...f4.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.n_rows = 6
|
|
||||||
self.n_cols = 5
|
|
||||||
self.max_char = 26
|
|
||||||
self.max_flag = 4
|
|
||||||
|
|
||||||
low = np.zeros((self.n_rows, 2*self.n_cols))
|
|
||||||
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
|
||||||
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
|
||||||
|
|
||||||
super().__init__(low, high, dtype=np.int64, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class GuessList(WordList):
|
|
||||||
"""Space for *guess* words to the Wordle environment.
|
|
||||||
|
|
||||||
This class represents the set of guess words.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
words = get_words('guess')
|
|
||||||
super().__init__(words, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv(gym.Env):
|
|
||||||
metadata = {'render.modes': ['human']}
|
|
||||||
|
|
||||||
# Character flag codes
|
|
||||||
no_char = 0
|
|
||||||
right_pos = 1
|
|
||||||
wrong_pos = 2
|
|
||||||
wrong_char = 3
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.action_space = GuessList()
|
|
||||||
self.solution_space = SolutionList()
|
|
||||||
|
|
||||||
self.observation_space = WordleObsSpace()
|
|
||||||
|
|
||||||
self._highlights = {
|
|
||||||
self.right_pos: (bg.green, bg.rs),
|
|
||||||
self.wrong_pos: (bg.yellow, bg.rs),
|
|
||||||
self.wrong_char: ('', ''),
|
|
||||||
self.no_char: ('', ''),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.n_rounds = 6
|
|
||||||
self.n_letters = 5
|
|
||||||
self.info = {
|
|
||||||
'correct': False,
|
|
||||||
'guesses': set(),
|
|
||||||
'known_positions': np.full(5, -1), # -1 for unknown, else letter index
|
|
||||||
'known_letters': set(), # Letters known to be in the word
|
|
||||||
'not_in_word': set(), # Letters known not to be in the word
|
|
||||||
'tried_positions': defaultdict(set) # Positions tried for each letter
|
|
||||||
}
|
|
||||||
|
|
||||||
def _highlighter(self, char: str, flag: int) -> str:
|
|
||||||
"""Terminal renderer functionality. Properly highlights a character
|
|
||||||
based on the flag associated with it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
char: Character in question.
|
|
||||||
flag: Associated flag, one of:
|
|
||||||
- 0: no character (render no background)
|
|
||||||
- 1: right position (render green background)
|
|
||||||
- 2: wrong position (render yellow background)
|
|
||||||
- 3: wrong character (render no background)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Correct ASCII sequence producing the desired character in the
|
|
||||||
correct background.
|
|
||||||
"""
|
|
||||||
front, back = self._highlights[flag]
|
|
||||||
return front + char + back
|
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
|
||||||
"""Reset the environment to an initial state and returns an initial
|
|
||||||
observation.
|
|
||||||
|
|
||||||
Note: The observation space instance should be a Box space.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (object): The initial observation of the space.
|
|
||||||
"""
|
|
||||||
self.round = 0
|
|
||||||
self.solution = self.solution_space.sample()
|
|
||||||
self.soln_hash = set(self.solution_space[self.solution])
|
|
||||||
|
|
||||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
|
||||||
|
|
||||||
self.info = {
|
|
||||||
'correct': False,
|
|
||||||
'guesses': set(),
|
|
||||||
'known_positions': np.full(5, -1),
|
|
||||||
'known_letters': set(),
|
|
||||||
'not_in_word': set(),
|
|
||||||
'tried_positions': defaultdict(set)
|
|
||||||
}
|
|
||||||
|
|
||||||
self.simulate_first_guess()
|
|
||||||
|
|
||||||
return self.state, self.info
|
|
||||||
|
|
||||||
def simulate_first_guess(self):
|
|
||||||
fixed_first_guess = "rates"
|
|
||||||
fixed_first_guess_array = to_array(fixed_first_guess)
|
|
||||||
|
|
||||||
# Simulate the feedback for each letter in the fixed first guess
|
|
||||||
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array
|
|
||||||
for i, letter in enumerate(fixed_first_guess_array):
|
|
||||||
if letter in self.solution_space[self.solution]:
|
|
||||||
if letter == self.solution_space[self.solution][i]:
|
|
||||||
feedback[i] = 1 # Correct position
|
|
||||||
else:
|
|
||||||
feedback[i] = 2 # Correct letter, wrong position
|
|
||||||
else:
|
|
||||||
feedback[i] = 3 # Letter not in word
|
|
||||||
|
|
||||||
# Update the state to reflect the fixed first guess and its feedback
|
|
||||||
self.state[0, :self.n_letters] = fixed_first_guess_array
|
|
||||||
self.state[0, self.n_letters:] = feedback
|
|
||||||
|
|
||||||
# Update self.info based on the feedback
|
|
||||||
for i, flag in enumerate(feedback):
|
|
||||||
if flag == self.right_pos:
|
|
||||||
# Mark letter as correctly placed
|
|
||||||
self.info['known_positions'][i] = fixed_first_guess_array[i]
|
|
||||||
elif flag == self.wrong_pos:
|
|
||||||
# Note the letter is in the word but in a different position
|
|
||||||
self.info['known_letters'].add(fixed_first_guess_array[i])
|
|
||||||
elif flag == self.wrong_char:
|
|
||||||
# Note the letter is not in the word
|
|
||||||
self.info['not_in_word'].add(fixed_first_guess_array[i])
|
|
||||||
|
|
||||||
# Since we're simulating the first guess, increment the round counter
|
|
||||||
self.round = 1
|
|
||||||
|
|
||||||
def render(self, mode: str = 'human'):
|
|
||||||
"""Renders the Wordle environment.
|
|
||||||
|
|
||||||
Currently supported render modes:
|
|
||||||
- human: renders the Wordle game to the terminal.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode: the mode to render with.
|
|
||||||
"""
|
|
||||||
if mode == 'human':
|
|
||||||
for row in self.state:
|
|
||||||
text = ''.join(map(
|
|
||||||
self._highlighter,
|
|
||||||
to_english(row[:self.n_letters]).upper(),
|
|
||||||
row[self.n_letters:]
|
|
||||||
))
|
|
||||||
print(text)
|
|
||||||
else:
|
|
||||||
super().render(mode=mode)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
assert self.action_space.contains(action), 'Invalid word!'
|
|
||||||
|
|
||||||
guessed_word = self.action_space[action]
|
|
||||||
solution_word = self.solution_space[self.solution]
|
|
||||||
|
|
||||||
reward = 0
|
|
||||||
correct_guess = np.array_equal(guessed_word, solution_word)
|
|
||||||
|
|
||||||
# Initialize flags for current guess
|
|
||||||
current_flags = np.full(self.n_letters, self.wrong_char)
|
|
||||||
|
|
||||||
# Track newly discovered information
|
|
||||||
new_info = False
|
|
||||||
|
|
||||||
for i in range(self.n_letters):
|
|
||||||
guessed_letter = guessed_word[i]
|
|
||||||
if guessed_letter in solution_word:
|
|
||||||
# Penalize for reusing a letter found to not be in the word
|
|
||||||
if guessed_letter in self.info['not_in_word']:
|
|
||||||
reward -= 2
|
|
||||||
|
|
||||||
# Handle correct letter in the correct position
|
|
||||||
if guessed_letter == solution_word[i]:
|
|
||||||
current_flags[i] = self.right_pos
|
|
||||||
if self.info['known_positions'][i] != guessed_letter:
|
|
||||||
reward += 10 # Large reward for new correct placement
|
|
||||||
new_info = True
|
|
||||||
self.info['known_positions'][i] = guessed_letter
|
|
||||||
else:
|
|
||||||
reward += 20 # Large reward for repeating correct placement
|
|
||||||
else:
|
|
||||||
current_flags[i] = self.wrong_pos
|
|
||||||
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
|
|
||||||
reward += 10 # Reward for guessing a letter in a new position
|
|
||||||
new_info = True
|
|
||||||
else:
|
|
||||||
reward -= 20 # Penalize for not leveraging known information
|
|
||||||
self.info['known_letters'].add(guessed_letter)
|
|
||||||
self.info['tried_positions'][guessed_letter].add(i)
|
|
||||||
else:
|
|
||||||
# New incorrect letter
|
|
||||||
if guessed_letter not in self.info['not_in_word']:
|
|
||||||
reward -= 2 # Penalize for guessing a letter not in the word
|
|
||||||
self.info['not_in_word'].add(guessed_letter)
|
|
||||||
new_info = True
|
|
||||||
else:
|
|
||||||
reward -= 15 # Larger penalty for repeating an incorrect letter
|
|
||||||
|
|
||||||
# Update observation state with the current guess and flags
|
|
||||||
self.state[self.round, :self.n_letters] = guessed_word
|
|
||||||
self.state[self.round, self.n_letters:] = current_flags
|
|
||||||
|
|
||||||
# Check if the game is over
|
|
||||||
done = self.round == self.n_rounds - 1 or correct_guess
|
|
||||||
self.info['correct'] = correct_guess
|
|
||||||
|
|
||||||
if correct_guess:
|
|
||||||
reward += 100 # Major reward for winning
|
|
||||||
elif done:
|
|
||||||
reward -= 50 # Penalty for losing without using new information effectively
|
|
||||||
elif not new_info:
|
|
||||||
reward -= 10 # Penalty if no new information was used in this guess
|
|
||||||
|
|
||||||
self.round += 1
|
|
||||||
|
|
||||||
return self.state, reward, done, False, self.info
|
|
189
test.ipynb
189
test.ipynb
@@ -1,189 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from collections import defaultdict"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def my_func()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"defaultdict(<function __main__.<lambda>()>, {})"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"[0, 1, 2, 3, 4]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t['t']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"False"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"'x' in t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"x = np.array([1, 1, 1])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"x[:] = 0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"array([0, 0, 0])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"x"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"'abcde'aaa\n",
|
|
||||||
" 33221\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "env",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
9
wordle_gym/__init__.py
Normal file
9
wordle_gym/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv",
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv",
|
||||||
|
)
|
0
wordle_gym/envs/__init__.py
Normal file
0
wordle_gym/envs/__init__.py
Normal file
15
wordle_gym/envs/strategies/base.py
Normal file
15
wordle_gym/envs/strategies/base.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class StrategyType(Enum):
|
||||||
|
RANDOM = 1
|
||||||
|
ELIMINATION = 2
|
||||||
|
PROBABILITY = 3
|
||||||
|
|
||||||
|
class Strategy:
|
||||||
|
def __init__(self, type: StrategyType):
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
|
||||||
|
raise NotImplementedError("Strategy.get_best_word() not implemented")
|
2
wordle_gym/envs/strategies/elimination.py
Normal file
2
wordle_gym/envs/strategies/elimination.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
def get_best_word(state):
|
||||||
|
|
20
wordle_gym/envs/strategies/probabilistic.py
Normal file
20
wordle_gym/envs/strategies/probabilistic.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from random import sample
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from base import Strategy
|
||||||
|
from base import StrategyType
|
||||||
|
|
||||||
|
from utils import freq
|
||||||
|
|
||||||
|
class Random(Strategy):
|
||||||
|
def __init__(self):
|
||||||
|
self.words = freq.get_5_letter_word_freqs()
|
||||||
|
super().__init__(StrategyType.RANDOM)
|
||||||
|
|
||||||
|
def get_best_word(self, state: List[List[int]]):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
r = Random()
|
||||||
|
print(r.get_best_word([]))
|
29
wordle_gym/envs/strategies/rand.py
Normal file
29
wordle_gym/envs/strategies/rand.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from random import sample
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from base import Strategy
|
||||||
|
from base import StrategyType
|
||||||
|
|
||||||
|
from utils import freq
|
||||||
|
|
||||||
|
class Random(Strategy):
|
||||||
|
def __init__(self):
|
||||||
|
self.words = freq.get_5_letter_word_freqs()
|
||||||
|
super().__init__(StrategyType.RANDOM)
|
||||||
|
|
||||||
|
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
|
||||||
|
correct_letters = []
|
||||||
|
regex = ""
|
||||||
|
for g, s in zip(guesses, state):
|
||||||
|
for c, s in zip(g, s):
|
||||||
|
if s == 2:
|
||||||
|
correct_letters.append(c)
|
||||||
|
regex += c
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
r = Random()
|
||||||
|
print(r.get_best_word([]))
|
27
wordle_gym/envs/strategies/utils/freq.py
Normal file
27
wordle_gym/envs/strategies/utils/freq.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from os import path
|
||||||
|
|
||||||
|
def get_5_letter_word_freqs():
|
||||||
|
"""
|
||||||
|
Returns a list of words with 5 letters.
|
||||||
|
"""
|
||||||
|
FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt")
|
||||||
|
lines = read_file(FILEPATH)
|
||||||
|
return {k:v for k, v in get_freq(lines).items() if len(k) == 5}
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(filename):
|
||||||
|
"""
|
||||||
|
Reads a file and returns a list of words and frequencies
|
||||||
|
"""
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
|
|
||||||
|
def get_freq(lines):
|
||||||
|
"""
|
||||||
|
Returns a dictionary of words and their frequencies
|
||||||
|
"""
|
||||||
|
freqs = {}
|
||||||
|
for word, freq in map(lambda x: x.split("\t"), lines):
|
||||||
|
freqs[word] = int(freq)
|
||||||
|
return freqs
|
131
wordle_gym/envs/wordle_env.py
Normal file
131
wordle_gym/envs/wordle_env.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym import error, spaces, utils
|
||||||
|
from gym.utils import seeding
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from collections import Counter
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
WORD_LENGTH = 5
|
||||||
|
TOTAL_GUESSES = 6
|
||||||
|
SOLUTION_PATH = "../words/solution.csv"
|
||||||
|
VALID_WORDS_PATH = "../words/guess.csv"
|
||||||
|
|
||||||
|
class LetterState(Enum):
|
||||||
|
ABSENT = 0
|
||||||
|
PRESENT = 1
|
||||||
|
CORRECT_POSITION = 2
|
||||||
|
|
||||||
|
|
||||||
|
class WordleEnv(gym.Env):
|
||||||
|
metadata = {"render.modes": ["human"]}
|
||||||
|
|
||||||
|
def _current_path(self):
|
||||||
|
return os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def _read_solutions(self):
|
||||||
|
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
|
||||||
|
|
||||||
|
def _get_valid_words(self):
|
||||||
|
words = []
|
||||||
|
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
|
||||||
|
words.append((word, Counter(word)))
|
||||||
|
return words
|
||||||
|
|
||||||
|
def get_valid(self):
|
||||||
|
return self._valid_words
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._solutions = self._read_solutions()
|
||||||
|
self._valid_words = self._get_valid_words()
|
||||||
|
self.action_space = spaces.Discrete(len(self._valid_words))
|
||||||
|
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
|
||||||
|
np.random.seed(0)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def _check_guess(self, guess, guess_counter):
|
||||||
|
c = guess_counter & self.solution_ct
|
||||||
|
result = []
|
||||||
|
correct = True
|
||||||
|
reward = 0
|
||||||
|
for i, char in enumerate(guess):
|
||||||
|
if c.get(char, 0) > 0:
|
||||||
|
if self.solution[i] == char:
|
||||||
|
result.append(2)
|
||||||
|
reward += 2
|
||||||
|
else:
|
||||||
|
result.append(1)
|
||||||
|
correct = False
|
||||||
|
reward += 1
|
||||||
|
c[char] -= 1
|
||||||
|
else:
|
||||||
|
result.append(0)
|
||||||
|
correct = False
|
||||||
|
return result, correct, reward
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
action: index of word in valid_words
|
||||||
|
|
||||||
|
returns:
|
||||||
|
observation: (TOTAL_GUESSES, WORD_LENGTH)
|
||||||
|
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
|
||||||
|
done: True if game over, w/ or w/o correct answer
|
||||||
|
additional_info: empty
|
||||||
|
"""
|
||||||
|
guess, guess_counter = self._valid_words[action]
|
||||||
|
if guess in self.guesses:
|
||||||
|
return self.obs, -1, False, {}
|
||||||
|
self.guesses.append(guess)
|
||||||
|
result, correct, reward = self._check_guess(guess, guess_counter)
|
||||||
|
done = False
|
||||||
|
|
||||||
|
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
|
||||||
|
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
|
||||||
|
|
||||||
|
self.guess_no += 1
|
||||||
|
if correct:
|
||||||
|
done = True
|
||||||
|
reward = 1200
|
||||||
|
if self.guess_no == TOTAL_GUESSES:
|
||||||
|
done = True
|
||||||
|
if not correct:
|
||||||
|
reward = -15
|
||||||
|
return self.obs, reward, done, {}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.solution = self._solutions[np.random.randint(len(self._solutions))]
|
||||||
|
self.solution_ct = Counter(self.solution)
|
||||||
|
self.guess_no = 0
|
||||||
|
self.guesses = []
|
||||||
|
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
|
||||||
|
return self.obs
|
||||||
|
|
||||||
|
def render(self, mode="human"):
|
||||||
|
m = {
|
||||||
|
0: "⬜",
|
||||||
|
1: "🟨",
|
||||||
|
2: "🟩"
|
||||||
|
}
|
||||||
|
print("Solution:", self.solution)
|
||||||
|
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
|
||||||
|
o_n = "".join(map(lambda x: m[x], o))
|
||||||
|
print(g, o_n)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
env = WordleEnv()
|
||||||
|
print(env.action_space)
|
||||||
|
print(env.observation_space)
|
||||||
|
print(env.solution)
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user