5 Commits

Author SHA1 Message Date
Arthur Lu
cf977e4797 try penalizing duplicate guesses 2024-03-15 18:48:21 -07:00
Ethan Shapiro
bbe9a1891c updated wordle to gymnasium env 2024-03-15 18:19:58 -07:00
Arthur Lu
9172326013 upload wordle env, fix indexing issue in wordle env, attempt to improve reward (no improvement) 2024-03-14 16:47:11 -07:00
Arthur Lu
4836be8121 remove debug prints 2024-03-14 15:00:19 -07:00
Arthur Lu
5672169073 copy the wordle env locally and fix the obs return 2024-03-14 14:49:17 -07:00
18 changed files with 16041 additions and 15557 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
**/data/*
**/*.zip
**/__pycache__
**/__pycache__
/env

369
dqn_wordle.ipynb Normal file
View File

@@ -0,0 +1,369 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gym_wordle\n",
"from stable_baselines3 import DQN, PPO, common\n",
"import numpy as np\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<Monitor<WordleEnv instance>>\n"
]
}
],
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"env = common.monitor.Monitor(env)\n",
"\n",
"print(env)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
"Wrapping the env in a DummyVecEnv.\n",
"---------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 2.14 |\n",
"| time/ | |\n",
"| fps | 750 |\n",
"| iterations | 1 |\n",
"| time_elapsed | 2 |\n",
"| total_timesteps | 2048 |\n",
"---------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.59 |\n",
"| time/ | |\n",
"| fps | 625 |\n",
"| iterations | 2 |\n",
"| time_elapsed | 6 |\n",
"| total_timesteps | 4096 |\n",
"| train/ | |\n",
"| approx_kl | 0.022059526 |\n",
"| clip_fraction | 0.331 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | -0.0118 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 130 |\n",
"| n_updates | 10 |\n",
"| policy_gradient_loss | -0.0851 |\n",
"| value_loss | 253 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 5.86 |\n",
"| time/ | |\n",
"| fps | 585 |\n",
"| iterations | 3 |\n",
"| time_elapsed | 10 |\n",
"| total_timesteps | 6144 |\n",
"| train/ | |\n",
"| approx_kl | 0.024416003 |\n",
"| clip_fraction | 0.462 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.152 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 85.2 |\n",
"| n_updates | 20 |\n",
"| policy_gradient_loss | -0.0987 |\n",
"| value_loss | 218 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 4.75 |\n",
"| time/ | |\n",
"| fps | 566 |\n",
"| iterations | 4 |\n",
"| time_elapsed | 14 |\n",
"| total_timesteps | 8192 |\n",
"| train/ | |\n",
"| approx_kl | 0.026305672 |\n",
"| clip_fraction | 0.45 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.47 |\n",
"| explained_variance | 0.161 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 144 |\n",
"| n_updates | 30 |\n",
"| policy_gradient_loss | -0.105 |\n",
"| value_loss | 220 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.47 |\n",
"| time/ | |\n",
"| fps | 554 |\n",
"| iterations | 5 |\n",
"| time_elapsed | 18 |\n",
"| total_timesteps | 10240 |\n",
"| train/ | |\n",
"| approx_kl | 0.02928267 |\n",
"| clip_fraction | 0.498 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.167 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 127 |\n",
"| n_updates | 40 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 207 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 1.62 |\n",
"| time/ | |\n",
"| fps | 546 |\n",
"| iterations | 6 |\n",
"| time_elapsed | 22 |\n",
"| total_timesteps | 12288 |\n",
"| train/ | |\n",
"| approx_kl | 0.028425258 |\n",
"| clip_fraction | 0.483 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.143 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 109 |\n",
"| n_updates | 50 |\n",
"| policy_gradient_loss | -0.117 |\n",
"| value_loss | 240 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 5.98 |\n",
"| ep_rew_mean | 6.14 |\n",
"| time/ | |\n",
"| fps | 541 |\n",
"| iterations | 7 |\n",
"| time_elapsed | 26 |\n",
"| total_timesteps | 14336 |\n",
"| train/ | |\n",
"| approx_kl | 0.026178032 |\n",
"| clip_fraction | 0.453 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.46 |\n",
"| explained_variance | 0.174 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 141 |\n",
"| n_updates | 60 |\n",
"| policy_gradient_loss | -0.116 |\n",
"| value_loss | 235 |\n",
"-----------------------------------------\n",
"----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.03 |\n",
"| time/ | |\n",
"| fps | 537 |\n",
"| iterations | 8 |\n",
"| time_elapsed | 30 |\n",
"| total_timesteps | 16384 |\n",
"| train/ | |\n",
"| approx_kl | 0.02457074 |\n",
"| clip_fraction | 0.423 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.171 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 111 |\n",
"| n_updates | 70 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 212 |\n",
"----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 9.54 |\n",
"| time/ | |\n",
"| fps | 532 |\n",
"| iterations | 9 |\n",
"| time_elapsed | 34 |\n",
"| total_timesteps | 18432 |\n",
"| train/ | |\n",
"| approx_kl | 0.024578478 |\n",
"| clip_fraction | 0.417 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.178 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 121 |\n",
"| n_updates | 80 |\n",
"| policy_gradient_loss | -0.114 |\n",
"| value_loss | 232 |\n",
"-----------------------------------------\n",
"-----------------------------------------\n",
"| rollout/ | |\n",
"| ep_len_mean | 6 |\n",
"| ep_rew_mean | 3.81 |\n",
"| time/ | |\n",
"| fps | 527 |\n",
"| iterations | 10 |\n",
"| time_elapsed | 38 |\n",
"| total_timesteps | 20480 |\n",
"| train/ | |\n",
"| approx_kl | 0.022704324 |\n",
"| clip_fraction | 0.379 |\n",
"| clip_range | 0.2 |\n",
"| entropy_loss | -9.45 |\n",
"| explained_variance | 0.194 |\n",
"| learning_rate | 0.0003 |\n",
"| loss | 108 |\n",
"| n_updates | 90 |\n",
"| policy_gradient_loss | -0.112 |\n",
"| value_loss | 216 |\n",
"-----------------------------------------\n"
]
},
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_timesteps = 20_000\n",
"model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
"model.learn(total_timesteps=total_timesteps)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model.save(\"dqn_wordle\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = PPO.load(\"dqn_wordle\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 7 18 1 19 16 3 3 3 2 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [16 9 5 14 4 3 3 3 3 3]\n",
" [ 7 18 1 19 16 3 3 3 2 3]\n",
" [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n",
"0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"env = gym_wordle.wordle.WordleEnv()\n",
"\n",
"for i in tqdm(range(1000)):\n",
" \n",
" state, info = env.reset()\n",
"\n",
" done = False\n",
"\n",
" wins = 0\n",
"\n",
" while not done:\n",
"\n",
" action, _states = model.predict(state, deterministic=True)\n",
"\n",
" state, reward, done, truncated, info = env.step(action)\n",
"\n",
" if info[\"correct\"]:\n",
" wins += 1\n",
"\n",
"print(state, reward, info)\n",
"\n",
"print(wins)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,38 +0,0 @@
import gym
import sys
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
import wordle_gym
import numpy as np
from tqdm import tqdm
def train (model, env, total_timesteps = 100000):
model.learn(total_timesteps=total_timesteps, progress_bar=True)
model.save("dqn_wordle")
def test(model, env, test_num=1000):
total_correct = 0
for i in tqdm(range(test_num)):
model = DQN.load("dqn_wordle")
env = gym.make("wordle-v0")
obs = env.reset()
done = False
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
return total_correct / test_num
if __name__ == "__main__":
env = gym.make("wordle-v0")
model = DQN("MlpPolicy", env, verbose=0)
print(env)
print(model)
train(model, env, total_timesteps=500000)
print(test(model, env))

