updated wordle to gymnasium env

This commit is contained in:
Ethan Shapiro 2024-03-15 18:19:58 -07:00
parent 9172326013
commit bbe9a1891c
4 changed files with 293 additions and 1032 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
**/data/* **/data/*
**/*.zip **/*.zip
**/__pycache__ **/__pycache__
/env

File diff suppressed because it is too large Load Diff

View File

@ -35,7 +35,7 @@ def to_array(word: str) -> npt.NDArray[np.int64]:
return np.array([_char_d[c] for c in word]) return np.array([_char_d[c] for c in word])
def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
"""Loads a list of words in array form. """Loads a list of words in array form.
If specified, this will recompute the list from the human-readable list of If specified, this will recompute the list from the human-readable list of
@ -53,14 +53,14 @@ def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]:
five. five.
""" """
assert category in {'guess', 'solution'} assert category in {'guess', 'solution'}
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
if build: if build:
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
with open(list_path, 'r') as f: with open(list_path, 'r') as f:
words = np.array([to_array(line.strip()) for line in f]) words = np.array([to_array(line.strip()) for line in f])
np.save(arr_path, words) np.save(arr_path, words)
return np.load(arr_path) return np.load(arr_path)
@ -69,16 +69,16 @@ def play():
"""Play Wordle yourself!""" """Play Wordle yourself!"""
import gym import gym
import gym_wordle import gym_wordle
env = gym.make('Wordle-v0') # load the environment env = gym.make('Wordle-v0') # load the environment
env.reset() env.reset()
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking! solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
done = False done = False
while not done: while not done:
action = -1 action = -1
# in general, the environment won't be forgiving if you input an # in general, the environment won't be forgiving if you input an
# invalid word, but for this function I want to let you screw up user # invalid word, but for this function I want to let you screw up user
@ -86,8 +86,8 @@ def play():
while not env.action_space.contains(action): while not env.action_space.contains(action):
guess = input('Guess: ') guess = input('Guess: ')
action = env.unwrapped.action_space.index_of(to_array(guess)) action = env.unwrapped.action_space.index_of(to_array(guess))
state, reward, done, info = env.step(action) state, reward, done, info = env.step(action)
env.render() env.render()
print(f"The word was {solution}") print(f"The word was {solution}")

View File

