2024-03-16 01:19:58 +00:00
|
|
|
import gymnasium as gym
|
2024-03-14 23:47:11 +00:00
|
|
|
import numpy as np
|
|
|
|
import numpy.typing as npt
|
|
|
|
from sty import fg, bg, ef, rs
|
|
|
|
|
2024-03-16 01:48:21 +00:00
|
|
|
from collections import Counter, defaultdict
|
2024-03-14 23:47:11 +00:00
|
|
|
from gym_wordle.utils import to_english, to_array, get_words
|
2024-03-16 01:19:58 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2024-03-14 23:47:11 +00:00
|
|
|
class WordList(gym.spaces.Discrete):
|
|
|
|
"""Super class for defining a space of valid words according to a specified
|
|
|
|
list.
|
|
|
|
|
|
|
|
The space is a subclass of gym.spaces.Discrete, where each element
|
|
|
|
corresponds to an index of a valid word in the word list. The obfuscation
|
|
|
|
is necessary for more direct implementation of RL algorithms, which expect
|
|
|
|
spaces of less sophisticated form.
|
|
|
|
|
|
|
|
In addition to the default methods of the Discrete space, it implements
|
|
|
|
a __getitem__ method for easy index lookup, and an index_of method to
|
|
|
|
convert potential words into their corresponding index (if they exist).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
words: Collection of words in array form with shape (_, 5), where
|
|
|
|
each word is a row of the array. Each array element is an integer
|
|
|
|
between 0,...,26 (inclusive).
|
|
|
|
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
|
|
"""
|
|
|
|
super().__init__(words.shape[0], **kwargs)
|
|
|
|
self.words = words
|
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
|
|
|
"""Obtains the (int-encoded) word associated with the given index.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
index: Index for the list of words.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Associated word at the position specified by index.
|
|
|
|
"""
|
|
|
|
return self.words[index]
|
|
|
|
|
|
|
|
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
|
|
|
"""Given a word, determine its index in the list (if it exists),
|
|
|
|
otherwise returning -1 if no index exists.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
word: Word to find in the word list.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The index of the given word if it exists, otherwise -1.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
index, = np.nonzero((word == self.words).all(axis=1))
|
|
|
|
return index[0]
|
|
|
|
except:
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
class SolutionList(WordList):
|
|
|
|
"""Space for *solution* words to the Wordle environment.
|
|
|
|
|
|
|
|
In the game Wordle, there are two different collections of words:
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
* "guesses", which the game accepts as valid words to use to guess the
|
|
|
|
answer.
|
|
|
|
* "solutions", which the game uses to choose solutions from.
|
2024-03-14 23:47:11 +00:00
|
|
|
|
|
|
|
Of course, the set of solutions is a strict subset of the set of guesses.
|
|
|
|
|
|
|
|
This class represents the set of solution words.
|
|
|
|
"""
|
2024-03-16 01:19:58 +00:00
|
|
|
|
2024-03-14 23:47:11 +00:00
|
|
|
def __init__(self, **kwargs):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
|
|
"""
|
|
|
|
words = get_words('solution')
|
|
|
|
super().__init__(words, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class WordleObsSpace(gym.spaces.Box):
|
|
|
|
"""Implementation of the state (observation) space in terms of gym
|
2024-03-16 01:19:58 +00:00
|
|
|
primitives, in this case, gym.spaces.Box.
|
2024-03-14 23:47:11 +00:00
|
|
|
|
|
|
|
The Wordle observation space can be thought of as a 6x5 array with two
|
|
|
|
channels:
|
|
|
|
|
|
|
|
- the character channel, indicating which characters are placed on the
|
|
|
|
board (unfilled rows are marked with the empty character, 0)
|
|
|
|
- the flag channel, indicating the in-game information associated with
|
|
|
|
each character's placement (green highlight, yellow highlight, etc.)
|
|
|
|
|
|
|
|
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
|
|
|
the solution will always be a word of length 5.
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
For simplicity, and compatibility with stable_baselines algorithms,
|
2024-03-14 23:47:11 +00:00
|
|
|
this multichannel is modeled as a 6x10 array, where the two channels are
|
|
|
|
horizontally appended (along columns). Thus each row in the observation
|
2024-03-16 01:19:58 +00:00
|
|
|
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
|
|
|
c0...c4 and its associated flags are f0...f4.
|
2024-03-14 23:47:11 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
self.n_rows = 6
|
|
|
|
self.n_cols = 5
|
|
|
|
self.max_char = 26
|
|
|
|
self.max_flag = 4
|
|
|
|
|
|
|
|
low = np.zeros((self.n_rows, 2*self.n_cols))
|
|
|
|
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
|
|
|
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
|
|
|
|
|
|
|
super().__init__(low, high, dtype=np.int64, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class GuessList(WordList):
|
2024-03-16 01:19:58 +00:00
|
|
|
"""Space for *guess* words to the Wordle environment.
|
2024-03-14 23:47:11 +00:00
|
|
|
|
|
|
|
This class represents the set of guess words.
|
|
|
|
"""
|
2024-03-16 01:19:58 +00:00
|
|
|
|
2024-03-14 23:47:11 +00:00
|
|
|
def __init__(self, **kwargs):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
|
|
"""
|
|
|
|
words = get_words('guess')
|
|
|
|
super().__init__(words, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
class WordleEnv(gym.Env):
|
|
|
|
metadata = {'render.modes': ['human']}
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
# Character flag codes
|
2024-03-14 23:47:11 +00:00
|
|
|
no_char = 0
|
|
|
|
right_pos = 1
|
|
|
|
wrong_pos = 2
|
|
|
|
wrong_char = 3
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.action_space = GuessList()
|
|
|
|
self.solution_space = SolutionList()
|
|
|
|
|
|
|
|
self.observation_space = WordleObsSpace()
|
|
|
|
|
|
|
|
self._highlights = {
|
|
|
|
self.right_pos: (bg.green, bg.rs),
|
|
|
|
self.wrong_pos: (bg.yellow, bg.rs),
|
|
|
|
self.wrong_char: ('', ''),
|
|
|
|
self.no_char: ('', ''),
|
|
|
|
}
|
|
|
|
|
|
|
|
self.n_rounds = 6
|
|
|
|
self.n_letters = 5
|
2024-03-16 01:48:21 +00:00
|
|
|
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
2024-03-14 23:47:11 +00:00
|
|
|
|
|
|
|
def _highlighter(self, char: str, flag: int) -> str:
|
|
|
|
"""Terminal renderer functionality. Properly highlights a character
|
|
|
|
based on the flag associated with it.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
char: Character in question.
|
|
|
|
flag: Associated flag, one of:
|
|
|
|
- 0: no character (render no background)
|
|
|
|
- 1: right position (render green background)
|
|
|
|
- 2: wrong position (render yellow background)
|
|
|
|
- 3: wrong character (render no background)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Correct ASCII sequence producing the desired character in the
|
|
|
|
correct background.
|
|
|
|
"""
|
|
|
|
front, back = self._highlights[flag]
|
|
|
|
return front + char + back
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
def reset(self, seed=None, options=None):
|
|
|
|
"""Reset the environment to an initial state and returns an initial
|
|
|
|
observation.
|
|
|
|
|
|
|
|
Note: The observation space instance should be a Box space.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
state (object): The initial observation of the space.
|
|
|
|
"""
|
2024-03-14 23:47:11 +00:00
|
|
|
self.round = 0
|
|
|
|
self.solution = self.solution_space.sample()
|
|
|
|
|
|
|
|
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
|
|
|
|
2024-03-16 01:48:21 +00:00
|
|
|
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
2024-03-14 23:47:11 +00:00
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
return self.state, self.info
|
|
|
|
|
|
|
|
def render(self, mode: str = 'human'):
|
2024-03-14 23:47:11 +00:00
|
|
|
"""Renders the Wordle environment.
|
|
|
|
|
|
|
|
Currently supported render modes:
|
|
|
|
- human: renders the Wordle game to the terminal.
|
|
|
|
|
|
|
|
Args:
|
2024-03-16 01:19:58 +00:00
|
|
|
mode: the mode to render with.
|
2024-03-14 23:47:11 +00:00
|
|
|
"""
|
|
|
|
if mode == 'human':
|
2024-03-16 01:19:58 +00:00
|
|
|
for row in self.state:
|
2024-03-14 23:47:11 +00:00
|
|
|
text = ''.join(map(
|
2024-03-16 01:19:58 +00:00
|
|
|
self._highlighter,
|
|
|
|
to_english(row[:self.n_letters]).upper(),
|
2024-03-14 23:47:11 +00:00
|
|
|
row[self.n_letters:]
|
|
|
|
))
|
|
|
|
print(text)
|
|
|
|
else:
|
2024-03-16 01:19:58 +00:00
|
|
|
super().render(mode=mode)
|
|
|
|
|
2024-03-14 23:47:11 +00:00
|
|
|
def step(self, action):
|
|
|
|
"""Run one step of the Wordle game. Every game must be previously
|
|
|
|
initialized by a call to the `reset` method.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
action: Word guessed by the agent.
|
2024-03-16 01:19:58 +00:00
|
|
|
|
2024-03-14 23:47:11 +00:00
|
|
|
Returns:
|
|
|
|
state (object): Wordle game state after the guess.
|
2024-03-16 01:19:58 +00:00
|
|
|
reward (float): Reward associated with the guess.
|
|
|
|
done (bool): Whether the game has ended.
|
|
|
|
info (dict): Auxiliary diagnostic information.
|
2024-03-14 23:47:11 +00:00
|
|
|
"""
|
|
|
|
assert self.action_space.contains(action), 'Invalid word!'
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
action = self.action_space[action]
|
2024-03-14 23:47:11 +00:00
|
|
|
solution = self.solution_space[self.solution]
|
|
|
|
|
|
|
|
self.state[self.round][:self.n_letters] = action
|
|
|
|
|
|
|
|
counter = Counter()
|
|
|
|
for i, char in enumerate(action):
|
2024-03-16 01:19:58 +00:00
|
|
|
flag_i = i + self.n_letters
|
2024-03-14 23:47:11 +00:00
|
|
|
counter[char] += 1
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
if char == solution[i]:
|
2024-03-14 23:47:11 +00:00
|
|
|
self.state[self.round, flag_i] = self.right_pos
|
2024-03-16 01:19:58 +00:00
|
|
|
elif counter[char] <= (char == solution).sum():
|
2024-03-14 23:47:11 +00:00
|
|
|
self.state[self.round, flag_i] = self.wrong_pos
|
|
|
|
else:
|
|
|
|
self.state[self.round, flag_i] = self.wrong_char
|
|
|
|
|
|
|
|
self.round += 1
|
|
|
|
|
|
|
|
correct = (action == solution).all()
|
|
|
|
game_over = (self.round == self.n_rounds)
|
|
|
|
|
|
|
|
done = correct or game_over
|
|
|
|
|
|
|
|
reward = 0
|
2024-03-16 01:19:58 +00:00
|
|
|
# correct spot
|
|
|
|
reward += np.sum(self.state[:, 5:] == 1) * 2
|
|
|
|
|
|
|
|
# correct letter not correct spot
|
|
|
|
reward += np.sum(self.state[:, 5:] == 2) * 1
|
|
|
|
|
|
|
|
# incorrect letter
|
2024-03-14 23:47:11 +00:00
|
|
|
reward += np.sum(self.state[:, 5:] == 3) * -1
|
2024-03-16 01:19:58 +00:00
|
|
|
|
|
|
|
# guess same word as before
|
2024-03-16 01:48:21 +00:00
|
|
|
hashable_action = to_english(action)
|
2024-03-16 01:19:58 +00:00
|
|
|
if hashable_action in self.info['guesses']:
|
2024-03-16 01:48:21 +00:00
|
|
|
reward += -10 * self.info['guesses'][hashable_action]
|
2024-03-16 01:19:58 +00:00
|
|
|
else: # guess different word
|
|
|
|
reward += 10
|
|
|
|
|
2024-03-16 01:48:21 +00:00
|
|
|
self.info['guesses'][hashable_action] += 1
|
2024-03-16 01:19:58 +00:00
|
|
|
|
|
|
|
# for game ending in win or loss
|
2024-03-14 23:47:11 +00:00
|
|
|
reward += 10 if correct else -10 if done else 0
|
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
self.info['correct'] = correct
|
2024-03-14 23:47:11 +00:00
|
|
|
|
2024-03-16 01:19:58 +00:00
|
|
|
# observation, reward, terminated, truncated, info
|
|
|
|
return self.state, reward, done, False, self.info
|