7
gym_wordle/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
from gym.envs.registration import register
from .wordle import WordleEnv
register(
id='Wordle-v0',
entry_point='gym_wordle.wordle:WordleEnv'
)

Binary file not shown.

Binary file not shown.

93
gym_wordle/utils.py Normal file
View File

@@ -0,0 +1,93 @@
import numpy as np
import numpy.typing as npt
from pathlib import Path
_chars = ' abcdefghijklmnopqrstuvwxyz'
_char_d = {c: i for i, c in enumerate(_chars)}
def to_english(array: npt.NDArray[np.int64]) -> str:
"""Converts a numpy integer array into a corresponding English string.
Args:
array: Word in array (int) form. It is assumed that each integer in the
array is between 0,...,26 (inclusive).
Returns:
A (lowercase) string representation of the word.
"""
return ''.join(_chars[i] for i in array)
def to_array(word: str) -> npt.NDArray[np.int64]:
"""Converts a string of characters into a corresponding numpy array.
Args:
word: Word in string form. It is assumed that each character in the
string is either an empty space ' ' or lowercase alphabetical
character.
Returns:
An array representation of the word.
"""
return np.array([_char_d[c] for c in word])
def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]:
"""Loads a list of words in array form.
If specified, this will recompute the list from the human-readable list of
words, and save the results in array form.
Args:
category: Either 'guess' or 'solution', which corresponds to the list
of acceptable guess words and the list of acceptable solution words.
build: If True, recomputes and saves the array-version of the computed
list for future access.
Returns:
An array representation of the list of words specified by the category.
This array has two dimensions, and the number of columns is fixed at
five.
"""
assert category in {'guess', 'solution'}
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
if build:
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
with open(list_path, 'r') as f:
words = np.array([to_array(line.strip()) for line in f])
np.save(arr_path, words)
return np.load(arr_path)
def play():
"""Play Wordle yourself!"""
import gym
import gym_wordle
env = gym.make('Wordle-v0') # load the environment
env.reset()
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
done = False
while not done:
action = -1
# in general, the environment won't be forgiving if you input an
# invalid word, but for this function I want to let you screw up user
# input without consequence, so just loops until valid input is taken
while not env.action_space.contains(action):
guess = input('Guess: ')
action = env.unwrapped.action_space.index_of(to_array(guess))
state, reward, done, info = env.step(action)
env.render()
print(f"The word was {solution}")

