mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-10 07:04:45 +00:00
attempt to use the other wordle gym, causing cuda errors
This commit is contained in:
parent
5ec123e0f1
commit
f641d77c47
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
**/data/*
|
||||
**/*.zip
|
||||
**/__pycache__
|
114
dqn_wordle.ipynb
114
dqn_wordle.ipynb
@ -1,114 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import gym_wordle\n",
|
||||
"from stable_baselines3 import DQN\n",
|
||||
"import numpy as np\n",
|
||||
"import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make(\"Wordle-v0\")\n",
|
||||
"\n",
|
||||
"print(env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"total_timesteps = 100000\n",
|
||||
"model = DQN(\"MlpPolicy\", env, verbose=0)\n",
|
||||
"model.learn(total_timesteps=total_timesteps, progress_bar=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test(model):\n",
|
||||
"\n",
|
||||
" end_rewards = []\n",
|
||||
"\n",
|
||||
" for i in range(1000):\n",
|
||||
" \n",
|
||||
" state = env.reset()\n",
|
||||
"\n",
|
||||
" done = False\n",
|
||||
"\n",
|
||||
" while not done:\n",
|
||||
"\n",
|
||||
" action, _states = model.predict(state, deterministic=True)\n",
|
||||
"\n",
|
||||
" state, reward, done, info = env.step(action)\n",
|
||||
" \n",
|
||||
" end_rewards.append(reward == 0)\n",
|
||||
" \n",
|
||||
" return np.sum(end_rewards) / len(end_rewards)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = DQN.load(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(test(model))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
47
dqn_wordle.py
Normal file
47
dqn_wordle.py
Normal file
@ -0,0 +1,47 @@
|
||||
# %%
|
||||
from stable_baselines3 import DQN
|
||||
import numpy as np
|
||||
import wordle.state
|
||||
import gym
|
||||
|
||||
# %%
|
||||
env = gym.make("WordleEnvFull-v0")
|
||||
|
||||
print(env)
|
||||
|
||||
# %%
|
||||
total_timesteps = 100000
|
||||
model = DQN("MlpPolicy", env, verbose=0)
|
||||
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
||||
|
||||
# %%
|
||||
def test(model):
|
||||
|
||||
end_rewards = []
|
||||
|
||||
for i in range(1000):
|
||||
|
||||
state = env.reset()
|
||||
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
|
||||
action, _states = model.predict(state, deterministic=True)
|
||||
|
||||
state, reward, done, info = env.step(action)
|
||||
|
||||
end_rewards.append(reward == 0)
|
||||
|
||||
return np.sum(end_rewards) / len(end_rewards)
|
||||
|
||||
# %%
|
||||
model.save("dqn_wordle")
|
||||
|
||||
# %%
|
||||
model = DQN.load("dqn_wordle")
|
||||
|
||||
# %%
|
||||
print(test(model))
|
||||
|
||||
|
83
wordle/__init__.py
Normal file
83
wordle/__init__.py
Normal file
@ -0,0 +1,83 @@
|
||||
from gym.envs.registration import (
|
||||
registry,
|
||||
register,
|
||||
make,
|
||||
spec,
|
||||
load_env_plugins as _load_env_plugins,
|
||||
)
|
||||
|
||||
|
||||
# Classic
|
||||
# ----------------------------------------
|
||||
|
||||
register(
|
||||
id="WordleEnv10-v0",
|
||||
entry_point="wordle.wordle:WordleEnv10",
|
||||
max_episode_steps=200,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv100-v0",
|
||||
entry_point="wordle.wordle:WordleEnv100",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv100OneAction-v0",
|
||||
entry_point="wordle.wordle:WordleEnv100OneAction",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv100TwoAction-v0",
|
||||
entry_point="wordle.wordle:WordleEnv100TwoAction",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv100FullAction-v0",
|
||||
entry_point="wordle.wordle:WordleEnv100FullAction",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv100WithMask-v0",
|
||||
entry_point="wordle.wordle:WordleEnv100WithMask",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv1000-v0",
|
||||
entry_point="wordle.wordle:WordleEnv1000",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv1000WithMask-v0",
|
||||
entry_point="wordle.wordle:WordleEnv1000WithMask",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnv1000FullAction-v0",
|
||||
entry_point="wordle.wordle:WordleEnv1000FullAction",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnvFull-v0",
|
||||
entry_point="wordle.wordle:WordleEnvFull",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnvReal-v0",
|
||||
entry_point="wordle.wordle:WordleEnvReal",
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
register(
|
||||
id="WordleEnvRealWithMask-v0",
|
||||
entry_point="wordle.wordle:WordleEnvRealWithMask",
|
||||
max_episode_steps=500,
|
||||
)
|
3
wordle/const.py
Normal file
3
wordle/const.py
Normal file
@ -0,0 +1,3 @@
|
||||
WORDLE_CHARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
WORDLE_N = 5
|
||||
REWARD = 10
|
162
wordle/state.py
Normal file
162
wordle/state.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Keep the state in a 1D int array
|
||||
|
||||
index[0] = remaining steps
|
||||
Rest of data is laid out as binary array
|
||||
|
||||
[1..27] = whether char has been guessed or not
|
||||
|
||||
[[status, status, status, status, status]
|
||||
for _ in "ABCD..."]
|
||||
where status has codes
|
||||
[1, 0, 0] - char is definitely not in this spot
|
||||
[0, 1, 0] - char is maybe in this spot
|
||||
[0, 0, 1] - char is definitely in this spot
|
||||
"""
|
||||
import collections
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from wordle.const import WORDLE_CHARS, WORDLE_N
|
||||
|
||||
|
||||
WordleState = np.ndarray
|
||||
|
||||
|
||||
def get_nvec(max_turns: int):
|
||||
return [max_turns] + [2] * len(WORDLE_CHARS) + [2] * 3 * WORDLE_N * len(WORDLE_CHARS)
|
||||
|
||||
|
||||
def new(max_turns: int) -> WordleState:
|
||||
return np.array(
|
||||
[max_turns] + [0] * len(WORDLE_CHARS) + [0, 1, 0] * WORDLE_N * len(WORDLE_CHARS),
|
||||
dtype=np.int32)
|
||||
|
||||
|
||||
def remaining_steps(state: WordleState) -> int:
|
||||
return state[0]
|
||||
|
||||
|
||||
NO = 0
|
||||
SOMEWHERE = 1
|
||||
YES = 2
|
||||
|
||||
|
||||
def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
|
||||
"""
|
||||
return a copy of state that has been updated to new state
|
||||
|
||||
From a mask we need slighty different logic since we don't know the
|
||||
goal word.
|
||||
|
||||
:param state:
|
||||
:param word:
|
||||
:param goal_word:
|
||||
:return:
|
||||
"""
|
||||
state = state.copy()
|
||||
|
||||
prior_yes = []
|
||||
prior_maybe = []
|
||||
# We need two passes because first pass sets definitely yesses
|
||||
# second pass sets the no's for those who aren't already yes
|
||||
state[0] -= 1
|
||||
for i, c in enumerate(word):
|
||||
cint = ord(c) - ord(WORDLE_CHARS[0])
|
||||
offset = 1 + len(WORDLE_CHARS) + cint * WORDLE_N * 3
|
||||
state[1 + cint] = 1
|
||||
if mask[i] == YES:
|
||||
prior_yes.append(c)
|
||||
# char at position i = yes, all other chars at position i == no
|
||||
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
||||
for ocint in range(len(WORDLE_CHARS)):
|
||||
if ocint != cint:
|
||||
oc_offset = 1 + len(WORDLE_CHARS) + ocint * WORDLE_N * 3
|
||||
state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
|
||||
|
||||
for i, c in enumerate(word):
|
||||
cint = ord(c) - ord(WORDLE_CHARS[0])
|
||||
offset = 1 + len(WORDLE_CHARS) + cint * WORDLE_N * 3
|
||||
if mask[i] == SOMEWHERE:
|
||||
prior_maybe.append(c)
|
||||
# Char at position i = no, other chars stay as they are
|
||||
state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
|
||||
elif mask[i] == NO:
|
||||
# Need to check this first in case there's prior maybe + yes
|
||||
if c in prior_maybe:
|
||||
# Then the maybe could be anywhere except here
|
||||
state[offset+3*i:offset+3*i+3] = [1, 0, 0]
|
||||
elif c in prior_yes:
|
||||
# No maybe, definitely a yes, so it's zero everywhere except the yesses
|
||||
for j in range(WORDLE_N):
|
||||
# Only flip no if previously was maybe
|
||||
if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
|
||||
state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
|
||||
else:
|
||||
# Just straight up no
|
||||
state[offset:offset+3*WORDLE_N] = [1, 0, 0]*WORDLE_N
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def get_mask(word: str, goal_word: str) -> List[int]:
|
||||
# Definite yesses first
|
||||
mask = [0, 0, 0, 0, 0]
|
||||
counts = collections.Counter(goal_word)
|
||||
for i, c in enumerate(word):
|
||||
if goal_word[i] == c:
|
||||
mask[i] = 2
|
||||
counts[c] -= 1
|
||||
|
||||
for i, c in enumerate(word):
|
||||
if mask[i] == 2:
|
||||
continue
|
||||
elif c in counts:
|
||||
if counts[c] > 0:
|
||||
mask[i] = 1
|
||||
counts[c] -= 1
|
||||
else:
|
||||
for j in range(i+1, len(mask)):
|
||||
if mask[j] == 2:
|
||||
continue
|
||||
mask[j] = 0
|
||||
|
||||
return mask
|
||||
|
||||
def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
|
||||
"""
|
||||
return a copy of state that has been updated to new state
|
||||
|
||||
:param state:
|
||||
:param word:
|
||||
:param goal_word:
|
||||
:return:
|
||||
"""
|
||||
mask = get_mask(word, goal_word)
|
||||
return update_from_mask(state, word, mask)
|
||||
|
||||
|
||||
def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
||||
state = state.copy()
|
||||
|
||||
state[0] -= 1
|
||||
for i, c in enumerate(word):
|
||||
cint = ord(c) - ord(WORDLE_CHARS[0])
|
||||
offset = 1 + len(WORDLE_CHARS) + cint * WORDLE_N * 3
|
||||
state[1 + cint] = 1
|
||||
if goal_word[i] == c:
|
||||
# char at position i = yes, all other chars at position i == no
|
||||
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
||||
for ocint in range(len(WORDLE_CHARS)):
|
||||
if ocint != cint:
|
||||
oc_offset = 1 + len(WORDLE_CHARS) + ocint * WORDLE_N * 3
|
||||
state[oc_offset + 3 * i:oc_offset + 3 * i + 3] = [1, 0, 0]
|
||||
elif c in goal_word:
|
||||
# Char at position i = no, other chars stay as they are
|
||||
state[offset + 3 * i:offset + 3 * i + 3] = [1, 0, 0]
|
||||
else:
|
||||
# Char at all positions = no
|
||||
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
||||
|
||||
return state
|
||||
|
173
wordle/wordle.py
Normal file
173
wordle/wordle.py
Normal file
@ -0,0 +1,173 @@
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
|
||||
import wordle.state
|
||||
from wordle.const import WORDLE_N, REWARD
|
||||
|
||||
CUR_PATH = os.environ.get('PYTHONPATH', '.')
|
||||
import os
|
||||
dirname = os.path.dirname(__file__)
|
||||
VALID_WORDS_PATH = f'{dirname}/wordle_words.txt'
|
||||
|
||||
|
||||
def _load_words(limit: Optional[int]=None) -> List[str]:
|
||||
with open(VALID_WORDS_PATH, 'r') as f:
|
||||
lines = [x.strip().upper() for x in f.readlines()]
|
||||
if not limit:
|
||||
return lines
|
||||
else:
|
||||
return lines[:limit]
|
||||
|
||||
|
||||
class WordleEnvBase(gym.Env):
|
||||
"""
|
||||
Actions:
|
||||
Can play any 5 letter word in vocabulary
|
||||
* 13k for full vocab
|
||||
State space is defined as:
|
||||
* 6 possibilities for turns (WORDLE_TURNS)
|
||||
* Each VALID_CHAR has a state of 0/1 for whether it's been guessed before
|
||||
* For each in VALID_CHARS [A-Z] can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
|
||||
for full game, this is (3^5)^26
|
||||
Each state has 1 + 5*26 possibilities
|
||||
Reward:
|
||||
Reward is 10 for guessing the right word, -10 for not guessing the right word after 6 guesses.
|
||||
Starting State:
|
||||
Random goal word
|
||||
Initial state with turn 0, all chars Unvisited + Maybe
|
||||
"""
|
||||
def __init__(self, words: List[str],
|
||||
max_turns: int,
|
||||
allowable_words: Optional[int] = None,
|
||||
frequencies: Optional[List[float]]=None,
|
||||
mask_based_state_updates: bool=False):
|
||||
assert all(len(w) == WORDLE_N for w in words), f'Not all words of length {WORDLE_N}, {words}'
|
||||
self.words = words
|
||||
self.max_turns = max_turns
|
||||
self.allowable_words = allowable_words
|
||||
self.mask_based_state_updates = mask_based_state_updates
|
||||
if not self.allowable_words:
|
||||
self.allowable_words = len(self.words)
|
||||
|
||||
self.frequencies = None
|
||||
if frequencies:
|
||||
assert len(words) == len(frequencies), f'{len(words), len(frequencies)}'
|
||||
self.frequencies = np.array(frequencies, dtype=np.float32) / sum(frequencies)
|
||||
|
||||
self.action_space = spaces.Discrete(len(self.words))
|
||||
self.observation_space = spaces.MultiDiscrete(wordle.state.get_nvec(self.max_turns))
|
||||
|
||||
self.done = True
|
||||
self.goal_word: int = -1
|
||||
|
||||
self.state: wordle.state.WordleState = None
|
||||
self.state_updater = wordle.state.update
|
||||
if self.mask_based_state_updates:
|
||||
self.state_updater = wordle.state.update_mask
|
||||
|
||||
def step(self, action: int):
|
||||
if self.done:
|
||||
raise ValueError(
|
||||
"You are calling 'step()' even though this "
|
||||
"environment has already returned done = True. You "
|
||||
"should always call 'reset()' once you receive 'done = "
|
||||
"True' -- any further steps are undefined behavior."
|
||||
)
|
||||
self.state = self.state_updater(state=self.state,
|
||||
word=self.words[action],
|
||||
goal_word=self.words[self.goal_word])
|
||||
|
||||
reward = 0
|
||||
if action == self.goal_word:
|
||||
self.done = True
|
||||
#reward = REWARD
|
||||
if wordle.state.remaining_steps(self.state) == self.max_turns-1:
|
||||
reward = 0#-10*REWARD # No reward for guessing off the bat
|
||||
else:
|
||||
#reward = REWARD*(self.state.remaining_steps() + 1) / self.max_turns
|
||||
reward = REWARD
|
||||
elif wordle.state.remaining_steps(self.state) == 0:
|
||||
self.done = True
|
||||
reward = -REWARD
|
||||
|
||||
return self.state.copy(), reward, self.done, False, {"goal_id": self.goal_word}
|
||||
|
||||
def reset(self, options = None, seed: Optional[int] = None):
|
||||
self.state = wordle.state.new(self.max_turns)
|
||||
self.done = False
|
||||
self.goal_word = int(np.random.random()*self.allowable_words)
|
||||
|
||||
return self.state.copy(), {"goal_id": self.goal_word}
|
||||
|
||||
def set_goal_word(self, goal_word: str):
|
||||
self.goal_word = self.words.index(goal_word)
|
||||
|
||||
def set_goal_id(self, goal_id: int):
|
||||
self.goal_word = goal_id
|
||||
|
||||
|
||||
class WordleEnv10(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(10), max_turns=6)
|
||||
|
||||
|
||||
class WordleEnv100(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(100), max_turns=6)
|
||||
|
||||
|
||||
class WordleEnv100OneAction(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(100), allowable_words=1, max_turns=6)
|
||||
|
||||
|
||||
class WordleEnv100WithMask(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(100), max_turns=6,
|
||||
mask_based_state_updates=True)
|
||||
|
||||
|
||||
class WordleEnv100TwoAction(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(100), allowable_words=2, max_turns=6)
|
||||
|
||||
|
||||
class WordleEnv100FullAction(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(), allowable_words=100, max_turns=6)
|
||||
|
||||
|
||||
class WordleEnv1000(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(1000), max_turns=6)
|
||||
|
||||
|
||||
class WordleEnv1000WithMask(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(1000), max_turns=6,
|
||||
mask_based_state_updates=True)
|
||||
|
||||
|
||||
class WordleEnv1000FullAction(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(), allowable_words=1000, max_turns=6)
|
||||
|
||||
|
||||
class WordleEnvFull(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(), max_turns=6)
|
||||
|
||||
|
||||
class WordleEnvReal(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(), allowable_words=2315, max_turns=6)
|
||||
|
||||
|
||||
class WordleEnvRealWithMask(WordleEnvBase):
|
||||
def __init__(self):
|
||||
super().__init__(words=_load_words(), allowable_words=2315, max_turns=6,
|
||||
mask_based_state_updates=True)
|
12972
wordle/wordle_words.txt
Normal file
12972
wordle/wordle_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user