mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-10 07:04:45 +00:00
updated wordle to gymnasium env
This commit is contained in:
parent
9172326013
commit
bbe9a1891c
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
**/data/*
|
**/data/*
|
||||||
**/*.zip
|
**/*.zip
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
|
/env
|
1278
dqn_wordle.ipynb
1278
dqn_wordle.ipynb
File diff suppressed because it is too large
Load Diff
@ -35,7 +35,7 @@ def to_array(word: str) -> npt.NDArray[np.int64]:
|
|||||||
return np.array([_char_d[c] for c in word])
|
return np.array([_char_d[c] for c in word])
|
||||||
|
|
||||||
|
|
||||||
def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]:
|
def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
|
||||||
"""Loads a list of words in array form.
|
"""Loads a list of words in array form.
|
||||||
|
|
||||||
If specified, this will recompute the list from the human-readable list of
|
If specified, this will recompute the list from the human-readable list of
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from sty import fg, bg, ef, rs
|
from sty import fg, bg, ef, rs
|
||||||
@ -7,11 +7,11 @@ from collections import Counter
|
|||||||
from gym_wordle.utils import to_english, to_array, get_words
|
from gym_wordle.utils import to_english, to_array, get_words
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class WordList(gym.spaces.Discrete):
|
class WordList(gym.spaces.Discrete):
|
||||||
"""Super class for defining a space of valid words according to a specified
|
"""Super class for defining a space of valid words according to a specified
|
||||||
list.
|
list.
|
||||||
|
|
||||||
TODO: Fix these paragraphs
|
|
||||||
The space is a subclass of gym.spaces.Discrete, where each element
|
The space is a subclass of gym.spaces.Discrete, where each element
|
||||||
corresponds to an index of a valid word in the word list. The obfuscation
|
corresponds to an index of a valid word in the word list. The obfuscation
|
||||||
is necessary for more direct implementation of RL algorithms, which expect
|
is necessary for more direct implementation of RL algorithms, which expect
|
||||||
@ -72,10 +72,9 @@ class SolutionList(WordList):
|
|||||||
|
|
||||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||||
|
|
||||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
|
||||||
|
|
||||||
This class represents the set of solution words.
|
This class represents the set of solution words.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -87,7 +86,7 @@ class SolutionList(WordList):
|
|||||||
|
|
||||||
class WordleObsSpace(gym.spaces.Box):
|
class WordleObsSpace(gym.spaces.Box):
|
||||||
"""Implementation of the state (observation) space in terms of gym
|
"""Implementation of the state (observation) space in terms of gym
|
||||||
primatives, in this case, gym.spaces.Box.
|
primitives, in this case, gym.spaces.Box.
|
||||||
|
|
||||||
The Wordle observation space can be thought of as a 6x5 array with two
|
The Wordle observation space can be thought of as a 6x5 array with two
|
||||||
channels:
|
channels:
|
||||||
@ -100,20 +99,11 @@ class WordleObsSpace(gym.spaces.Box):
|
|||||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
||||||
the solution will always be a word of length 5.
|
the solution will always be a word of length 5.
|
||||||
|
|
||||||
For simplicity, and compatibility with the stable_baselines algorithms,
|
For simplicity, and compatibility with stable_baselines algorithms,
|
||||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
this multichannel is modeled as a 6x10 array, where the two channels are
|
||||||
horizontally appended (along columns). Thus each row in the observation
|
horizontally appended (along columns). Thus each row in the observation
|
||||||
should be interpreted as
|
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
||||||
|
c0...c4 and its associated flags are f0...f4.
|
||||||
c0 c1 c2 c3 c4 f0 f1 f2 f3 f4
|
|
||||||
|
|
||||||
when the word is c0...c4 and its associated flags are f0...f4.
|
|
||||||
|
|
||||||
While the superclass method `sample` is available to the WordleObsSpace, it
|
|
||||||
should be emphasized that the output of `sample` will (almost surely) not
|
|
||||||
correspond to a real game configuration, because the sampling is not out of
|
|
||||||
possible game configurations. Instead, the Box superclass just samples the
|
|
||||||
integer array space uniformly.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -130,20 +120,11 @@ class WordleObsSpace(gym.spaces.Box):
|
|||||||
|
|
||||||
|
|
||||||
class GuessList(WordList):
|
class GuessList(WordList):
|
||||||
"""Space for *solution* words to the Wordle environment.
|
"""Space for *guess* words to the Wordle environment.
|
||||||
|
|
||||||
In the game Wordle, there are two different collections of words:
|
|
||||||
|
|
||||||
* "guesses", which the game accepts as valid words to use to guess the
|
|
||||||
answer.
|
|
||||||
* "solutions", which the game uses to choose solutions from.
|
|
||||||
|
|
||||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
|
||||||
|
|
||||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
|
||||||
|
|
||||||
This class represents the set of guess words.
|
This class represents the set of guess words.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -154,10 +135,9 @@ class GuessList(WordList):
|
|||||||
|
|
||||||
|
|
||||||
class WordleEnv(gym.Env):
|
class WordleEnv(gym.Env):
|
||||||
|
|
||||||
metadata = {'render.modes': ['human']}
|
metadata = {'render.modes': ['human']}
|
||||||
|
|
||||||
# character flag codes
|
# Character flag codes
|
||||||
no_char = 0
|
no_char = 0
|
||||||
right_pos = 1
|
right_pos = 1
|
||||||
wrong_pos = 2
|
wrong_pos = 2
|
||||||
@ -166,7 +146,6 @@ class WordleEnv(gym.Env):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.seed()
|
|
||||||
self.action_space = GuessList()
|
self.action_space = GuessList()
|
||||||
self.solution_space = SolutionList()
|
self.solution_space = SolutionList()
|
||||||
|
|
||||||
@ -181,6 +160,7 @@ class WordleEnv(gym.Env):
|
|||||||
|
|
||||||
self.n_rounds = 6
|
self.n_rounds = 6
|
||||||
self.n_letters = 5
|
self.n_letters = 5
|
||||||
|
self.info = {'correct': False, 'guesses': set()}
|
||||||
|
|
||||||
def _highlighter(self, char: str, flag: int) -> str:
|
def _highlighter(self, char: str, flag: int) -> str:
|
||||||
"""Terminal renderer functionality. Properly highlights a character
|
"""Terminal renderer functionality. Properly highlights a character
|
||||||
@ -201,35 +181,43 @@ class WordleEnv(gym.Env):
|
|||||||
front, back = self._highlights[flag]
|
front, back = self._highlights[flag]
|
||||||
return front + char + back
|
return front + char + back
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, seed=None, options=None):
|
||||||
|
"""Reset the environment to an initial state and returns an initial
|
||||||
|
observation.
|
||||||
|
|
||||||
|
Note: The observation space instance should be a Box space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (object): The initial observation of the space.
|
||||||
|
"""
|
||||||
self.round = 0
|
self.round = 0
|
||||||
self.solution = self.solution_space.sample()
|
self.solution = self.solution_space.sample()
|
||||||
|
|
||||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||||
|
|
||||||
return self.state
|
self.info = {'correct': False, 'guesses': set()}
|
||||||
|
|
||||||
def render(self, mode: str ='human'):
|
return self.state, self.info
|
||||||
|
|
||||||
|
def render(self, mode: str = 'human'):
|
||||||
"""Renders the Wordle environment.
|
"""Renders the Wordle environment.
|
||||||
|
|
||||||
Currently supported render modes:
|
Currently supported render modes:
|
||||||
|
|
||||||
- human: renders the Wordle game to the terminal.
|
- human: renders the Wordle game to the terminal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode: the mode to render with
|
mode: the mode to render with.
|
||||||
"""
|
"""
|
||||||
if mode == 'human':
|
if mode == 'human':
|
||||||
for row in self.states:
|
for row in self.state:
|
||||||
text = ''.join(map(
|
text = ''.join(map(
|
||||||
self._highlighter,
|
self._highlighter,
|
||||||
to_english(row[:self.n_letters]).upper(),
|
to_english(row[:self.n_letters]).upper(),
|
||||||
row[self.n_letters:]
|
row[self.n_letters:]
|
||||||
))
|
))
|
||||||
|
|
||||||
print(text)
|
print(text)
|
||||||
else:
|
else:
|
||||||
super(WordleEnv, self).render(mode=mode)
|
super().render(mode=mode)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Run one step of the Wordle game. Every game must be previously
|
"""Run one step of the Wordle game. Every game must be previously
|
||||||
@ -237,37 +225,30 @@ class WordleEnv(gym.Env):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: Word guessed by the agent.
|
action: Word guessed by the agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
state (object): Wordle game state after the guess.
|
state (object): Wordle game state after the guess.
|
||||||
reward (float): Reward associated with the guess (-1 for incorrect,
|
reward (float): Reward associated with the guess.
|
||||||
0 for correct)
|
done (bool): Whether the game has ended.
|
||||||
done (bool): Whether the game has ended (by a correct guess or
|
info (dict): Auxiliary diagnostic information.
|
||||||
after six guesses).
|
|
||||||
info (dict): Auxiliary diagnostic information (empty).
|
|
||||||
"""
|
"""
|
||||||
assert self.action_space.contains(action), 'Invalid word!'
|
assert self.action_space.contains(action), 'Invalid word!'
|
||||||
|
|
||||||
# transform the action, solution indices to their words
|
|
||||||
action = self.action_space[action]
|
action = self.action_space[action]
|
||||||
solution = self.solution_space[self.solution]
|
solution = self.solution_space[self.solution]
|
||||||
|
|
||||||
# populate the word chars into the row (character channel)
|
|
||||||
self.state[self.round][:self.n_letters] = action
|
self.state[self.round][:self.n_letters] = action
|
||||||
|
|
||||||
# populate the flag characters into the row (flag channel)
|
|
||||||
counter = Counter()
|
counter = Counter()
|
||||||
for i, char in enumerate(action):
|
for i, char in enumerate(action):
|
||||||
flag_i = i + self.n_letters # starts at 5
|
flag_i = i + self.n_letters
|
||||||
counter[char] += 1
|
counter[char] += 1
|
||||||
|
|
||||||
if char == solution[i]: # character is in correct position
|
if char == solution[i]:
|
||||||
self.state[self.round, flag_i] = self.right_pos
|
self.state[self.round, flag_i] = self.right_pos
|
||||||
elif counter[char] <= (char == solution).sum():
|
elif counter[char] <= (char == solution).sum():
|
||||||
# current character has been seen within correct number of
|
|
||||||
# occurrences
|
|
||||||
self.state[self.round, flag_i] = self.wrong_pos
|
self.state[self.round, flag_i] = self.wrong_pos
|
||||||
else:
|
else:
|
||||||
# wrong character, or "correct" character too many times
|
|
||||||
self.state[self.round, flag_i] = self.wrong_char
|
self.state[self.round, flag_i] = self.wrong_char
|
||||||
|
|
||||||
self.round += 1
|
self.round += 1
|
||||||
@ -277,20 +258,29 @@ class WordleEnv(gym.Env):
|
|||||||
|
|
||||||
done = correct or game_over
|
done = correct or game_over
|
||||||
|
|
||||||
# Total reward equals -(number of incorrect guesses)
|
|
||||||
# reward = 0. if correct else -1.
|
|
||||||
|
|
||||||
# correct +10
|
|
||||||
# guesses new letter +1
|
|
||||||
# guesses correct letter +1
|
|
||||||
# spent another guess -1
|
|
||||||
|
|
||||||
reward = 0
|
reward = 0
|
||||||
reward += np.sum(self.state[:, 5:] == 1) * 1
|
# correct spot
|
||||||
reward += np.sum(self.state[:, 5:] == 2) * 0.5
|
reward += np.sum(self.state[:, 5:] == 1) * 2
|
||||||
|
|
||||||
|
# correct letter not correct spot
|
||||||
|
reward += np.sum(self.state[:, 5:] == 2) * 1
|
||||||
|
|
||||||
|
# incorrect letter
|
||||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||||
|
|
||||||
|
# guess same word as before
|
||||||
|
hashable_action = tuple(action)
|
||||||
|
if hashable_action in self.info['guesses']:
|
||||||
|
reward += -10
|
||||||
|
else: # guess different word
|
||||||
|
reward += 10
|
||||||
|
|
||||||
|
self.info['guesses'].add(hashable_action)
|
||||||
|
|
||||||
|
# for game ending in win or loss
|
||||||
reward += 10 if correct else -10 if done else 0
|
reward += 10 if correct else -10 if done else 0
|
||||||
|
|
||||||
info = {'correct': correct}
|
self.info['correct'] = correct
|
||||||
|
|
||||||
return self.state, reward, done, info
|
# observation, reward, terminated, truncated, info
|
||||||
|
return self.state, reward, done, False, self.info
|
||||||
|
Loading…
Reference in New Issue
Block a user