285
gym_wordle/wordle.py Normal file
View File

@@ -0,0 +1,285 @@
import gymnasium as gym
import numpy as np
import numpy.typing as npt
from sty import fg, bg, ef, rs
from collections import Counter, defaultdict
from gym_wordle.utils import to_english, to_array, get_words
from typing import Optional
class WordList(gym.spaces.Discrete):
"""Super class for defining a space of valid words according to a specified
list.
The space is a subclass of gym.spaces.Discrete, where each element
corresponds to an index of a valid word in the word list. The obfuscation
is necessary for more direct implementation of RL algorithms, which expect
spaces of less sophisticated form.
In addition to the default methods of the Discrete space, it implements
a __getitem__ method for easy index lookup, and an index_of method to
convert potential words into their corresponding index (if they exist).
"""
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
"""
Args:
words: Collection of words in array form with shape (_, 5), where
each word is a row of the array. Each array element is an integer
between 0,...,26 (inclusive).
kwargs: See documentation for gym.spaces.MultiDiscrete
"""
super().__init__(words.shape[0], **kwargs)
self.words = words
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
"""Obtains the (int-encoded) word associated with the given index.
Args:
index: Index for the list of words.
Returns:
Associated word at the position specified by index.
"""
return self.words[index]
def index_of(self, word: npt.NDArray[np.int64]) -> int:
"""Given a word, determine its index in the list (if it exists),
otherwise returning -1 if no index exists.
Args:
word: Word to find in the word list.
Returns:
The index of the given word if it exists, otherwise -1.
"""
try:
index, = np.nonzero((word == self.words).all(axis=1))
return index[0]
except:
return -1
class SolutionList(WordList):
"""Space for *solution* words to the Wordle environment.
In the game Wordle, there are two different collections of words:
* "guesses", which the game accepts as valid words to use to guess the
answer.
* "solutions", which the game uses to choose solutions from.
Of course, the set of solutions is a strict subset of the set of guesses.
This class represents the set of solution words.
"""
def __init__(self, **kwargs):
"""
Args:
kwargs: See documentation for gym.spaces.MultiDiscrete
"""
words = get_words('solution')
super().__init__(words, **kwargs)
class WordleObsSpace(gym.spaces.Box):
"""Implementation of the state (observation) space in terms of gym
primitives, in this case, gym.spaces.Box.
The Wordle observation space can be thought of as a 6x5 array with two
channels:
- the character channel, indicating which characters are placed on the
board (unfilled rows are marked with the empty character, 0)
- the flag channel, indicating the in-game information associated with
each character's placement (green highlight, yellow highlight, etc.)
where there are 6 rows, one for each turn in the game, and 5 columns, since
the solution will always be a word of length 5.
For simplicity, and compatibility with stable_baselines algorithms,
this multichannel is modeled as a 6x10 array, where the two channels are
horizontally appended (along columns). Thus each row in the observation
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
c0...c4 and its associated flags are f0...f4.
"""
def __init__(self, **kwargs):
self.n_rows = 6
self.n_cols = 5
self.max_char = 26
self.max_flag = 4
low = np.zeros((self.n_rows, 2*self.n_cols))
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
np.full((self.n_rows, self.n_cols), self.max_flag)]
super().__init__(low, high, dtype=np.int64, **kwargs)
class GuessList(WordList):
"""Space for *guess* words to the Wordle environment.
This class represents the set of guess words.
"""
def __init__(self, **kwargs):
"""
Args:
kwargs: See documentation for gym.spaces.MultiDiscrete
"""
words = get_words('guess')
super().__init__(words, **kwargs)
class WordleEnv(gym.Env):
metadata = {'render.modes': ['human']}
# Character flag codes
no_char = 0
right_pos = 1
wrong_pos = 2
wrong_char = 3
def __init__(self):
super().__init__()
self.action_space = GuessList()
self.solution_space = SolutionList()
self.observation_space = WordleObsSpace()
self._highlights = {
self.right_pos: (bg.green, bg.rs),
self.wrong_pos: (bg.yellow, bg.rs),
self.wrong_char: ('', ''),
self.no_char: ('', ''),
}
self.n_rounds = 6
self.n_letters = 5
self.info = {'correct': False, 'guesses': defaultdict(int)}
def _highlighter(self, char: str, flag: int) -> str:
"""Terminal renderer functionality. Properly highlights a character
based on the flag associated with it.
Args:
char: Character in question.
flag: Associated flag, one of:
- 0: no character (render no background)
- 1: right position (render green background)
- 2: wrong position (render yellow background)
- 3: wrong character (render no background)
Returns:
Correct ASCII sequence producing the desired character in the
correct background.
"""
front, back = self._highlights[flag]
return front + char + back
def reset(self, seed=None, options=None):
"""Reset the environment to an initial state and returns an initial
observation.
Note: The observation space instance should be a Box space.
Returns:
state (object): The initial observation of the space.
"""
self.round = 0
self.solution = self.solution_space.sample()
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
self.info = {'correct': False, 'guesses': defaultdict(int)}
return self.state, self.info
def render(self, mode: str = 'human'):
"""Renders the Wordle environment.
Currently supported render modes:
- human: renders the Wordle game to the terminal.
Args:
mode: the mode to render with.
"""
if mode == 'human':
for row in self.state:
text = ''.join(map(
self._highlighter,
to_english(row[:self.n_letters]).upper(),
row[self.n_letters:]
))
print(text)
else:
super().render(mode=mode)
def step(self, action):
"""Run one step of the Wordle game. Every game must be previously
initialized by a call to the `reset` method.
Args:
action: Word guessed by the agent.
Returns:
state (object): Wordle game state after the guess.
reward (float): Reward associated with the guess.
done (bool): Whether the game has ended.
info (dict): Auxiliary diagnostic information.
"""
assert self.action_space.contains(action), 'Invalid word!'
action = self.action_space[action]
solution = self.solution_space[self.solution]
self.state[self.round][:self.n_letters] = action
counter = Counter()
for i, char in enumerate(action):
flag_i = i + self.n_letters
counter[char] += 1
if char == solution[i]:
self.state[self.round, flag_i] = self.right_pos
elif counter[char] <= (char == solution).sum():
self.state[self.round, flag_i] = self.wrong_pos
else:
self.state[self.round, flag_i] = self.wrong_char
self.round += 1
correct = (action == solution).all()
game_over = (self.round == self.n_rounds)
done = correct or game_over
reward = 0
# correct spot
reward += np.sum(self.state[:, 5:] == 1) * 2
# correct letter not correct spot
reward += np.sum(self.state[:, 5:] == 2) * 1
# incorrect letter
reward += np.sum(self.state[:, 5:] == 3) * -1
# guess same word as before
hashable_action = to_english(action)
if hashable_action in self.info['guesses']:
reward += -10 * self.info['guesses'][hashable_action]
else: # guess different word
reward += 10
self.info['guesses'][hashable_action] += 1
# for game ending in win or loss
reward += 10 if correct else -10 if done else 0
self.info['correct'] = correct
# observation, reward, terminated, truncated, info
return self.state, reward, done, False, self.info

