mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-10 07:04:45 +00:00
still doesnt train
This commit is contained in:
parent
f641d77c47
commit
848ea719b7
@ -1,47 +1,40 @@
|
|||||||
# %%
|
|
||||||
from stable_baselines3 import DQN
|
|
||||||
import numpy as np
|
|
||||||
import wordle.state
|
|
||||||
import gym
|
import gym
|
||||||
|
import sys
|
||||||
|
from stable_baselines3 import DQN
|
||||||
|
from stable_baselines3.common.env_util import make_vec_env
|
||||||
|
import wordle_gym
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
# %%
|
def train (model, env, total_timesteps = 100000):
|
||||||
env = gym.make("WordleEnvFull-v0")
|
|
||||||
|
|
||||||
print(env)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
total_timesteps = 100000
|
|
||||||
model = DQN("MlpPolicy", env, verbose=0)
|
|
||||||
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
||||||
|
|
||||||
# %%
|
|
||||||
def test(model):
|
|
||||||
|
|
||||||
end_rewards = []
|
|
||||||
|
|
||||||
for i in range(1000):
|
|
||||||
|
|
||||||
state = env.reset()
|
|
||||||
|
|
||||||
done = False
|
|
||||||
|
|
||||||
while not done:
|
|
||||||
|
|
||||||
action, _states = model.predict(state, deterministic=True)
|
|
||||||
|
|
||||||
state, reward, done, info = env.step(action)
|
|
||||||
|
|
||||||
end_rewards.append(reward == 0)
|
|
||||||
|
|
||||||
return np.sum(end_rewards) / len(end_rewards)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
model.save("dqn_wordle")
|
model.save("dqn_wordle")
|
||||||
|
|
||||||
# %%
|
def test(model, env, test_num=1000):
|
||||||
|
|
||||||
|
total_correct = 0
|
||||||
|
|
||||||
|
for i in tqdm(range(test_num)):
|
||||||
|
|
||||||
model = DQN.load("dqn_wordle")
|
model = DQN.load("dqn_wordle")
|
||||||
|
|
||||||
# %%
|
env = gym.make("wordle-v0")
|
||||||
print(test(model))
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action, _states = model.predict(obs)
|
||||||
|
obs, rewards, done, info = env.step(action)
|
||||||
|
|
||||||
|
print(action, obs, rewards)
|
||||||
|
|
||||||
|
return total_correct / test_num
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
env = gym.make("wordle-v0")
|
||||||
|
model = DQN("MlpPolicy", env, verbose=0)
|
||||||
|
print(env)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
train(model, env, total_timesteps=10000)
|
||||||
|
print(test(model, env, test_num=1))
|
@ -1,83 +0,0 @@
|
|||||||
from gym.envs.registration import (
|
|
||||||
registry,
|
|
||||||
register,
|
|
||||||
make,
|
|
||||||
spec,
|
|
||||||
load_env_plugins as _load_env_plugins,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Classic
|
|
||||||
# ----------------------------------------
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv10-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv10",
|
|
||||||
max_episode_steps=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv100-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv100",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv100OneAction-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv100OneAction",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv100TwoAction-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv100TwoAction",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv100FullAction-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv100FullAction",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv100WithMask-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv100WithMask",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv1000-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv1000",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv1000WithMask-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv1000WithMask",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnv1000FullAction-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnv1000FullAction",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnvFull-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnvFull",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnvReal-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnvReal",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id="WordleEnvRealWithMask-v0",
|
|
||||||
entry_point="wordle.wordle:WordleEnvRealWithMask",
|
|
||||||
max_episode_steps=500,
|
|
||||||
)
|
|
@ -1,3 +0,0 @@
|
|||||||
WORDLE_CHARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
|
||||||
WORDLE_N = 5
|
|
||||||
REWARD = 10
|
|
162
wordle/state.py
162
wordle/state.py
@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
Keep the state in a 1D int array
|
|
||||||
|
|
||||||
index[0] = remaining steps
|
|
||||||
Rest of data is laid out as binary array
|
|
||||||
|
|
||||||
[1..27] = whether char has been guessed or not
|
|
||||||
|
|
||||||
[[status, status, status, status, status]
|
|
||||||
for _ in "ABCD..."]
|
|
||||||
where status has codes
|
|
||||||
[1, 0, 0] - char is definitely not in this spot
|
|
||||||
[0, 1, 0] - char is maybe in this spot
|
|
||||||
[0, 0, 1] - char is definitely in this spot
|
|
||||||
"""
|
|
||||||
import collections
|
|
||||||
from typing import List
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from wordle.const import WORDLE_CHARS, WORDLE_N
|
|
||||||
|
|
||||||
|
|
||||||
WordleState = np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
def get_nvec(max_turns: int):
|
|
||||||
return [max_turns] + [2] * len(WORDLE_CHARS) + [2] * 3 * WORDLE_N * len(WORDLE_CHARS)
|
|
||||||
|
|
||||||
|
|
||||||
def new(max_turns: int) -> WordleState:
|
|
||||||
return np.array(
|
|
||||||
[max_turns] + [0] * len(WORDLE_CHARS) + [0, 1, 0] * WORDLE_N * len(WORDLE_CHARS),
|
|
||||||
dtype=np.int32)
|
|
||||||
|
|
||||||
|
|
||||||
def remaining_steps(state: WordleState) -> int:
|
|
||||||
return state[0]
|
|
||||||
|
|
||||||
|
|
||||||
NO = 0
|
|
||||||
SOMEWHERE = 1
|
|
||||||
YES = 2
|
|
||||||
|
|
||||||
|
|
||||||
def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
|
|
||||||
"""
|
|
||||||
return a copy of state that has been updated to new state
|
|
||||||
|
|
||||||
From a mask we need slighty different logic since we don't know the
|
|
||||||
goal word.
|
|
||||||
|
|
||||||
:param state:
|
|
||||||
:param word:
|
|
||||||
:param goal_word:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
state = state.copy()
|
|
||||||
|
|
||||||
prior_yes = []
|
|
||||||
prior_maybe = []
|
|
||||||
# We need two passes because first pass sets definitely yesses
|
|
||||||
# second pass sets the no's for those who aren't already yes
|
|
||||||
state[0] -= 1
|
|
||||||
for i, c in enumerate(word):
|
|
||||||
cint = ord(c) - ord(WORDLE_CHARS[0])
|
|
||||||
offset = 1 + len(WORDLE_CHARS) + cint * WORDLE_N * 3
|
|
||||||
state[1 + cint] = 1
|
|
||||||
if mask[i] == YES:
|
|
||||||
prior_yes.append(c)
|
|
||||||
# char at position i = yes, all other chars at position i == no
|
|
||||||
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
|
||||||
for ocint in range(len(WORDLE_CHARS)):
|
|
||||||
if ocint != cint:
|
|
||||||
oc_offset = 1 + len(WORDLE_CHARS) + ocint * WORDLE_N * 3
|
|
||||||
state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
|
|
||||||
|
|
||||||
for i, c in enumerate(word):
|
|
||||||
cint = ord(c) - ord(WORDLE_CHARS[0])
|
|
||||||
offset = 1 + len(WORDLE_CHARS) + cint * WORDLE_N * 3
|
|
||||||
if mask[i] == SOMEWHERE:
|
|
||||||
prior_maybe.append(c)
|
|
||||||
# Char at position i = no, other chars stay as they are
|
|
||||||
state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
|
|
||||||
elif mask[i] == NO:
|
|
||||||
# Need to check this first in case there's prior maybe + yes
|
|
||||||
if c in prior_maybe:
|
|
||||||
# Then the maybe could be anywhere except here
|
|
||||||
state[offset+3*i:offset+3*i+3] = [1, 0, 0]
|
|
||||||
elif c in prior_yes:
|
|
||||||
# No maybe, definitely a yes, so it's zero everywhere except the yesses
|
|
||||||
for j in range(WORDLE_N):
|
|
||||||
# Only flip no if previously was maybe
|
|
||||||
if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
|
|
||||||
state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
|
|
||||||
else:
|
|
||||||
# Just straight up no
|
|
||||||
state[offset:offset+3*WORDLE_N] = [1, 0, 0]*WORDLE_N
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def get_mask(word: str, goal_word: str) -> List[int]:
|
|
||||||
# Definite yesses first
|
|
||||||
mask = [0, 0, 0, 0, 0]
|
|
||||||
counts = collections.Counter(goal_word)
|
|
||||||
for i, c in enumerate(word):
|
|
||||||
if goal_word[i] == c:
|
|
||||||
mask[i] = 2
|
|
||||||
counts[c] -= 1
|
|
||||||
|
|
||||||
for i, c in enumerate(word):
|
|
||||||
if mask[i] == 2:
|
|
||||||
continue
|
|
||||||
elif c in counts:
|
|
||||||
if counts[c] > 0:
|
|
||||||
mask[i] = 1
|
|
||||||
counts[c] -= 1
|
|
||||||
else:
|
|
||||||
for j in range(i+1, len(mask)):
|
|
||||||
if mask[j] == 2:
|
|
||||||
continue
|
|
||||||
mask[j] = 0
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
||||||
"""
|
|
||||||
return a copy of state that has been updated to new state
|
|
||||||
|
|
||||||
:param state:
|
|
||||||
:param word:
|
|
||||||
:param goal_word:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
mask = get_mask(word, goal_word)
|
|
||||||
return update_from_mask(state, word, mask)
|
|
||||||
|
|
||||||
|
|
||||||
def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
||||||
state = state.copy()
|
|
||||||
|
|
||||||
state[0] -= 1
|
|
||||||
for i, c in enumerate(word):
|
|
||||||
cint = ord(c) - ord(WORDLE_CHARS[0])
|
|
||||||
offset = 1 + len(WORDLE_CHARS) + cint * WORDLE_N * 3
|
|
||||||
state[1 + cint] = 1
|
|
||||||
if goal_word[i] == c:
|
|
||||||
# char at position i = yes, all other chars at position i == no
|
|
||||||
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
|
||||||
for ocint in range(len(WORDLE_CHARS)):
|
|
||||||
if ocint != cint:
|
|
||||||
oc_offset = 1 + len(WORDLE_CHARS) + ocint * WORDLE_N * 3
|
|
||||||
state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
|
|
||||||
elif c in goal_word:
|
|
||||||
# Char at position i = no, other chars stay as they are
|
|
||||||
state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
|
|
||||||
else:
|
|
||||||
# Char at all positions = no
|
|
||||||
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
173
wordle/wordle.py
173
wordle/wordle.py
@ -1,173 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import gym
|
|
||||||
from gym import spaces
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import wordle.state
|
|
||||||
from wordle.const import WORDLE_N, REWARD
|
|
||||||
|
|
||||||
CUR_PATH = os.environ.get('PYTHONPATH', '.')
|
|
||||||
import os
|
|
||||||
dirname = os.path.dirname(__file__)
|
|
||||||
VALID_WORDS_PATH = f'{dirname}/wordle_words.txt'
|
|
||||||
|
|
||||||
|
|
||||||
def _load_words(limit: Optional[int]=None) -> List[str]:
|
|
||||||
with open(VALID_WORDS_PATH, 'r') as f:
|
|
||||||
lines = [x.strip().upper() for x in f.readlines()]
|
|
||||||
if not limit:
|
|
||||||
return lines
|
|
||||||
else:
|
|
||||||
return lines[:limit]
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnvBase(gym.Env):
|
|
||||||
"""
|
|
||||||
Actions:
|
|
||||||
Can play any 5 letter word in vocabulary
|
|
||||||
* 13k for full vocab
|
|
||||||
State space is defined as:
|
|
||||||
* 6 possibilities for turns (WORDLE_TURNS)
|
|
||||||
* Each VALID_CHAR has a state of 0/1 for whether it's been guessed before
|
|
||||||
* For each in VALID_CHARS [A-Z] can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
|
|
||||||
for full game, this is (3^5)^26
|
|
||||||
Each state has 1 + 5*26 possibilities
|
|
||||||
Reward:
|
|
||||||
Reward is 10 for guessing the right word, -10 for not guessing the right word after 6 guesses.
|
|
||||||
Starting State:
|
|
||||||
Random goal word
|
|
||||||
Initial state with turn 0, all chars Unvisited + Maybe
|
|
||||||
"""
|
|
||||||
def __init__(self, words: List[str],
|
|
||||||
max_turns: int,
|
|
||||||
allowable_words: Optional[int] = None,
|
|
||||||
frequencies: Optional[List[float]]=None,
|
|
||||||
mask_based_state_updates: bool=False):
|
|
||||||
assert all(len(w) == WORDLE_N for w in words), f'Not all words of length {WORDLE_N}, {words}'
|
|
||||||
self.words = words
|
|
||||||
self.max_turns = max_turns
|
|
||||||
self.allowable_words = allowable_words
|
|
||||||
self.mask_based_state_updates = mask_based_state_updates
|
|
||||||
if not self.allowable_words:
|
|
||||||
self.allowable_words = len(self.words)
|
|
||||||
|
|
||||||
self.frequencies = None
|
|
||||||
if frequencies:
|
|
||||||
assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
|
|
||||||
self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
|
|
||||||
|
|
||||||
self.action_space = spaces.Discrete(len(self.words))
|
|
||||||
self.observation_space = spaces.MultiDiscrete(wordle.state.get_nvec(self.max_turns))
|
|
||||||
|
|
||||||
self.done = True
|
|
||||||
self.goal_word: int = -1
|
|
||||||
|
|
||||||
self.state: wordle.state.WordleState = None
|
|
||||||
self.state_updater = wordle.state.update
|
|
||||||
if self.mask_based_state_updates:
|
|
||||||
self.state_updater = wordle.state.update_mask
|
|
||||||
|
|
||||||
def step(self, action: int):
|
|
||||||
if self.done:
|
|
||||||
raise ValueError(
|
|
||||||
"You are calling 'step()' even though this "
|
|
||||||
"environment has already returned done = True. You "
|
|
||||||
"should always call 'reset()' once you receive 'done = "
|
|
||||||
"True' -- any further steps are undefined behavior."
|
|
||||||
)
|
|
||||||
self.state = self.state_updater(state=self.state,
|
|
||||||
word=self.words[action],
|
|
||||||
goal_word=self.words[self.goal_word])
|
|
||||||
|
|
||||||
reward = 0
|
|
||||||
if action == self.goal_word:
|
|
||||||
self.done = True
|
|
||||||
#reward = REWARD
|
|
||||||
if wordle.state.remaining_steps(self.state) == self.max_turns-1:
|
|
||||||
reward = 0#-10*REWARD # No reward for guessing off the bat
|
|
||||||
else:
|
|
||||||
#reward = REWARD*(self.state.remaining_steps() + 1) / self.max_turns
|
|
||||||
reward = REWARD
|
|
||||||
elif wordle.state.remaining_steps(self.state) == 0:
|
|
||||||
self.done = True
|
|
||||||
reward = -REWARD
|
|
||||||
|
|
||||||
return self.state.copy(), reward, self.done, False, {"goal_id": self.goal_word}
|
|
||||||
|
|
||||||
def reset(self, options = None, seed: Optional[int] = None):
|
|
||||||
self.state = wordle.state.new(self.max_turns)
|
|
||||||
self.done = False
|
|
||||||
self.goal_word = int(np.random.random()*self.allowable_words)
|
|
||||||
|
|
||||||
return self.state.copy(), {"goal_id": self.goal_word}
|
|
||||||
|
|
||||||
def set_goal_word(self, goal_word: str):
|
|
||||||
self.goal_word = self.words.index(goal_word)
|
|
||||||
|
|
||||||
def set_goal_id(self, goal_id: int):
|
|
||||||
self.goal_word = goal_id
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv10(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(10), max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv100(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(100), max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv100OneAction(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(100), allowable_words=1, max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv100WithMask(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(100), max_turns=6,
|
|
||||||
mask_based_state_updates=True)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv100TwoAction(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(100), allowable_words=2, max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv100FullAction(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(), allowable_words=100, max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv1000(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(1000), max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv1000WithMask(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(1000), max_turns=6,
|
|
||||||
mask_based_state_updates=True)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv1000FullAction(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(), allowable_words=1000, max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnvFull(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(), max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnvReal(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(), allowable_words=2315, max_turns=6)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnvRealWithMask(WordleEnvBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(words=_load_words(), allowable_words=2315, max_turns=6,
|
|
||||||
mask_based_state_updates=True)
|
|
9
wordle_gym/__init__.py
Normal file
9
wordle_gym/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv",
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv",
|
||||||
|
)
|
0
wordle_gym/envs/__init__.py
Normal file
0
wordle_gym/envs/__init__.py
Normal file
15
wordle_gym/envs/strategies/base.py
Normal file
15
wordle_gym/envs/strategies/base.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class StrategyType(Enum):
|
||||||
|
RANDOM = 1
|
||||||
|
ELIMINATION = 2
|
||||||
|
PROBABILITY = 3
|
||||||
|
|
||||||
|
class Strategy:
|
||||||
|
def __init__(self, type: StrategyType):
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
|
||||||
|
raise NotImplementedError("Strategy.get_best_word() not implemented")
|
2
wordle_gym/envs/strategies/elimination.py
Normal file
2
wordle_gym/envs/strategies/elimination.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
def get_best_word(state):
|
||||||
|
|
20
wordle_gym/envs/strategies/probabilistic.py
Normal file
20
wordle_gym/envs/strategies/probabilistic.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from random import sample
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from base import Strategy
|
||||||
|
from base import StrategyType
|
||||||
|
|
||||||
|
from utils import freq
|
||||||
|
|
||||||
|
class Random(Strategy):
|
||||||
|
def __init__(self):
|
||||||
|
self.words = freq.get_5_letter_word_freqs()
|
||||||
|
super().__init__(StrategyType.RANDOM)
|
||||||
|
|
||||||
|
def get_best_word(self, state: List[List[int]]):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
r = Random()
|
||||||
|
print(r.get_best_word([]))
|
29
wordle_gym/envs/strategies/rand.py
Normal file
29
wordle_gym/envs/strategies/rand.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from random import sample
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from base import Strategy
|
||||||
|
from base import StrategyType
|
||||||
|
|
||||||
|
from utils import freq
|
||||||
|
|
||||||
|
class Random(Strategy):
|
||||||
|
def __init__(self):
|
||||||
|
self.words = freq.get_5_letter_word_freqs()
|
||||||
|
super().__init__(StrategyType.RANDOM)
|
||||||
|
|
||||||
|
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
|
||||||
|
correct_letters = []
|
||||||
|
regex = ""
|
||||||
|
for g, s in zip(guesses, state):
|
||||||
|
for c, s in zip(g, s):
|
||||||
|
if s == 2:
|
||||||
|
correct_letters.append(c)
|
||||||
|
regex += c
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
r = Random()
|
||||||
|
print(r.get_best_word([]))
|
27
wordle_gym/envs/strategies/utils/freq.py
Normal file
27
wordle_gym/envs/strategies/utils/freq.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from os import path
|
||||||
|
|
||||||
|
def get_5_letter_word_freqs():
|
||||||
|
"""
|
||||||
|
Returns a list of words with 5 letters.
|
||||||
|
"""
|
||||||
|
FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt")
|
||||||
|
lines = read_file(FILEPATH)
|
||||||
|
return {k:v for k, v in get_freq(lines).items() if len(k) == 5}
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(filename):
|
||||||
|
"""
|
||||||
|
Reads a file and returns a list of words and frequencies
|
||||||
|
"""
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
return f.readlines()
|
||||||
|
|
||||||
|
|
||||||
|
def get_freq(lines):
|
||||||
|
"""
|
||||||
|
Returns a dictionary of words and their frequencies
|
||||||
|
"""
|
||||||
|
freqs = {}
|
||||||
|
for word, freq in map(lambda x: x.split("\t"), lines):
|
||||||
|
freqs[word] = int(freq)
|
||||||
|
return freqs
|
131
wordle_gym/envs/wordle_env.py
Normal file
131
wordle_gym/envs/wordle_env.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym import error, spaces, utils
|
||||||
|
from gym.utils import seeding
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from collections import Counter
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
WORD_LENGTH = 5
|
||||||
|
TOTAL_GUESSES = 6
|
||||||
|
SOLUTION_PATH = "../words/solution.csv"
|
||||||
|
VALID_WORDS_PATH = "../words/guess.csv"
|
||||||
|
|
||||||
|
class LetterState(Enum):
|
||||||
|
ABSENT = 0
|
||||||
|
PRESENT = 1
|
||||||
|
CORRECT_POSITION = 2
|
||||||
|
|
||||||
|
|
||||||
|
class WordleEnv(gym.Env):
|
||||||
|
metadata = {"render.modes": ["human"]}
|
||||||
|
|
||||||
|
def _current_path(self):
|
||||||
|
return os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def _read_solutions(self):
|
||||||
|
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
|
||||||
|
|
||||||
|
def _get_valid_words(self):
|
||||||
|
words = []
|
||||||
|
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
|
||||||
|
words.append((word, Counter(word)))
|
||||||
|
return words
|
||||||
|
|
||||||
|
def get_valid(self):
|
||||||
|
return self._valid_words
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._solutions = self._read_solutions()
|
||||||
|
self._valid_words = self._get_valid_words()
|
||||||
|
self.action_space = spaces.Discrete(len(self._valid_words))
|
||||||
|
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
|
||||||
|
np.random.seed(0)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def _check_guess(self, guess, guess_counter):
|
||||||
|
c = guess_counter & self.solution_ct
|
||||||
|
result = []
|
||||||
|
correct = True
|
||||||
|
reward = 0
|
||||||
|
for i, char in enumerate(guess):
|
||||||
|
if c.get(char, 0) > 0:
|
||||||
|
if self.solution[i] == char:
|
||||||
|
result.append(2)
|
||||||
|
reward += 2
|
||||||
|
else:
|
||||||
|
result.append(1)
|
||||||
|
correct = False
|
||||||
|
reward += 1
|
||||||
|
c[char] -= 1
|
||||||
|
else:
|
||||||
|
result.append(0)
|
||||||
|
correct = False
|
||||||
|
return result, correct, reward
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
action: index of word in valid_words
|
||||||
|
|
||||||
|
returns:
|
||||||
|
observation: (TOTAL_GUESSES, WORD_LENGTH)
|
||||||
|
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
|
||||||
|
done: True if game over, w/ or w/o correct answer
|
||||||
|
additional_info: empty
|
||||||
|
"""
|
||||||
|
guess, guess_counter = self._valid_words[action]
|
||||||
|
if guess in self.guesses:
|
||||||
|
return self.obs, -1, False, {}
|
||||||
|
self.guesses.append(guess)
|
||||||
|
result, correct, reward = self._check_guess(guess, guess_counter)
|
||||||
|
done = False
|
||||||
|
|
||||||
|
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
|
||||||
|
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
|
||||||
|
|
||||||
|
self.guess_no += 1
|
||||||
|
if correct:
|
||||||
|
done = True
|
||||||
|
reward = 1200
|
||||||
|
if self.guess_no == TOTAL_GUESSES:
|
||||||
|
done = True
|
||||||
|
if not correct:
|
||||||
|
reward = -15
|
||||||
|
return self.obs, reward, done, {}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.solution = self._solutions[np.random.randint(len(self._solutions))]
|
||||||
|
self.solution_ct = Counter(self.solution)
|
||||||
|
self.guess_no = 0
|
||||||
|
self.guesses = []
|
||||||
|
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
|
||||||
|
return self.obs
|
||||||
|
|
||||||
|
def render(self, mode="human"):
|
||||||
|
m = {
|
||||||
|
0: "⬜",
|
||||||
|
1: "🟨",
|
||||||
|
2: "🟩"
|
||||||
|
}
|
||||||
|
print("Solution:", self.solution)
|
||||||
|
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
|
||||||
|
o_n = "".join(map(lambda x: m[x], o))
|
||||||
|
print(g, o_n)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
env = WordleEnv()
|
||||||
|
print(env.action_space)
|
||||||
|
print(env.observation_space)
|
||||||
|
print(env.solution)
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
||||||
|
print(env.step(0))
|
File diff suppressed because it is too large
Load Diff
2315
wordle_gym/words/solution.csv
Normal file
2315
wordle_gym/words/solution.csv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user