mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-22 18:49:21 +00:00
Compare commits
5 Commits
arthur-tes
...
gymnasium-
Author | SHA1 | Date | |
---|---|---|---|
|
cf977e4797 | ||
|
bbe9a1891c | ||
|
9172326013 | ||
|
4836be8121 | ||
|
5672169073 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
**/data/*
|
||||
**/*.zip
|
||||
**/__pycache__
|
||||
/env
|
345
dqn_wordle.ipynb
345
dqn_wordle.ipynb
@@ -2,69 +2,268 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import gym_wordle\n",
|
||||
"from stable_baselines3 import DQN\n",
|
||||
"from stable_baselines3 import DQN, PPO, common\n",
|
||||
"import numpy as np\n",
|
||||
"import tqdm"
|
||||
"from tqdm import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<Monitor<WordleEnv instance>>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym.make(\"Wordle-v0\")\n",
|
||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||
"env = common.monitor.Monitor(env)\n",
|
||||
"\n",
|
||||
"print(env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using cuda device\n",
|
||||
"Wrapping the env in a DummyVecEnv.\n",
|
||||
"---------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 2.14 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 750 |\n",
|
||||
"| iterations | 1 |\n",
|
||||
"| time_elapsed | 2 |\n",
|
||||
"| total_timesteps | 2048 |\n",
|
||||
"---------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 4.59 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 625 |\n",
|
||||
"| iterations | 2 |\n",
|
||||
"| time_elapsed | 6 |\n",
|
||||
"| total_timesteps | 4096 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022059526 |\n",
|
||||
"| clip_fraction | 0.331 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | -0.0118 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 130 |\n",
|
||||
"| n_updates | 10 |\n",
|
||||
"| policy_gradient_loss | -0.0851 |\n",
|
||||
"| value_loss | 253 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 5.86 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 585 |\n",
|
||||
"| iterations | 3 |\n",
|
||||
"| time_elapsed | 10 |\n",
|
||||
"| total_timesteps | 6144 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.024416003 |\n",
|
||||
"| clip_fraction | 0.462 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.152 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 85.2 |\n",
|
||||
"| n_updates | 20 |\n",
|
||||
"| policy_gradient_loss | -0.0987 |\n",
|
||||
"| value_loss | 218 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 4.75 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 566 |\n",
|
||||
"| iterations | 4 |\n",
|
||||
"| time_elapsed | 14 |\n",
|
||||
"| total_timesteps | 8192 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.026305672 |\n",
|
||||
"| clip_fraction | 0.45 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.47 |\n",
|
||||
"| explained_variance | 0.161 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 144 |\n",
|
||||
"| n_updates | 30 |\n",
|
||||
"| policy_gradient_loss | -0.105 |\n",
|
||||
"| value_loss | 220 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.47 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 554 |\n",
|
||||
"| iterations | 5 |\n",
|
||||
"| time_elapsed | 18 |\n",
|
||||
"| total_timesteps | 10240 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02928267 |\n",
|
||||
"| clip_fraction | 0.498 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.167 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 127 |\n",
|
||||
"| n_updates | 40 |\n",
|
||||
"| policy_gradient_loss | -0.116 |\n",
|
||||
"| value_loss | 207 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 1.62 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 546 |\n",
|
||||
"| iterations | 6 |\n",
|
||||
"| time_elapsed | 22 |\n",
|
||||
"| total_timesteps | 12288 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.028425258 |\n",
|
||||
"| clip_fraction | 0.483 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.143 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 109 |\n",
|
||||
"| n_updates | 50 |\n",
|
||||
"| policy_gradient_loss | -0.117 |\n",
|
||||
"| value_loss | 240 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5.98 |\n",
|
||||
"| ep_rew_mean | 6.14 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 541 |\n",
|
||||
"| iterations | 7 |\n",
|
||||
"| time_elapsed | 26 |\n",
|
||||
"| total_timesteps | 14336 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.026178032 |\n",
|
||||
"| clip_fraction | 0.453 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.46 |\n",
|
||||
"| explained_variance | 0.174 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 141 |\n",
|
||||
"| n_updates | 60 |\n",
|
||||
"| policy_gradient_loss | -0.116 |\n",
|
||||
"| value_loss | 235 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 3.03 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 537 |\n",
|
||||
"| iterations | 8 |\n",
|
||||
"| time_elapsed | 30 |\n",
|
||||
"| total_timesteps | 16384 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.02457074 |\n",
|
||||
"| clip_fraction | 0.423 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.171 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 111 |\n",
|
||||
"| n_updates | 70 |\n",
|
||||
"| policy_gradient_loss | -0.112 |\n",
|
||||
"| value_loss | 212 |\n",
|
||||
"----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 9.54 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 532 |\n",
|
||||
"| iterations | 9 |\n",
|
||||
"| time_elapsed | 34 |\n",
|
||||
"| total_timesteps | 18432 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.024578478 |\n",
|
||||
"| clip_fraction | 0.417 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.178 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 121 |\n",
|
||||
"| n_updates | 80 |\n",
|
||||
"| policy_gradient_loss | -0.114 |\n",
|
||||
"| value_loss | 232 |\n",
|
||||
"-----------------------------------------\n",
|
||||
"-----------------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 6 |\n",
|
||||
"| ep_rew_mean | 3.81 |\n",
|
||||
"| time/ | |\n",
|
||||
"| fps | 527 |\n",
|
||||
"| iterations | 10 |\n",
|
||||
"| time_elapsed | 38 |\n",
|
||||
"| total_timesteps | 20480 |\n",
|
||||
"| train/ | |\n",
|
||||
"| approx_kl | 0.022704324 |\n",
|
||||
"| clip_fraction | 0.379 |\n",
|
||||
"| clip_range | 0.2 |\n",
|
||||
"| entropy_loss | -9.45 |\n",
|
||||
"| explained_variance | 0.194 |\n",
|
||||
"| learning_rate | 0.0003 |\n",
|
||||
"| loss | 108 |\n",
|
||||
"| n_updates | 90 |\n",
|
||||
"| policy_gradient_loss | -0.112 |\n",
|
||||
"| value_loss | 216 |\n",
|
||||
"-----------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"total_timesteps = 100000\n",
|
||||
"model = DQN(\"MlpPolicy\", env, verbose=0)\n",
|
||||
"model.learn(total_timesteps=total_timesteps, progress_bar=True)"
|
||||
"total_timesteps = 20_000\n",
|
||||
"model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||
"model.learn(total_timesteps=total_timesteps)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test(model):\n",
|
||||
"\n",
|
||||
" end_rewards = []\n",
|
||||
"\n",
|
||||
" for i in range(1000):\n",
|
||||
" \n",
|
||||
" state = env.reset()\n",
|
||||
"\n",
|
||||
" done = False\n",
|
||||
"\n",
|
||||
" while not done:\n",
|
||||
"\n",
|
||||
" action, _states = model.predict(state, deterministic=True)\n",
|
||||
"\n",
|
||||
" state, reward, done, info = env.step(action)\n",
|
||||
" \n",
|
||||
" end_rewards.append(reward == 0)\n",
|
||||
" \n",
|
||||
" return np.sum(end_rewards) / len(end_rewards)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -73,11 +272,69 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = DQN.load(\"dqn_wordle\")"
|
||||
"model = PPO.load(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[ 7 18 1 19 16 3 3 3 2 3]\n",
|
||||
" [16 9 5 14 4 3 3 3 3 3]\n",
|
||||
" [16 9 5 14 4 3 3 3 3 3]\n",
|
||||
" [16 9 5 14 4 3 3 3 3 3]\n",
|
||||
" [ 7 18 1 19 16 3 3 3 2 3]\n",
|
||||
" [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n",
|
||||
"0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||
"\n",
|
||||
"for i in tqdm(range(1000)):\n",
|
||||
" \n",
|
||||
" state, info = env.reset()\n",
|
||||
"\n",
|
||||
" done = False\n",
|
||||
"\n",
|
||||
" wins = 0\n",
|
||||
"\n",
|
||||
" while not done:\n",
|
||||
"\n",
|
||||
" action, _states = model.predict(state, deterministic=True)\n",
|
||||
"\n",
|
||||
" state, reward, done, truncated, info = env.step(action)\n",
|
||||
"\n",
|
||||
" if info[\"correct\"]:\n",
|
||||
" wins += 1\n",
|
||||
"\n",
|
||||
"print(state, reward, info)\n",
|
||||
"\n",
|
||||
"print(wins)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -85,9 +342,7 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(test(model))"
|
||||
]
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
7
gym_wordle/__init__.py
Normal file
7
gym_wordle/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from gym.envs.registration import register
|
||||
from .wordle import WordleEnv
|
||||
|
||||
register(
|
||||
id='Wordle-v0',
|
||||
entry_point='gym_wordle.wordle:WordleEnv'
|
||||
)
|
12972
gym_wordle/dictionary/guess_list.csv
Normal file
12972
gym_wordle/dictionary/guess_list.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
gym_wordle/dictionary/guess_list.npy
Normal file
BIN
gym_wordle/dictionary/guess_list.npy
Normal file
Binary file not shown.
2315
gym_wordle/dictionary/solution_list.csv
Normal file
2315
gym_wordle/dictionary/solution_list.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
gym_wordle/dictionary/solution_list.npy
Normal file
BIN
gym_wordle/dictionary/solution_list.npy
Normal file
Binary file not shown.
93
gym_wordle/utils.py
Normal file
93
gym_wordle/utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_chars = ' abcdefghijklmnopqrstuvwxyz'
|
||||
_char_d = {c: i for i, c in enumerate(_chars)}
|
||||
|
||||
|
||||
def to_english(array: npt.NDArray[np.int64]) -> str:
|
||||
"""Converts a numpy integer array into a corresponding English string.
|
||||
|
||||
Args:
|
||||
array: Word in array (int) form. It is assumed that each integer in the
|
||||
array is between 0,...,26 (inclusive).
|
||||
|
||||
Returns:
|
||||
A (lowercase) string representation of the word.
|
||||
"""
|
||||
return ''.join(_chars[i] for i in array)
|
||||
|
||||
|
||||
def to_array(word: str) -> npt.NDArray[np.int64]:
|
||||
"""Converts a string of characters into a corresponding numpy array.
|
||||
|
||||
Args:
|
||||
word: Word in string form. It is assumed that each character in the
|
||||
string is either an empty space ' ' or lowercase alphabetical
|
||||
character.
|
||||
|
||||
Returns:
|
||||
An array representation of the word.
|
||||
"""
|
||||
return np.array([_char_d[c] for c in word])
|
||||
|
||||
|
||||
def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
|
||||
"""Loads a list of words in array form.
|
||||
|
||||
If specified, this will recompute the list from the human-readable list of
|
||||
words, and save the results in array form.
|
||||
|
||||
Args:
|
||||
category: Either 'guess' or 'solution', which corresponds to the list
|
||||
of acceptable guess words and the list of acceptable solution words.
|
||||
build: If True, recomputes and saves the array-version of the computed
|
||||
list for future access.
|
||||
|
||||
Returns:
|
||||
An array representation of the list of words specified by the category.
|
||||
This array has two dimensions, and the number of columns is fixed at
|
||||
five.
|
||||
"""
|
||||
assert category in {'guess', 'solution'}
|
||||
|
||||
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
|
||||
if build:
|
||||
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
|
||||
|
||||
with open(list_path, 'r') as f:
|
||||
words = np.array([to_array(line.strip()) for line in f])
|
||||
np.save(arr_path, words)
|
||||
|
||||
return np.load(arr_path)
|
||||
|
||||
|
||||
def play():
|
||||
"""Play Wordle yourself!"""
|
||||
import gym
|
||||
import gym_wordle
|
||||
|
||||
env = gym.make('Wordle-v0') # load the environment
|
||||
|
||||
env.reset()
|
||||
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
|
||||
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = -1
|
||||
|
||||
# in general, the environment won't be forgiving if you input an
|
||||
# invalid word, but for this function I want to let you screw up user
|
||||
# input without consequence, so just loops until valid input is taken
|
||||
while not env.action_space.contains(action):
|
||||
guess = input('Guess: ')
|
||||
action = env.unwrapped.action_space.index_of(to_array(guess))
|
||||
|
||||
state, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
print(f"The word was {solution}")
|
285
gym_wordle/wordle.py
Normal file
285
gym_wordle/wordle.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
"""Super class for defining a space of valid words according to a specified
|
||||
list.
|
||||
|
||||
The space is a subclass of gym.spaces.Discrete, where each element
|
||||
corresponds to an index of a valid word in the word list. The obfuscation
|
||||
is necessary for more direct implementation of RL algorithms, which expect
|
||||
spaces of less sophisticated form.
|
||||
|
||||
In addition to the default methods of the Discrete space, it implements
|
||||
a __getitem__ method for easy index lookup, and an index_of method to
|
||||
convert potential words into their corresponding index (if they exist).
|
||||
"""
|
||||
|
||||
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
||||
"""
|
||||
Args:
|
||||
words: Collection of words in array form with shape (_, 5), where
|
||||
each word is a row of the array. Each array element is an integer
|
||||
between 0,...,26 (inclusive).
|
||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||
"""
|
||||
super().__init__(words.shape[0], **kwargs)
|
||||
self.words = words
|
||||
|
||||
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
||||
"""Obtains the (int-encoded) word associated with the given index.
|
||||
|
||||
Args:
|
||||
index: Index for the list of words.
|
||||
|
||||
Returns:
|
||||
Associated word at the position specified by index.
|
||||
"""
|
||||
return self.words[index]
|
||||
|
||||
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
||||
"""Given a word, determine its index in the list (if it exists),
|
||||
otherwise returning -1 if no index exists.
|
||||
|
||||
Args:
|
||||
word: Word to find in the word list.
|
||||
|
||||
Returns:
|
||||
The index of the given word if it exists, otherwise -1.
|
||||
"""
|
||||
try:
|
||||
index, = np.nonzero((word == self.words).all(axis=1))
|
||||
return index[0]
|
||||
except:
|
||||
return -1
|
||||
|
||||
|
||||
class SolutionList(WordList):
|
||||
"""Space for *solution* words to the Wordle environment.
|
||||
|
||||
In the game Wordle, there are two different collections of words:
|
||||
|
||||
* "guesses", which the game accepts as valid words to use to guess the
|
||||
answer.
|
||||
* "solutions", which the game uses to choose solutions from.
|
||||
|
||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||
|
||||
This class represents the set of solution words.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||
"""
|
||||
words = get_words('solution')
|
||||
super().__init__(words, **kwargs)
|
||||
|
||||
|
||||
class WordleObsSpace(gym.spaces.Box):
|
||||
"""Implementation of the state (observation) space in terms of gym
|
||||
primitives, in this case, gym.spaces.Box.
|
||||
|
||||
The Wordle observation space can be thought of as a 6x5 array with two
|
||||
channels:
|
||||
|
||||
- the character channel, indicating which characters are placed on the
|
||||
board (unfilled rows are marked with the empty character, 0)
|
||||
- the flag channel, indicating the in-game information associated with
|
||||
each character's placement (green highlight, yellow highlight, etc.)
|
||||
|
||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
||||
the solution will always be a word of length 5.
|
||||
|
||||
For simplicity, and compatibility with stable_baselines algorithms,
|
||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
||||
horizontally appended (along columns). Thus each row in the observation
|
||||
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
||||
c0...c4 and its associated flags are f0...f4.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.n_rows = 6
|
||||
self.n_cols = 5
|
||||
self.max_char = 26
|
||||
self.max_flag = 4
|
||||
|
||||
low = np.zeros((self.n_rows, 2*self.n_cols))
|
||||
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
||||
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
||||
|
||||
super().__init__(low, high, dtype=np.int64, **kwargs)
|
||||
|
||||
|
||||
class GuessList(WordList):
|
||||
"""Space for *guess* words to the Wordle environment.
|
||||
|
||||
This class represents the set of guess words.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||
"""
|
||||
words = get_words('guess')
|
||||
super().__init__(words, **kwargs)
|
||||
|
||||
|
||||
class WordleEnv(gym.Env):
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
# Character flag codes
|
||||
no_char = 0
|
||||
right_pos = 1
|
||||
wrong_pos = 2
|
||||
wrong_char = 3
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.action_space = GuessList()
|
||||
self.solution_space = SolutionList()
|
||||
|
||||
self.observation_space = WordleObsSpace()
|
||||
|
||||
self._highlights = {
|
||||
self.right_pos: (bg.green, bg.rs),
|
||||
self.wrong_pos: (bg.yellow, bg.rs),
|
||||
self.wrong_char: ('', ''),
|
||||
self.no_char: ('', ''),
|
||||
}
|
||||
|
||||
self.n_rounds = 6
|
||||
self.n_letters = 5
|
||||
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
||||
|
||||
def _highlighter(self, char: str, flag: int) -> str:
|
||||
"""Terminal renderer functionality. Properly highlights a character
|
||||
based on the flag associated with it.
|
||||
|
||||
Args:
|
||||
char: Character in question.
|
||||
flag: Associated flag, one of:
|
||||
- 0: no character (render no background)
|
||||
- 1: right position (render green background)
|
||||
- 2: wrong position (render yellow background)
|
||||
- 3: wrong character (render no background)
|
||||
|
||||
Returns:
|
||||
Correct ASCII sequence producing the desired character in the
|
||||
correct background.
|
||||
"""
|
||||
front, back = self._highlights[flag]
|
||||
return front + char + back
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""Reset the environment to an initial state and returns an initial
|
||||
observation.
|
||||
|
||||
Note: The observation space instance should be a Box space.
|
||||
|
||||
Returns:
|
||||
state (object): The initial observation of the space.
|
||||
"""
|
||||
self.round = 0
|
||||
self.solution = self.solution_space.sample()
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||
|
||||
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
||||
|
||||
return self.state, self.info
|
||||
|
||||
def render(self, mode: str = 'human'):
|
||||
"""Renders the Wordle environment.
|
||||
|
||||
Currently supported render modes:
|
||||
- human: renders the Wordle game to the terminal.
|
||||
|
||||
Args:
|
||||
mode: the mode to render with.
|
||||
"""
|
||||
if mode == 'human':
|
||||
for row in self.state:
|
||||
text = ''.join(map(
|
||||
self._highlighter,
|
||||
to_english(row[:self.n_letters]).upper(),
|
||||
row[self.n_letters:]
|
||||
))
|
||||
print(text)
|
||||
else:
|
||||
super().render(mode=mode)
|
||||
|
||||
def step(self, action):
|
||||
"""Run one step of the Wordle game. Every game must be previously
|
||||
initialized by a call to the `reset` method.
|
||||
|
||||
Args:
|
||||
action: Word guessed by the agent.
|
||||
|
||||
Returns:
|
||||
state (object): Wordle game state after the guess.
|
||||
reward (float): Reward associated with the guess.
|
||||
done (bool): Whether the game has ended.
|
||||
info (dict): Auxiliary diagnostic information.
|
||||
"""
|
||||
assert self.action_space.contains(action), 'Invalid word!'
|
||||
|
||||
action = self.action_space[action]
|
||||
solution = self.solution_space[self.solution]
|
||||
|
||||
self.state[self.round][:self.n_letters] = action
|
||||
|
||||
counter = Counter()
|
||||
for i, char in enumerate(action):
|
||||
flag_i = i + self.n_letters
|
||||
counter[char] += 1
|
||||
|
||||
if char == solution[i]:
|
||||
self.state[self.round, flag_i] = self.right_pos
|
||||
elif counter[char] <= (char == solution).sum():
|
||||
self.state[self.round, flag_i] = self.wrong_pos
|
||||
else:
|
||||
self.state[self.round, flag_i] = self.wrong_char
|
||||
|
||||
self.round += 1
|
||||
|
||||
correct = (action == solution).all()
|
||||
game_over = (self.round == self.n_rounds)
|
||||
|
||||
done = correct or game_over
|
||||
|
||||
reward = 0
|
||||
# correct spot
|
||||
reward += np.sum(self.state[:, 5:] == 1) * 2
|
||||
|
||||
# correct letter not correct spot
|
||||
reward += np.sum(self.state[:, 5:] == 2) * 1
|
||||
|
||||
# incorrect letter
|
||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||
|
||||
# guess same word as before
|
||||
hashable_action = to_english(action)
|
||||
if hashable_action in self.info['guesses']:
|
||||
reward += -10 * self.info['guesses'][hashable_action]
|
||||
else: # guess different word
|
||||
reward += 10
|
||||
|
||||
self.info['guesses'][hashable_action] += 1
|
||||
|
||||
# for game ending in win or loss
|
||||
reward += 10 if correct else -10 if done else 0
|
||||
|
||||
self.info['correct'] = correct
|
||||
|
||||
# observation, reward, terminated, truncated, info
|
||||
return self.state, reward, done, False, self.info
|
Reference in New Issue
Block a user