View File

@@ -1,9 +0,0 @@
from gym.envs.registration import register
register(
id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv",
)
register(
id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv",
)

View File

@@ -1,15 +0,0 @@
from enum import Enum
from typing import List
class StrategyType(Enum):
RANDOM = 1
ELIMINATION = 2
PROBABILITY = 3
class Strategy:
def __init__(self, type: StrategyType):
self.type = type
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
raise NotImplementedError("Strategy.get_best_word() not implemented")

View File

@@ -1,2 +0,0 @@
def get_best_word(state):

View File

@@ -1,20 +0,0 @@
from random import sample
from typing import List
from base import Strategy
from base import StrategyType
from utils import freq
class Random(Strategy):
def __init__(self):
self.words = freq.get_5_letter_word_freqs()
super().__init__(StrategyType.RANDOM)
def get_best_word(self, state: List[List[int]]):
if __name__ == "__main__":
r = Random()
print(r.get_best_word([]))

View File

@@ -1,29 +0,0 @@
from random import sample
from typing import List
from base import Strategy
from base import StrategyType
from utils import freq
class Random(Strategy):
def __init__(self):
self.words = freq.get_5_letter_word_freqs()
super().__init__(StrategyType.RANDOM)
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
correct_letters = []
regex = ""
for g, s in zip(guesses, state):
for c, s in zip(g, s):
if s == 2:
correct_letters.append(c)
regex += c
if __name__ == "__main__":
r = Random()
print(r.get_best_word([]))

