mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-26 01:59:10 +00:00
updated wordle to gymnasium env
This commit is contained in:
parent
9172326013
commit
bbe9a1891c
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
**/data/*
|
||||
**/*.zip
|
||||
**/__pycache__
|
||||
**/__pycache__
|
||||
/env
|
1158
dqn_wordle.ipynb
1158
dqn_wordle.ipynb
File diff suppressed because it is too large
Load Diff
@ -35,7 +35,7 @@ def to_array(word: str) -> npt.NDArray[np.int64]:
|
||||
return np.array([_char_d[c] for c in word])
|
||||
|
||||
|
||||
def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]:
|
||||
def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
|
||||
"""Loads a list of words in array form.
|
||||
|
||||
If specified, this will recompute the list from the human-readable list of
|
||||
@ -53,14 +53,14 @@ def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]:
|
||||
five.
|
||||
"""
|
||||
assert category in {'guess', 'solution'}
|
||||
|
||||
|
||||
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
|
||||
if build:
|
||||
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
|
||||
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
|
||||
|
||||
with open(list_path, 'r') as f:
|
||||
words = np.array([to_array(line.strip()) for line in f])
|
||||
np.save(arr_path, words)
|
||||
with open(list_path, 'r') as f:
|
||||
words = np.array([to_array(line.strip()) for line in f])
|
||||
np.save(arr_path, words)
|
||||
|
||||
return np.load(arr_path)
|
||||
|
||||
@ -69,16 +69,16 @@ def play():
|
||||
"""Play Wordle yourself!"""
|
||||
import gym
|
||||
import gym_wordle
|
||||
|
||||
|
||||
env = gym.make('Wordle-v0') # load the environment
|
||||
|
||||
|
||||
env.reset()
|
||||
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
|
||||
|
||||
done = False
|
||||
|
||||
|
||||
while not done:
|
||||
action = -1
|
||||
action = -1
|
||||
|
||||
# in general, the environment won't be forgiving if you input an
|
||||
# invalid word, but for this function I want to let you screw up user
|
||||
@ -86,8 +86,8 @@ def play():
|
||||
while not env.action_space.contains(action):
|
||||
guess = input('Guess: ')
|
||||
action = env.unwrapped.action_space.index_of(to_array(guess))
|
||||
|
||||
|
||||
state, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
print(f"The word was {solution}")
|
||||
|
||||
print(f"The word was {solution}")
|
||||
|
@ -1,17 +1,17 @@
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
from collections import Counter
|
||||
from collections import Counter
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
"""Super class for defining a space of valid words according to a specified
|
||||
list.
|
||||
|
||||
TODO: Fix these paragraphs
|
||||
The space is a subclass of gym.spaces.Discrete, where each element
|
||||
corresponds to an index of a valid word in the word list. The obfuscation
|
||||
is necessary for more direct implementation of RL algorithms, which expect
|
||||
@ -66,16 +66,15 @@ class SolutionList(WordList):
|
||||
|
||||
In the game Wordle, there are two different collections of words:
|
||||
|
||||
* "guesses", which the game accepts as valid words to use to guess the
|
||||
answer.
|
||||
* "solutions", which the game uses to choose solutions from.
|
||||
* "guesses", which the game accepts as valid words to use to guess the
|
||||
answer.
|
||||
* "solutions", which the game uses to choose solutions from.
|
||||
|
||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||
|
||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
||||
|
||||
This class represents the set of solution words.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
@ -87,7 +86,7 @@ class SolutionList(WordList):
|
||||
|
||||
class WordleObsSpace(gym.spaces.Box):
|
||||
"""Implementation of the state (observation) space in terms of gym
|
||||
primatives, in this case, gym.spaces.Box.
|
||||
primitives, in this case, gym.spaces.Box.
|
||||
|
||||
The Wordle observation space can be thought of as a 6x5 array with two
|
||||
channels:
|
||||
@ -100,20 +99,11 @@ class WordleObsSpace(gym.spaces.Box):
|
||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
||||
the solution will always be a word of length 5.
|
||||
|
||||
For simplicity, and compatibility with the stable_baselines algorithms,
|
||||
For simplicity, and compatibility with stable_baselines algorithms,
|
||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
||||
horizontally appended (along columns). Thus each row in the observation
|
||||
should be interpreted as
|
||||
|
||||
c0 c1 c2 c3 c4 f0 f1 f2 f3 f4
|
||||
|
||||
when the word is c0...c4 and its associated flags are f0...f4.
|
||||
|
||||
While the superclass method `sample` is available to the WordleObsSpace, it
|
||||
should be emphasized that the output of `sample` will (almost surely) not
|
||||
correspond to a real game configuration, because the sampling is not out of
|
||||
possible game configurations. Instead, the Box superclass just samples the
|
||||
integer array space uniformly.
|
||||
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
||||
c0...c4 and its associated flags are f0...f4.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -130,20 +120,11 @@ class WordleObsSpace(gym.spaces.Box):
|
||||
|
||||
|
||||
class GuessList(WordList):
|
||||
"""Space for *solution* words to the Wordle environment.
|
||||
|
||||
In the game Wordle, there are two different collections of words:
|
||||
|
||||
* "guesses", which the game accepts as valid words to use to guess the
|
||||
answer.
|
||||
* "solutions", which the game uses to choose solutions from.
|
||||
|
||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||
|
||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
||||
"""Space for *guess* words to the Wordle environment.
|
||||
|
||||
This class represents the set of guess words.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
@ -154,10 +135,9 @@ class GuessList(WordList):
|
||||
|
||||
|
||||
class WordleEnv(gym.Env):
|
||||
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
# character flag codes
|
||||
# Character flag codes
|
||||
no_char = 0
|
||||
right_pos = 1
|
||||
wrong_pos = 2
|
||||
@ -166,7 +146,6 @@ class WordleEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.seed()
|
||||
self.action_space = GuessList()
|
||||
self.solution_space = SolutionList()
|
||||
|
||||
@ -181,6 +160,7 @@ class WordleEnv(gym.Env):
|
||||
|
||||
self.n_rounds = 6
|
||||
self.n_letters = 5
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
|
||||
def _highlighter(self, char: str, flag: int) -> str:
|
||||
"""Terminal renderer functionality. Properly highlights a character
|
||||
@ -201,73 +181,74 @@ class WordleEnv(gym.Env):
|
||||
front, back = self._highlights[flag]
|
||||
return front + char + back
|
||||
|
||||
def reset(self):
|
||||
def reset(self, seed=None, options=None):
|
||||
"""Reset the environment to an initial state and returns an initial
|
||||
observation.
|
||||
|
||||
Note: The observation space instance should be a Box space.
|
||||
|
||||
Returns:
|
||||
state (object): The initial observation of the space.
|
||||
"""
|
||||
self.round = 0
|
||||
self.solution = self.solution_space.sample()
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||
|
||||
return self.state
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
|
||||
def render(self, mode: str ='human'):
|
||||
return self.state, self.info
|
||||
|
||||
def render(self, mode: str = 'human'):
|
||||
"""Renders the Wordle environment.
|
||||
|
||||
Currently supported render modes:
|
||||
|
||||
- human: renders the Wordle game to the terminal.
|
||||
|
||||
Args:
|
||||
mode: the mode to render with
|
||||
mode: the mode to render with.
|
||||
"""
|
||||
if mode == 'human':
|
||||
for row in self.states:
|
||||
for row in self.state:
|
||||
text = ''.join(map(
|
||||
self._highlighter,
|
||||
to_english(row[:self.n_letters]).upper(),
|
||||
self._highlighter,
|
||||
to_english(row[:self.n_letters]).upper(),
|
||||
row[self.n_letters:]
|
||||
))
|
||||
|
||||
print(text)
|
||||
else:
|
||||
super(WordleEnv, self).render(mode=mode)
|
||||
|
||||
super().render(mode=mode)
|
||||
|
||||
def step(self, action):
|
||||
"""Run one step of the Wordle game. Every game must be previously
|
||||
initialized by a call to the `reset` method.
|
||||
|
||||
Args:
|
||||
action: Word guessed by the agent.
|
||||
|
||||
Returns:
|
||||
state (object): Wordle game state after the guess.
|
||||
reward (float): Reward associated with the guess (-1 for incorrect,
|
||||
0 for correct)
|
||||
done (bool): Whether the game has ended (by a correct guess or
|
||||
after six guesses).
|
||||
info (dict): Auxiliary diagnostic information (empty).
|
||||
reward (float): Reward associated with the guess.
|
||||
done (bool): Whether the game has ended.
|
||||
info (dict): Auxiliary diagnostic information.
|
||||
"""
|
||||
assert self.action_space.contains(action), 'Invalid word!'
|
||||
|
||||
# transform the action, solution indices to their words
|
||||
action = self.action_space[action]
|
||||
action = self.action_space[action]
|
||||
solution = self.solution_space[self.solution]
|
||||
|
||||
# populate the word chars into the row (character channel)
|
||||
self.state[self.round][:self.n_letters] = action
|
||||
|
||||
# populate the flag characters into the row (flag channel)
|
||||
counter = Counter()
|
||||
for i, char in enumerate(action):
|
||||
flag_i = i + self.n_letters # starts at 5
|
||||
flag_i = i + self.n_letters
|
||||
counter[char] += 1
|
||||
|
||||
if char == solution[i]: # character is in correct position
|
||||
if char == solution[i]:
|
||||
self.state[self.round, flag_i] = self.right_pos
|
||||
elif counter[char] <= (char == solution).sum():
|
||||
# current character has been seen within correct number of
|
||||
# occurrences
|
||||
elif counter[char] <= (char == solution).sum():
|
||||
self.state[self.round, flag_i] = self.wrong_pos
|
||||
else:
|
||||
# wrong character, or "correct" character too many times
|
||||
self.state[self.round, flag_i] = self.wrong_char
|
||||
|
||||
self.round += 1
|
||||
@ -277,20 +258,29 @@ class WordleEnv(gym.Env):
|
||||
|
||||
done = correct or game_over
|
||||
|
||||
# Total reward equals -(number of incorrect guesses)
|
||||
# reward = 0. if correct else -1.
|
||||
|
||||
# correct +10
|
||||
# guesses new letter +1
|
||||
# guesses correct letter +1
|
||||
# spent another guess -1
|
||||
|
||||
reward = 0
|
||||
reward += np.sum(self.state[:, 5:] == 1) * 1
|
||||
reward += np.sum(self.state[:, 5:] == 2) * 0.5
|
||||
# correct spot
|
||||
reward += np.sum(self.state[:, 5:] == 1) * 2
|
||||
|
||||
# correct letter not correct spot
|
||||
reward += np.sum(self.state[:, 5:] == 2) * 1
|
||||
|
||||
# incorrect letter
|
||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||
|
||||
# guess same word as before
|
||||
hashable_action = tuple(action)
|
||||
if hashable_action in self.info['guesses']:
|
||||
reward += -10
|
||||
else: # guess different word
|
||||
reward += 10
|
||||
|
||||
self.info['guesses'].add(hashable_action)
|
||||
|
||||
# for game ending in win or loss
|
||||
reward += 10 if correct else -10 if done else 0
|
||||
|
||||
info = {'correct': correct}
|
||||
self.info['correct'] = correct
|
||||
|
||||
return self.state, reward, done, info
|
||||
# observation, reward, terminated, truncated, info
|
||||
return self.state, reward, done, False, self.info
|
||||
|
Loading…
Reference in New Issue
Block a user