mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-01-15 18:35:56 +00:00
Compare commits
2 Commits
e799c14ece
...
f40301cac9
Author | SHA1 | Date | |
---|---|---|---|
|
f40301cac9 | ||
|
fc197acb6e |
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,3 +2,5 @@
|
|||||||
**/*.zip
|
**/*.zip
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
/env
|
/env
|
||||||
|
**/runs/*
|
||||||
|
**/wandb/*
|
2996
dqn_letter_gssr.ipynb
Normal file
2996
dqn_letter_gssr.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
557
dqn_wordle.ipynb
557
dqn_wordle.ipynb
@ -35,13 +35,21 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Using cuda device\n",
|
||||||
|
"Wrapping the env in a DummyVecEnv.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "7c52630b65904d5e8e200be505d2121a",
|
"model_id": "6921a0721569456abf5bceac7e7b6b34",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -52,29 +60,20 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "display_data"
|
"output_type": "display_data"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Using cuda device\n",
|
|
||||||
"Wrapping the env with a `Monitor` wrapper\n",
|
|
||||||
"Wrapping the env in a DummyVecEnv.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"----------------------------------\n",
|
"----------------------------------\n",
|
||||||
"| rollout/ | |\n",
|
"| rollout/ | |\n",
|
||||||
"| ep_len_mean | 5 |\n",
|
"| ep_len_mean | 4.97 |\n",
|
||||||
"| ep_rew_mean | -175 |\n",
|
"| ep_rew_mean | -63.8 |\n",
|
||||||
"| exploration_rate | 0.525 |\n",
|
"| exploration_rate | 0.05 |\n",
|
||||||
"| time/ | |\n",
|
"| time/ | |\n",
|
||||||
"| episodes | 10000 |\n",
|
"| episodes | 10000 |\n",
|
||||||
"| fps | 4606 |\n",
|
"| fps | 1628 |\n",
|
||||||
"| time_elapsed | 10 |\n",
|
"| time_elapsed | 30 |\n",
|
||||||
"| total_timesteps | 49989 |\n",
|
"| total_timesteps | 49995 |\n",
|
||||||
"----------------------------------\n"
|
"----------------------------------\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -85,395 +84,17 @@
|
|||||||
"----------------------------------\n",
|
"----------------------------------\n",
|
||||||
"| rollout/ | |\n",
|
"| rollout/ | |\n",
|
||||||
"| ep_len_mean | 5 |\n",
|
"| ep_len_mean | 5 |\n",
|
||||||
"| ep_rew_mean | -208 |\n",
|
"| ep_rew_mean | -70.5 |\n",
|
||||||
"| exploration_rate | 0.0502 |\n",
|
"| exploration_rate | 0.05 |\n",
|
||||||
"| time/ | |\n",
|
"| time/ | |\n",
|
||||||
"| episodes | 20000 |\n",
|
"| episodes | 20000 |\n",
|
||||||
"| fps | 1118 |\n",
|
"| fps | 662 |\n",
|
||||||
"| time_elapsed | 89 |\n",
|
"| time_elapsed | 150 |\n",
|
||||||
"| total_timesteps | 99980 |\n",
|
"| total_timesteps | 99992 |\n",
|
||||||
"| train/ | |\n",
|
"| train/ | |\n",
|
||||||
"| learning_rate | 0.0001 |\n",
|
"| learning_rate | 0.0001 |\n",
|
||||||
"| loss | 24.6 |\n",
|
"| loss | 11.7 |\n",
|
||||||
"| n_updates | 12494 |\n",
|
"| n_updates | 12497 |\n",
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -230 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 30000 |\n",
|
|
||||||
"| fps | 856 |\n",
|
|
||||||
"| time_elapsed | 175 |\n",
|
|
||||||
"| total_timesteps | 149974 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 18.7 |\n",
|
|
||||||
"| n_updates | 24993 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -242 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 40000 |\n",
|
|
||||||
"| fps | 766 |\n",
|
|
||||||
"| time_elapsed | 260 |\n",
|
|
||||||
"| total_timesteps | 199967 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 24 |\n",
|
|
||||||
"| n_updates | 37491 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -186 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 50000 |\n",
|
|
||||||
"| fps | 722 |\n",
|
|
||||||
"| time_elapsed | 346 |\n",
|
|
||||||
"| total_timesteps | 249962 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.5 |\n",
|
|
||||||
"| n_updates | 49990 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -183 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 60000 |\n",
|
|
||||||
"| fps | 694 |\n",
|
|
||||||
"| time_elapsed | 431 |\n",
|
|
||||||
"| total_timesteps | 299957 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.6 |\n",
|
|
||||||
"| n_updates | 62489 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -181 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 70000 |\n",
|
|
||||||
"| fps | 675 |\n",
|
|
||||||
"| time_elapsed | 517 |\n",
|
|
||||||
"| total_timesteps | 349953 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 26.8 |\n",
|
|
||||||
"| n_updates | 74988 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -196 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 80000 |\n",
|
|
||||||
"| fps | 663 |\n",
|
|
||||||
"| time_elapsed | 603 |\n",
|
|
||||||
"| total_timesteps | 399936 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.5 |\n",
|
|
||||||
"| n_updates | 87483 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -174 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 90000 |\n",
|
|
||||||
"| fps | 653 |\n",
|
|
||||||
"| time_elapsed | 688 |\n",
|
|
||||||
"| total_timesteps | 449928 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.1 |\n",
|
|
||||||
"| n_updates | 99981 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -155 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 100000 |\n",
|
|
||||||
"| fps | 645 |\n",
|
|
||||||
"| time_elapsed | 774 |\n",
|
|
||||||
"| total_timesteps | 499920 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.8 |\n",
|
|
||||||
"| n_updates | 112479 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -153 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 110000 |\n",
|
|
||||||
"| fps | 638 |\n",
|
|
||||||
"| time_elapsed | 860 |\n",
|
|
||||||
"| total_timesteps | 549916 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 16 |\n",
|
|
||||||
"| n_updates | 124978 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -164 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 120000 |\n",
|
|
||||||
"| fps | 633 |\n",
|
|
||||||
"| time_elapsed | 947 |\n",
|
|
||||||
"| total_timesteps | 599915 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.8 |\n",
|
|
||||||
"| n_updates | 137478 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -145 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 130000 |\n",
|
|
||||||
"| fps | 628 |\n",
|
|
||||||
"| time_elapsed | 1033 |\n",
|
|
||||||
"| total_timesteps | 649910 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.8 |\n",
|
|
||||||
"| n_updates | 149977 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -154 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 140000 |\n",
|
|
||||||
"| fps | 624 |\n",
|
|
||||||
"| time_elapsed | 1120 |\n",
|
|
||||||
"| total_timesteps | 699902 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 20.9 |\n",
|
|
||||||
"| n_updates | 162475 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -192 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 150000 |\n",
|
|
||||||
"| fps | 621 |\n",
|
|
||||||
"| time_elapsed | 1206 |\n",
|
|
||||||
"| total_timesteps | 749884 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 18.3 |\n",
|
|
||||||
"| n_updates | 174970 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -170 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 160000 |\n",
|
|
||||||
"| fps | 618 |\n",
|
|
||||||
"| time_elapsed | 1293 |\n",
|
|
||||||
"| total_timesteps | 799869 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 17.7 |\n",
|
|
||||||
"| n_updates | 187467 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -233 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 170000 |\n",
|
|
||||||
"| fps | 615 |\n",
|
|
||||||
"| time_elapsed | 1380 |\n",
|
|
||||||
"| total_timesteps | 849855 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 21.6 |\n",
|
|
||||||
"| n_updates | 199963 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -146 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 180000 |\n",
|
|
||||||
"| fps | 613 |\n",
|
|
||||||
"| time_elapsed | 1466 |\n",
|
|
||||||
"| total_timesteps | 899847 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 19.4 |\n",
|
|
||||||
"| n_updates | 212461 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -142 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 190000 |\n",
|
|
||||||
"| fps | 611 |\n",
|
|
||||||
"| time_elapsed | 1553 |\n",
|
|
||||||
"| total_timesteps | 949846 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 22.9 |\n",
|
|
||||||
"| n_updates | 224961 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -171 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 200000 |\n",
|
|
||||||
"| fps | 609 |\n",
|
|
||||||
"| time_elapsed | 1640 |\n",
|
|
||||||
"| total_timesteps | 999839 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 20.3 |\n",
|
|
||||||
"| n_updates | 237459 |\n",
|
|
||||||
"----------------------------------\n"
|
"----------------------------------\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -503,27 +124,27 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>"
|
"<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"total_timesteps = 1_000_000\n",
|
"total_timesteps = 100_000\n",
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model.save(\"dqn_new_rewards\")"
|
"model.save(\"dqn_new_state\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -557,6 +178,76 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||||
|
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||||
|
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||||
|
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
"0\n"
|
"0\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -578,6 +269,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" state, reward, done, truncated, info = env.step(action)\n",
|
" state, reward, done, truncated, info = env.step(action)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" print(state)\n",
|
||||||
" if info[\"correct\"]:\n",
|
" if info[\"correct\"]:\n",
|
||||||
" wins += 1\n",
|
" wins += 1\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -586,22 +278,26 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
|
"(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n",
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
" 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n",
|
||||||
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
|
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||||
" -130)"
|
" 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||||
|
" 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
|
||||||
|
" -50)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -610,35 +306,6 @@
|
|||||||
"state, reward"
|
"state, reward"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 21,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"blah = (14, 1, 9, 22, 5)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"True"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"blah in info['guesses']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
129
eric_wordle/.gitignore
vendored
Normal file
129
eric_wordle/.gitignore
vendored
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
11
eric_wordle/README.md
Normal file
11
eric_wordle/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# N-dle Solver
|
||||||
|
|
||||||
|
A solver designed to beat New York Time's Wordle (link [here](https://www.nytimes.com/games/wordle/index.html)). If you are bored enough, can extend to solve the more general N-dle problem (for quordle, octordle, etc.)
|
||||||
|
|
||||||
|
I originally made this out of frustration for the game (and my own lack of lingual talent). One day, my friend thought she could beat my bot. To her dismay, she learned that she is no better than a machine. Let's see if you can do any better (the average number of attempts is 3.6).
|
||||||
|
|
||||||
|
## Usage:
|
||||||
|
1. Run `python main.py --n 1`
|
||||||
|
2. Follow the prompts
|
||||||
|
|
||||||
|
Currently only supports solving for 1 word at a time (i.e. wordle).
|
126
eric_wordle/ai.py
Normal file
126
eric_wordle/ai.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AI:
|
||||||
|
def __init__(self, vocab_file, num_letters=5, num_guesses=6):
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
self.num_letters = num_letters
|
||||||
|
self.num_guesses = 6
|
||||||
|
|
||||||
|
self.vocab, self.vocab_scores, self.letter_scores = self.get_vocab(self.vocab_file)
|
||||||
|
self.best_words = sorted(list(self.vocab_scores.items()), key=lambda tup: tup[1])[::-1]
|
||||||
|
|
||||||
|
self.domains = None
|
||||||
|
self.possible_letters = None
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def solve(self):
|
||||||
|
num_guesses = 0
|
||||||
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
|
num_guesses += 1
|
||||||
|
word = self.sample()
|
||||||
|
|
||||||
|
# # Always start with these two words
|
||||||
|
# if num_guesses == 1:
|
||||||
|
# word = 'soare'
|
||||||
|
# elif num_guesses == 2:
|
||||||
|
# word = 'culti'
|
||||||
|
|
||||||
|
print('-----------------------------------------------')
|
||||||
|
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
|
||||||
|
print('-----------------------------------------------')
|
||||||
|
self.arc_consistency(word)
|
||||||
|
|
||||||
|
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
||||||
|
|
||||||
|
|
||||||
|
def arc_consistency(self, word):
|
||||||
|
print(f'Performing arc consistency check on {word}...')
|
||||||
|
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Collect results
|
||||||
|
for l in word:
|
||||||
|
while True:
|
||||||
|
result = input(f'{l}: ')
|
||||||
|
if result not in ['0', '1', '2']:
|
||||||
|
print('Incorrect option. Try again.')
|
||||||
|
continue
|
||||||
|
results.append(result)
|
||||||
|
break
|
||||||
|
|
||||||
|
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
|
||||||
|
|
||||||
|
for i in range(len(word)):
|
||||||
|
if results[i] == '0':
|
||||||
|
if word[i] in self.possible_letters:
|
||||||
|
if word[i] in self.domains[i]:
|
||||||
|
self.domains[i].remove(word[i])
|
||||||
|
else:
|
||||||
|
for j in range(len(self.domains)):
|
||||||
|
if word[i] in self.domains[j] and len(self.domains[j]) > 1:
|
||||||
|
self.domains[j].remove(word[i])
|
||||||
|
if results[i] == '1':
|
||||||
|
if word[i] in self.domains[i]:
|
||||||
|
self.domains[i].remove(word[i])
|
||||||
|
if results[i] == '2':
|
||||||
|
self.domains[i] = [word[i]]
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
||||||
|
self.possible_letters = []
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""
|
||||||
|
Samples a best word given the current domains
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Compile a regex of possible words with the current domain
|
||||||
|
regex_string = ''
|
||||||
|
for domain in self.domains:
|
||||||
|
regex_string += ''.join(['[', ''.join(domain), ']', '{1}'])
|
||||||
|
pattern = re.compile(regex_string)
|
||||||
|
|
||||||
|
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||||
|
for word, _ in self.best_words:
|
||||||
|
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||||
|
return word
|
||||||
|
|
||||||
|
def get_vocab(self, vocab_file):
|
||||||
|
vocab = []
|
||||||
|
with open(vocab_file, 'r') as f:
|
||||||
|
for l in f:
|
||||||
|
vocab.append(l.strip())
|
||||||
|
|
||||||
|
# Count letter frequencies at each index
|
||||||
|
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||||
|
for word in vocab:
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
letter_freqs[i][l] += 1
|
||||||
|
|
||||||
|
# Assign a score to each letter at each index by the probability of it appearing
|
||||||
|
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||||
|
for i in range(len(letter_scores)):
|
||||||
|
max_freq = np.max(list(letter_freqs[i].values()))
|
||||||
|
for l in letter_scores[i].keys():
|
||||||
|
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||||
|
|
||||||
|
# Find a sorted list of words ranked by sum of letter scores
|
||||||
|
vocab_scores = {} # (score, word)
|
||||||
|
for word in vocab:
|
||||||
|
score = 0
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
score += letter_scores[i][l]
|
||||||
|
|
||||||
|
# # Optimization: If repeating letters, deduct a couple points
|
||||||
|
# if len(set(word)) < len(word):
|
||||||
|
# score -= 0.25 * (len(word) - len(set(word)))
|
||||||
|
|
||||||
|
vocab_scores[word] = score
|
||||||
|
|
||||||
|
return vocab, vocab_scores, letter_scores
|
37
eric_wordle/dist.py
Normal file
37
eric_wordle/dist.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
words = []
|
||||||
|
with open('words.txt', 'r') as f:
|
||||||
|
for l in f:
|
||||||
|
words.append(l.strip())
|
||||||
|
|
||||||
|
# Count letter frequencies at each index
|
||||||
|
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||||
|
for word in words:
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
letter_freqs[i][l] += 1
|
||||||
|
|
||||||
|
# Assign a score to each letter at each index by the probability of it appearing
|
||||||
|
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||||
|
for i in range(len(letter_scores)):
|
||||||
|
max_freq = np.max(list(letter_freqs[i].values()))
|
||||||
|
for l in letter_scores[i].keys():
|
||||||
|
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||||
|
|
||||||
|
# Find a sorted list of words ranked by sum of letter scores
|
||||||
|
word_scores = [] # (score, word)
|
||||||
|
for word in words:
|
||||||
|
score = 0
|
||||||
|
for i, l in enumerate(word):
|
||||||
|
score += letter_scores[i][l]
|
||||||
|
word_scores.append((score, word))
|
||||||
|
|
||||||
|
sorted_by_second = sorted(word_scores, key=lambda tup: tup[0])[::-1]
|
||||||
|
print(sorted_by_second[:10])
|
||||||
|
|
||||||
|
for i, (score, word) in enumerate(sorted_by_second):
|
||||||
|
if word == 'soare':
|
||||||
|
print(f'{word} with a score of {score} is found at index {i}')
|
||||||
|
|
18
eric_wordle/main.py
Normal file
18
eric_wordle/main.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import argparse
|
||||||
|
from ai import AI
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.n is None:
|
||||||
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
|
||||||
|
ai = AI(args.vocab_file)
|
||||||
|
ai.solve()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||||
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
15
eric_wordle/process.py
Normal file
15
eric_wordle/process.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pandas
|
||||||
|
|
||||||
|
print('Loading in words dictionary; this may take a while...')
|
||||||
|
df = pandas.read_json('words_dictionary.json')
|
||||||
|
print('Done loading words dictionary.')
|
||||||
|
words = []
|
||||||
|
for word in df.axes[0].tolist():
|
||||||
|
if len(word) != 5:
|
||||||
|
continue
|
||||||
|
words.append(word)
|
||||||
|
words.sort()
|
||||||
|
|
||||||
|
with open('words.txt', 'w') as f:
|
||||||
|
for word in words:
|
||||||
|
f.write(word + '\n')
|
File diff suppressed because it is too large
Load Diff
370104
eric_wordle/words_dictionary.json
Normal file
370104
eric_wordle/words_dictionary.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,7 +0,0 @@
|
|||||||
from gym.envs.registration import register
|
|
||||||
from .wordle import WordleEnv
|
|
||||||
|
|
||||||
register(
|
|
||||||
id='Wordle-v0',
|
|
||||||
entry_point='gym_wordle.wordle:WordleEnv'
|
|
||||||
)
|
|
Binary file not shown.
Binary file not shown.
@ -1,93 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
_chars = ' abcdefghijklmnopqrstuvwxyz'
|
|
||||||
_char_d = {c: i for i, c in enumerate(_chars)}
|
|
||||||
|
|
||||||
|
|
||||||
def to_english(array: npt.NDArray[np.int64]) -> str:
|
|
||||||
"""Converts a numpy integer array into a corresponding English string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
array: Word in array (int) form. It is assumed that each integer in the
|
|
||||||
array is between 0,...,26 (inclusive).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A (lowercase) string representation of the word.
|
|
||||||
"""
|
|
||||||
return ''.join(_chars[i] for i in array)
|
|
||||||
|
|
||||||
|
|
||||||
def to_array(word: str) -> npt.NDArray[np.int64]:
|
|
||||||
"""Converts a string of characters into a corresponding numpy array.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word: Word in string form. It is assumed that each character in the
|
|
||||||
string is either an empty space ' ' or lowercase alphabetical
|
|
||||||
character.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array representation of the word.
|
|
||||||
"""
|
|
||||||
return np.array([_char_d[c] for c in word])
|
|
||||||
|
|
||||||
|
|
||||||
def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
|
|
||||||
"""Loads a list of words in array form.
|
|
||||||
|
|
||||||
If specified, this will recompute the list from the human-readable list of
|
|
||||||
words, and save the results in array form.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
category: Either 'guess' or 'solution', which corresponds to the list
|
|
||||||
of acceptable guess words and the list of acceptable solution words.
|
|
||||||
build: If True, recomputes and saves the array-version of the computed
|
|
||||||
list for future access.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array representation of the list of words specified by the category.
|
|
||||||
This array has two dimensions, and the number of columns is fixed at
|
|
||||||
five.
|
|
||||||
"""
|
|
||||||
assert category in {'guess', 'solution'}
|
|
||||||
|
|
||||||
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
|
|
||||||
if build:
|
|
||||||
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
|
|
||||||
|
|
||||||
with open(list_path, 'r') as f:
|
|
||||||
words = np.array([to_array(line.strip()) for line in f])
|
|
||||||
np.save(arr_path, words)
|
|
||||||
|
|
||||||
return np.load(arr_path)
|
|
||||||
|
|
||||||
|
|
||||||
def play():
|
|
||||||
"""Play Wordle yourself!"""
|
|
||||||
import gym
|
|
||||||
import gym_wordle
|
|
||||||
|
|
||||||
env = gym.make('Wordle-v0') # load the environment
|
|
||||||
|
|
||||||
env.reset()
|
|
||||||
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
|
|
||||||
|
|
||||||
done = False
|
|
||||||
|
|
||||||
while not done:
|
|
||||||
action = -1
|
|
||||||
|
|
||||||
# in general, the environment won't be forgiving if you input an
|
|
||||||
# invalid word, but for this function I want to let you screw up user
|
|
||||||
# input without consequence, so just loops until valid input is taken
|
|
||||||
while not env.action_space.contains(action):
|
|
||||||
guess = input('Guess: ')
|
|
||||||
action = env.unwrapped.action_space.index_of(to_array(guess))
|
|
||||||
|
|
||||||
state, reward, done, info = env.step(action)
|
|
||||||
env.render()
|
|
||||||
|
|
||||||
print(f"The word was {solution}")
|
|
@ -1,340 +0,0 @@
|
|||||||
import gymnasium as gym
|
|
||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
from sty import fg, bg, ef, rs
|
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from gym_wordle.utils import to_english, to_array, get_words
|
|
||||||
from typing import Optional
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
|
|
||||||
class WordList(gym.spaces.Discrete):
|
|
||||||
"""Super class for defining a space of valid words according to a specified
|
|
||||||
list.
|
|
||||||
|
|
||||||
The space is a subclass of gym.spaces.Discrete, where each element
|
|
||||||
corresponds to an index of a valid word in the word list. The obfuscation
|
|
||||||
is necessary for more direct implementation of RL algorithms, which expect
|
|
||||||
spaces of less sophisticated form.
|
|
||||||
|
|
||||||
In addition to the default methods of the Discrete space, it implements
|
|
||||||
a __getitem__ method for easy index lookup, and an index_of method to
|
|
||||||
convert potential words into their corresponding index (if they exist).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
words: Collection of words in array form with shape (_, 5), where
|
|
||||||
each word is a row of the array. Each array element is an integer
|
|
||||||
between 0,...,26 (inclusive).
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
super().__init__(words.shape[0], **kwargs)
|
|
||||||
self.words = words
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
|
||||||
"""Obtains the (int-encoded) word associated with the given index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index: Index for the list of words.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Associated word at the position specified by index.
|
|
||||||
"""
|
|
||||||
return self.words[index]
|
|
||||||
|
|
||||||
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
|
||||||
"""Given a word, determine its index in the list (if it exists),
|
|
||||||
otherwise returning -1 if no index exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word: Word to find in the word list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The index of the given word if it exists, otherwise -1.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
index, = np.nonzero((word == self.words).all(axis=1))
|
|
||||||
return index[0]
|
|
||||||
except:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
class SolutionList(WordList):
|
|
||||||
"""Space for *solution* words to the Wordle environment.
|
|
||||||
|
|
||||||
In the game Wordle, there are two different collections of words:
|
|
||||||
|
|
||||||
* "guesses", which the game accepts as valid words to use to guess the
|
|
||||||
answer.
|
|
||||||
* "solutions", which the game uses to choose solutions from.
|
|
||||||
|
|
||||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
|
||||||
|
|
||||||
This class represents the set of solution words.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
words = get_words('solution')
|
|
||||||
super().__init__(words, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleObsSpace(gym.spaces.Box):
|
|
||||||
"""Implementation of the state (observation) space in terms of gym
|
|
||||||
primitives, in this case, gym.spaces.Box.
|
|
||||||
|
|
||||||
The Wordle observation space can be thought of as a 6x5 array with two
|
|
||||||
channels:
|
|
||||||
|
|
||||||
- the character channel, indicating which characters are placed on the
|
|
||||||
board (unfilled rows are marked with the empty character, 0)
|
|
||||||
- the flag channel, indicating the in-game information associated with
|
|
||||||
each character's placement (green highlight, yellow highlight, etc.)
|
|
||||||
|
|
||||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
|
||||||
the solution will always be a word of length 5.
|
|
||||||
|
|
||||||
For simplicity, and compatibility with stable_baselines algorithms,
|
|
||||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
|
||||||
horizontally appended (along columns). Thus each row in the observation
|
|
||||||
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
|
||||||
c0...c4 and its associated flags are f0...f4.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.n_rows = 6
|
|
||||||
self.n_cols = 5
|
|
||||||
self.max_char = 26
|
|
||||||
self.max_flag = 4
|
|
||||||
|
|
||||||
low = np.zeros((self.n_rows, 2*self.n_cols))
|
|
||||||
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
|
||||||
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
|
||||||
|
|
||||||
super().__init__(low, high, dtype=np.int64, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class GuessList(WordList):
|
|
||||||
"""Space for *guess* words to the Wordle environment.
|
|
||||||
|
|
||||||
This class represents the set of guess words.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
words = get_words('guess')
|
|
||||||
super().__init__(words, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv(gym.Env):
|
|
||||||
metadata = {'render.modes': ['human']}
|
|
||||||
|
|
||||||
# Character flag codes
|
|
||||||
no_char = 0
|
|
||||||
right_pos = 1
|
|
||||||
wrong_pos = 2
|
|
||||||
wrong_char = 3
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.action_space = GuessList()
|
|
||||||
self.solution_space = SolutionList()
|
|
||||||
|
|
||||||
self.observation_space = WordleObsSpace()
|
|
||||||
|
|
||||||
self._highlights = {
|
|
||||||
self.right_pos: (bg.green, bg.rs),
|
|
||||||
self.wrong_pos: (bg.yellow, bg.rs),
|
|
||||||
self.wrong_char: ('', ''),
|
|
||||||
self.no_char: ('', ''),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.n_rounds = 6
|
|
||||||
self.n_letters = 5
|
|
||||||
self.info = {
|
|
||||||
'correct': False,
|
|
||||||
'guesses': set(),
|
|
||||||
'known_positions': np.full(5, -1), # -1 for unknown, else letter index
|
|
||||||
'known_letters': set(), # Letters known to be in the word
|
|
||||||
'not_in_word': set(), # Letters known not to be in the word
|
|
||||||
'tried_positions': defaultdict(set) # Positions tried for each letter
|
|
||||||
}
|
|
||||||
|
|
||||||
def _highlighter(self, char: str, flag: int) -> str:
|
|
||||||
"""Terminal renderer functionality. Properly highlights a character
|
|
||||||
based on the flag associated with it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
char: Character in question.
|
|
||||||
flag: Associated flag, one of:
|
|
||||||
- 0: no character (render no background)
|
|
||||||
- 1: right position (render green background)
|
|
||||||
- 2: wrong position (render yellow background)
|
|
||||||
- 3: wrong character (render no background)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Correct ASCII sequence producing the desired character in the
|
|
||||||
correct background.
|
|
||||||
"""
|
|
||||||
front, back = self._highlights[flag]
|
|
||||||
return front + char + back
|
|
||||||
|
|
||||||
def reset(self, seed=None, options=None):
|
|
||||||
"""Reset the environment to an initial state and returns an initial
|
|
||||||
observation.
|
|
||||||
|
|
||||||
Note: The observation space instance should be a Box space.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (object): The initial observation of the space.
|
|
||||||
"""
|
|
||||||
self.round = 0
|
|
||||||
self.solution = self.solution_space.sample()
|
|
||||||
self.soln_hash = set(self.solution_space[self.solution])
|
|
||||||
|
|
||||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
|
||||||
|
|
||||||
self.info = {
|
|
||||||
'correct': False,
|
|
||||||
'guesses': set(),
|
|
||||||
'known_positions': np.full(5, -1),
|
|
||||||
'known_letters': set(),
|
|
||||||
'not_in_word': set(),
|
|
||||||
'tried_positions': defaultdict(set)
|
|
||||||
}
|
|
||||||
|
|
||||||
self.simulate_first_guess()
|
|
||||||
|
|
||||||
return self.state, self.info
|
|
||||||
|
|
||||||
def simulate_first_guess(self):
|
|
||||||
fixed_first_guess = "rates"
|
|
||||||
fixed_first_guess_array = to_array(fixed_first_guess)
|
|
||||||
|
|
||||||
# Simulate the feedback for each letter in the fixed first guess
|
|
||||||
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array
|
|
||||||
for i, letter in enumerate(fixed_first_guess_array):
|
|
||||||
if letter in self.solution_space[self.solution]:
|
|
||||||
if letter == self.solution_space[self.solution][i]:
|
|
||||||
feedback[i] = 1 # Correct position
|
|
||||||
else:
|
|
||||||
feedback[i] = 2 # Correct letter, wrong position
|
|
||||||
else:
|
|
||||||
feedback[i] = 3 # Letter not in word
|
|
||||||
|
|
||||||
# Update the state to reflect the fixed first guess and its feedback
|
|
||||||
self.state[0, :self.n_letters] = fixed_first_guess_array
|
|
||||||
self.state[0, self.n_letters:] = feedback
|
|
||||||
|
|
||||||
# Update self.info based on the feedback
|
|
||||||
for i, flag in enumerate(feedback):
|
|
||||||
if flag == self.right_pos:
|
|
||||||
# Mark letter as correctly placed
|
|
||||||
self.info['known_positions'][i] = fixed_first_guess_array[i]
|
|
||||||
elif flag == self.wrong_pos:
|
|
||||||
# Note the letter is in the word but in a different position
|
|
||||||
self.info['known_letters'].add(fixed_first_guess_array[i])
|
|
||||||
elif flag == self.wrong_char:
|
|
||||||
# Note the letter is not in the word
|
|
||||||
self.info['not_in_word'].add(fixed_first_guess_array[i])
|
|
||||||
|
|
||||||
# Since we're simulating the first guess, increment the round counter
|
|
||||||
self.round = 1
|
|
||||||
|
|
||||||
def render(self, mode: str = 'human'):
|
|
||||||
"""Renders the Wordle environment.
|
|
||||||
|
|
||||||
Currently supported render modes:
|
|
||||||
- human: renders the Wordle game to the terminal.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode: the mode to render with.
|
|
||||||
"""
|
|
||||||
if mode == 'human':
|
|
||||||
for row in self.state:
|
|
||||||
text = ''.join(map(
|
|
||||||
self._highlighter,
|
|
||||||
to_english(row[:self.n_letters]).upper(),
|
|
||||||
row[self.n_letters:]
|
|
||||||
))
|
|
||||||
print(text)
|
|
||||||
else:
|
|
||||||
super().render(mode=mode)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
assert self.action_space.contains(action), 'Invalid word!'
|
|
||||||
|
|
||||||
guessed_word = self.action_space[action]
|
|
||||||
solution_word = self.solution_space[self.solution]
|
|
||||||
|
|
||||||
reward = 0
|
|
||||||
correct_guess = np.array_equal(guessed_word, solution_word)
|
|
||||||
|
|
||||||
# Initialize flags for current guess
|
|
||||||
current_flags = np.full(self.n_letters, self.wrong_char)
|
|
||||||
|
|
||||||
# Track newly discovered information
|
|
||||||
new_info = False
|
|
||||||
|
|
||||||
for i in range(self.n_letters):
|
|
||||||
guessed_letter = guessed_word[i]
|
|
||||||
if guessed_letter in solution_word:
|
|
||||||
# Penalize for reusing a letter found to not be in the word
|
|
||||||
if guessed_letter in self.info['not_in_word']:
|
|
||||||
reward -= 2
|
|
||||||
|
|
||||||
# Handle correct letter in the correct position
|
|
||||||
if guessed_letter == solution_word[i]:
|
|
||||||
current_flags[i] = self.right_pos
|
|
||||||
if self.info['known_positions'][i] != guessed_letter:
|
|
||||||
reward += 10 # Large reward for new correct placement
|
|
||||||
new_info = True
|
|
||||||
self.info['known_positions'][i] = guessed_letter
|
|
||||||
else:
|
|
||||||
reward += 20 # Large reward for repeating correct placement
|
|
||||||
else:
|
|
||||||
current_flags[i] = self.wrong_pos
|
|
||||||
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
|
|
||||||
reward += 10 # Reward for guessing a letter in a new position
|
|
||||||
new_info = True
|
|
||||||
else:
|
|
||||||
reward -= 20 # Penalize for not leveraging known information
|
|
||||||
self.info['known_letters'].add(guessed_letter)
|
|
||||||
self.info['tried_positions'][guessed_letter].add(i)
|
|
||||||
else:
|
|
||||||
# New incorrect letter
|
|
||||||
if guessed_letter not in self.info['not_in_word']:
|
|
||||||
reward -= 2 # Penalize for guessing a letter not in the word
|
|
||||||
self.info['not_in_word'].add(guessed_letter)
|
|
||||||
new_info = True
|
|
||||||
else:
|
|
||||||
reward -= 15 # Larger penalty for repeating an incorrect letter
|
|
||||||
|
|
||||||
# Update observation state with the current guess and flags
|
|
||||||
self.state[self.round, :self.n_letters] = guessed_word
|
|
||||||
self.state[self.round, self.n_letters:] = current_flags
|
|
||||||
|
|
||||||
# Check if the game is over
|
|
||||||
done = self.round == self.n_rounds - 1 or correct_guess
|
|
||||||
self.info['correct'] = correct_guess
|
|
||||||
|
|
||||||
if correct_guess:
|
|
||||||
reward += 100 # Major reward for winning
|
|
||||||
elif done:
|
|
||||||
reward -= 50 # Penalty for losing without using new information effectively
|
|
||||||
elif not new_info:
|
|
||||||
reward -= 10 # Penalty if no new information was used in this guess
|
|
||||||
|
|
||||||
self.round += 1
|
|
||||||
|
|
||||||
return self.state, reward, done, False, self.info
|
|
108
letter_guess.py
Normal file
108
letter_guess.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class LetterGuessingEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
Custom Gymnasium environment for a letter guessing game with a focus on forming
|
||||||
|
valid prefixes and words from a list of valid Wordle words. The environment tracks
|
||||||
|
the current guess prefix and validates it against known valid words, ending the game
|
||||||
|
early with a negative reward for invalid prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {'render_modes': ['human']}
|
||||||
|
|
||||||
|
def __init__(self, valid_words, seed=None):
|
||||||
|
self.action_space = spaces.Discrete(26)
|
||||||
|
self.observation_space = spaces.Box(low=0, high=1, shape=(26*2 + 26*4,), dtype=np.int32)
|
||||||
|
|
||||||
|
self.valid_words = valid_words # List of valid Wordle words
|
||||||
|
self.target_word = '' # Target word for the current episode
|
||||||
|
self.valid_words_str = ' '.join(self.valid_words) + ' '
|
||||||
|
self.letter_flags = None
|
||||||
|
self.letter_positions = None
|
||||||
|
self.guessed_letters = set()
|
||||||
|
self.guess_prefix = "" # Tracks the current guess prefix
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
letter_index = action % 26 # Assuming action is the letter index directly
|
||||||
|
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||||
|
letter = chr(ord('a') + letter_index)
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# Check if the letter has already been used in the guess prefix
|
||||||
|
if letter in self.guessed_letters:
|
||||||
|
reward = -1 # Penalize for repeating letters in the prefix
|
||||||
|
else:
|
||||||
|
# Add the new letter to the prefix and update guessed letters set
|
||||||
|
self.guess_prefix += letter
|
||||||
|
self.guessed_letters.add(letter)
|
||||||
|
|
||||||
|
# Update letter flags based on whether the letter is in the target word
|
||||||
|
if self.target_word[position] == letter:
|
||||||
|
self.letter_flags[letter_index, :] = [1, 0] # Update flag for correct guess
|
||||||
|
elif letter in self.target_word:
|
||||||
|
self.letter_flags[letter_index, :] = [0, 1] # Update flag for correct guess wrong position
|
||||||
|
else:
|
||||||
|
self.letter_flags[letter_index, :] = [0, 0] # Update flag for incorrect guess
|
||||||
|
|
||||||
|
reward = 1 # Reward for adding new information by trying a new letter
|
||||||
|
|
||||||
|
# Update the letter_positions matrix to reflect the new guess
|
||||||
|
if position == 4:
|
||||||
|
self.letter_positions[:,:] = 1
|
||||||
|
else:
|
||||||
|
self.letter_positions[:, position] = 0
|
||||||
|
self.letter_positions[letter_index, position] = 1
|
||||||
|
|
||||||
|
# Use regex to check if the current prefix can lead to a valid word
|
||||||
|
if not re.search(r'\b' + self.guess_prefix, self.valid_words_str):
|
||||||
|
reward = -5 # Penalize for forming an invalid prefix
|
||||||
|
done = True # End the episode if the prefix is invalid
|
||||||
|
|
||||||
|
# guessed a full word so we reset our guess prefix to guess next round
|
||||||
|
if len(self.guess_prefix) == len(self.target_word):
|
||||||
|
self.guess_prefix = ''
|
||||||
|
self.round += 1
|
||||||
|
|
||||||
|
# end after 5 rounds of total guesses
|
||||||
|
if self.round == 2:
|
||||||
|
# reward = 5
|
||||||
|
done = True
|
||||||
|
|
||||||
|
obs = self._get_obs()
|
||||||
|
|
||||||
|
if reward < -50:
|
||||||
|
print(obs, reward, done)
|
||||||
|
|
||||||
|
return obs, reward, done, False, {}
|
||||||
|
|
||||||
|
def reset(self, seed=None):
|
||||||
|
self.target_word = random.choice(self.valid_words)
|
||||||
|
# self.target_word_encoded = self.encode_word(self.target_word)
|
||||||
|
self.letter_flags = np.ones((26, 2), dtype=np.int32)
|
||||||
|
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
||||||
|
self.guessed_letters = set()
|
||||||
|
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||||
|
self.round = 1
|
||||||
|
return self._get_obs(), {}
|
||||||
|
|
||||||
|
def encode_word(self, word):
|
||||||
|
encoded = np.zeros((26,))
|
||||||
|
for char in word:
|
||||||
|
index = ord(char) - ord('a')
|
||||||
|
encoded[index] = 1
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
pass # Optional: Implement rendering logic if needed
|
189
test.ipynb
189
test.ipynb
@ -1,189 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from collections import defaultdict"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def my_func()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"defaultdict(<function __main__.<lambda>()>, {})"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"[0, 1, 2, 3, 4]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t['t']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"False"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"'x' in t"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"x = np.array([1, 1, 1])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"x[:] = 0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"array([0, 0, 0])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"x"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"'abcde'aaa\n",
|
|
||||||
" 33221\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "env",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user