View File

@@ -1,27 +0,0 @@
from os import path
def get_5_letter_word_freqs():
"""
Returns a list of words with 5 letters.
"""
FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt")
lines = read_file(FILEPATH)
return {k:v for k, v in get_freq(lines).items() if len(k) == 5}
def read_file(filename):
"""
Reads a file and returns a list of words and frequencies
"""
with open(filename, 'r') as f:
return f.readlines()
def get_freq(lines):
"""
Returns a dictionary of words and their frequencies
"""
freqs = {}
for word, freq in map(lambda x: x.split("\t"), lines):
freqs[word] = int(freq)
return freqs

View File

@@ -1,131 +0,0 @@
import os
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from enum import Enum
from collections import Counter
import numpy as np
WORD_LENGTH = 5
TOTAL_GUESSES = 6
SOLUTION_PATH = "../words/solution.csv"
VALID_WORDS_PATH = "../words/guess.csv"
class LetterState(Enum):
ABSENT = 0
PRESENT = 1
CORRECT_POSITION = 2
class WordleEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def _current_path(self):
return os.path.dirname(os.path.abspath(__file__))
def _read_solutions(self):
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
def _get_valid_words(self):
words = []
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
words.append((word, Counter(word)))
return words
def get_valid(self):
return self._valid_words
def __init__(self):
self._solutions = self._read_solutions()
self._valid_words = self._get_valid_words()
self.action_space = spaces.Discrete(len(self._valid_words))
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
np.random.seed(0)
self.reset()
def _check_guess(self, guess, guess_counter):
c = guess_counter & self.solution_ct
result = []
correct = True
reward = 0
for i, char in enumerate(guess):
if c.get(char, 0) > 0:
if self.solution[i] == char:
result.append(2)
reward += 2
else:
result.append(1)
correct = False
reward += 1
c[char] -= 1
else:
result.append(0)
correct = False
return result, correct, reward
def step(self, action):
"""
action: index of word in valid_words
returns:
observation: (TOTAL_GUESSES, WORD_LENGTH)
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
done: True if game over, w/ or w/o correct answer
additional_info: empty
"""
guess, guess_counter = self._valid_words[action]
if guess in self.guesses:
return self.obs, -1, False, {}
self.guesses.append(guess)
result, correct, reward = self._check_guess(guess, guess_counter)
done = False
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
self.guess_no += 1
if correct:
done = True
reward = 1200
if self.guess_no == TOTAL_GUESSES:
done = True
if not correct:
reward = -15
return self.obs, reward, done, {}
def reset(self):
self.solution = self._solutions[np.random.randint(len(self._solutions))]
self.solution_ct = Counter(self.solution)
self.guess_no = 0
self.guesses = []
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
return self.obs
def render(self, mode="human"):
m = {
0: "",
1: "🟨",
2: "🟩"
}
print("Solution:", self.solution)
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
o_n = "".join(map(lambda x: m[x], o))
print(g, o_n)
def close(self):
pass
if __name__ == "__main__":
env = WordleEnv()
print(env.action_space)
print(env.observation_space)
print(env.solution)
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))