@ -1,17 +1,17 @@
import gym import gymnasium as gym
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from sty import fg, bg, ef, rs from sty import fg, bg, ef, rs
from collections import Counter from collections import Counter
from gym_wordle.utils import to_english, to_array, get_words from gym_wordle.utils import to_english, to_array, get_words
from typing import Optional from typing import Optional
class WordList(gym.spaces.Discrete): class WordList(gym.spaces.Discrete):
"""Super class for defining a space of valid words according to a specified """Super class for defining a space of valid words according to a specified
list. list.
TODO: Fix these paragraphs
The space is a subclass of gym.spaces.Discrete, where each element The space is a subclass of gym.spaces.Discrete, where each element
corresponds to an index of a valid word in the word list. The obfuscation corresponds to an index of a valid word in the word list. The obfuscation
is necessary for more direct implementation of RL algorithms, which expect is necessary for more direct implementation of RL algorithms, which expect
@ -66,16 +66,15 @@ class SolutionList(WordList):
In the game Wordle, there are two different collections of words: In the game Wordle, there are two different collections of words:
* "guesses", which the game accepts as valid words to use to guess the * "guesses", which the game accepts as valid words to use to guess the
answer. answer.
* "solutions", which the game uses to choose solutions from. * "solutions", which the game uses to choose solutions from.
Of course, the set of solutions is a strict subset of the set of guesses. Of course, the set of solutions is a strict subset of the set of guesses.
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
This class represents the set of solution words. This class represents the set of solution words.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
Args: Args:
@ -87,7 +86,7 @@ class SolutionList(WordList):
class WordleObsSpace(gym.spaces.Box): class WordleObsSpace(gym.spaces.Box):
"""Implementation of the state (observation) space in terms of gym """Implementation of the state (observation) space in terms of gym
primatives, in this case, gym.spaces.Box. primitives, in this case, gym.spaces.Box.
The Wordle observation space can be thought of as a 6x5 array with two The Wordle observation space can be thought of as a 6x5 array with two
channels: channels:
@ -100,20 +99,11 @@ class WordleObsSpace(gym.spaces.Box):
where there are 6 rows, one for each turn in the game, and 5 columns, since where there are 6 rows, one for each turn in the game, and 5 columns, since
the solution will always be a word of length 5. the solution will always be a word of length 5.
For simplicity, and compatibility with the stable_baselines algorithms, For simplicity, and compatibility with stable_baselines algorithms,
this multichannel is modeled as a 6x10 array, where the two channels are this multichannel is modeled as a 6x10 array, where the two channels are
horizontally appended (along columns). Thus each row in the observation horizontally appended (along columns). Thus each row in the observation
should be interpreted as should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
c0...c4 and its associated flags are f0...f4.
c0 c1 c2 c3 c4 f0 f1 f2 f3 f4
when the word is c0...c4 and its associated flags are f0...f4.
While the superclass method `sample` is available to the WordleObsSpace, it
should be emphasized that the output of `sample` will (almost surely) not
correspond to a real game configuration, because the sampling is not out of
possible game configurations. Instead, the Box superclass just samples the
integer array space uniformly.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -130,20 +120,11 @@ class WordleObsSpace(gym.spaces.Box):
class GuessList(WordList): class GuessList(WordList):
"""Space for *solution* words to the Wordle environment. """Space for *guess* words to the Wordle environment.
In the game Wordle, there are two different collections of words:
* "guesses", which the game accepts as valid words to use to guess the
answer.
* "solutions", which the game uses to choose solutions from.
Of course, the set of solutions is a strict subset of the set of guesses.
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
This class represents the set of guess words. This class represents the set of guess words.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
Args: Args:
@ -154,10 +135,9 @@ class GuessList(WordList):
class WordleEnv(gym.Env): class WordleEnv(gym.Env):
metadata = {'render.modes': ['human']} metadata = {'render.modes': ['human']}
# character flag codes # Character flag codes
no_char = 0 no_char = 0
right_pos = 1 right_pos = 1
wrong_pos = 2 wrong_pos = 2
@ -166,7 +146,6 @@ class WordleEnv(gym.Env):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.seed()
self.action_space = GuessList() self.action_space = GuessList()
self.solution_space = SolutionList() self.solution_space = SolutionList()
@ -181,6 +160,7 @@ class WordleEnv(gym.Env):
self.n_rounds = 6 self.n_rounds = 6
self.n_letters = 5 self.n_letters = 5
self.info = {'correct': False, 'guesses': set()}
def _highlighter(self, char: str, flag: int) -> str: def _highlighter(self, char: str, flag: int) -> str:
"""Terminal renderer functionality. Properly highlights a character """Terminal renderer functionality. Properly highlights a character
@ -201,73 +181,74 @@ class WordleEnv(gym.Env):
front, back = self._highlights[flag] front, back = self._highlights[flag]
return front + char + back return front + char + back
def reset(self): def reset(self, seed=None, options=None):
"""Reset the environment to an initial state and returns an initial
observation.
Note: The observation space instance should be a Box space.
Returns:
state (object): The initial observation of the space.
"""
self.round = 0 self.round = 0
self.solution = self.solution_space.sample() self.solution = self.solution_space.sample()
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
return self.state self.info = {'correct': False, 'guesses': set()}
def render(self, mode: str ='human'): return self.state, self.info
def render(self, mode: str = 'human'):
"""Renders the Wordle environment. """Renders the Wordle environment.
Currently supported render modes: Currently supported render modes:
- human: renders the Wordle game to the terminal. - human: renders the Wordle game to the terminal.
Args: Args:
mode: the mode to render with mode: the mode to render with.
""" """
if mode == 'human': if mode == 'human':
for row in self.states: for row in self.state:
text = ''.join(map( text = ''.join(map(
self._highlighter, self._highlighter,
to_english(row[:self.n_letters]).upper(), to_english(row[:self.n_letters]).upper(),
row[self.n_letters:] row[self.n_letters:]
)) ))
print(text) print(text)
else: else:
super(WordleEnv, self).render(mode=mode) super().render(mode=mode)
def step(self, action): def step(self, action):
"""Run one step of the Wordle game. Every game must be previously """Run one step of the Wordle game. Every game must be previously
initialized by a call to the `reset` method. initialized by a call to the `reset` method.
Args: Args:
action: Word guessed by the agent. action: Word guessed by the agent.
Returns: Returns:
state (object): Wordle game state after the guess. state (object): Wordle game state after the guess.
reward (float): Reward associated with the guess (-1 for incorrect, reward (float): Reward associated with the guess.
0 for correct) done (bool): Whether the game has ended.
done (bool): Whether the game has ended (by a correct guess or info (dict): Auxiliary diagnostic information.
after six guesses).
info (dict): Auxiliary diagnostic information (empty).
""" """
assert self.action_space.contains(action), 'Invalid word!' assert self.action_space.contains(action), 'Invalid word!'
# transform the action, solution indices to their words action = self.action_space[action]
action = self.action_space[action]
solution = self.solution_space[self.solution] solution = self.solution_space[self.solution]
# populate the word chars into the row (character channel)
self.state[self.round][:self.n_letters] = action self.state[self.round][:self.n_letters] = action
# populate the flag characters into the row (flag channel)
counter = Counter() counter = Counter()
for i, char in enumerate(action): for i, char in enumerate(action):
flag_i = i + self.n_letters # starts at 5 flag_i = i + self.n_letters
counter[char] += 1 counter[char] += 1
if char == solution[i]: # character is in correct position if char == solution[i]:
self.state[self.round, flag_i] = self.right_pos self.state[self.round, flag_i] = self.right_pos
elif counter[char] <= (char == solution).sum(): elif counter[char] <= (char == solution).sum():
# current character has been seen within correct number of
# occurrences
self.state[self.round, flag_i] = self.wrong_pos self.state[self.round, flag_i] = self.wrong_pos
else: else:
# wrong character, or "correct" character too many times
self.state[self.round, flag_i] = self.wrong_char self.state[self.round, flag_i] = self.wrong_char
self.round += 1 self.round += 1
@ -277,20 +258,29 @@ class WordleEnv(gym.Env):
done = correct or game_over done = correct or game_over
# Total reward equals -(number of incorrect guesses)
# reward = 0. if correct else -1.
# correct +10
# guesses new letter +1
# guesses correct letter +1
# spent another guess -1
reward = 0 reward = 0
reward += np.sum(self.state[:, 5:] == 1) * 1 # correct spot
reward += np.sum(self.state[:, 5:] == 2) * 0.5 reward += np.sum(self.state[:, 5:] == 1) * 2
# correct letter not correct spot
reward += np.sum(self.state[:, 5:] == 2) * 1
# incorrect letter
reward += np.sum(self.state[:, 5:] == 3) * -1 reward += np.sum(self.state[:, 5:] == 3) * -1
# guess same word as before
hashable_action = tuple(action)
if hashable_action in self.info['guesses']:
reward += -10
else: # guess different word
reward += 10
self.info['guesses'].add(hashable_action)
# for game ending in win or loss
reward += 10 if correct else -10 if done else 0 reward += 10 if correct else -10 if done else 0
info = {'correct': correct} self.info['correct'] = correct
return self.state, reward, done, info # observation, reward, terminated, truncated, info
return self.state, reward, done, False, self.info