mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-22 18:49:21 +00:00
Compare commits
4 Commits
f40301cac9
...
ethan-test
Author | SHA1 | Date | |
---|---|---|---|
|
335d56ac88 | ||
|
496b8ad796 | ||
|
8d3ce990e3 | ||
|
7ad5b97463 |
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,6 +1,3 @@
|
||||
**/data/*
|
||||
**/*.zip
|
||||
**/__pycache__
|
||||
/env
|
||||
**/runs/*
|
||||
**/wandb/*
|
||||
**/data/*
|
||||
/env
|
||||
**/*.zip
|
21
Gym-Wordle-main/LICENSE
Normal file
21
Gym-Wordle-main/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 David Kraemer
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
78
Gym-Wordle-main/README.md
Normal file
78
Gym-Wordle-main/README.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Gym-Wordle
|
||||
|
||||
An OpenAI gym compatible environment for training agents to play Wordle.
|
||||
|
||||
<p align='center'>
|
||||
<img src="https://user-images.githubusercontent.com/8514041/152437216-d78e85f6-8049-4cb9-ae61-3c015a8a0e4f.gif"><br/>
|
||||
<em>User-input demo of the environment</em>
|
||||
</p>
|
||||
|
||||
## Installation
|
||||
|
||||
My goal is for a minimalist package that lets you install quickly and get on
|
||||
with your research. Installation is just a simple call to `pip`:
|
||||
|
||||
```
|
||||
$ pip install gym_wordle
|
||||
```
|
||||
|
||||
### Requirements
|
||||
|
||||
In keeping with my desire to have a minimalist package, there are only three
|
||||
major requirements:
|
||||
|
||||
* `numpy`
|
||||
* `gym`
|
||||
* `sty`, a lovely little package for stylizing text in terminals
|
||||
|
||||
## Usage
|
||||
|
||||
The basic flow for training agents with the `Wordle-v0` environment is the same
|
||||
as with gym environments generally:
|
||||
|
||||
```Python
|
||||
import gym
|
||||
import gym_wordle
|
||||
|
||||
eng = gym.make("Wordle-v0")
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
action = ... # RL magic
|
||||
state, reward, done, info = env.step(action)
|
||||
```
|
||||
|
||||
If you're like millions of other people, you're a Wordle-obsessive in your own
|
||||
right. I have good news for you! The `Wordle-v0` environment currently has an
|
||||
implemented `render` method, which allows you to see a human-friendly version
|
||||
of the game. And it isn't so hard to set up the environment to play for
|
||||
yourself. Here's an example script:
|
||||
|
||||
```Python
|
||||
from gym_wordle.utils import play
|
||||
|
||||
play()
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Coming soon!
|
||||
|
||||
## Examples
|
||||
|
||||
Coming soon!
|
||||
|
||||
## Citing
|
||||
|
||||
If you decide to use this project in your work, please consider a citation!
|
||||
|
||||
```bibtex
|
||||
@misc{gym_wordle,
|
||||
author = {Kraemer, David},
|
||||
title = {An Environment for Reinforcement Learning with Wordle},
|
||||
year = {2022},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/DavidNKraemer/Gym-Wordle}},
|
||||
}
|
||||
```
|
7
Gym-Wordle-main/gym-wordle.toml
Normal file
7
Gym-Wordle-main/gym-wordle.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[build-system]
|
||||
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"wheel"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
7
Gym-Wordle-main/gym_wordle/__init__.py
Normal file
7
Gym-Wordle-main/gym_wordle/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from gym.envs.registration import register
|
||||
from .wordle import WordleEnv
|
||||
|
||||
register(
|
||||
id='Wordle-v0',
|
||||
entry_point='gym_wordle.wordle:WordleEnv'
|
||||
)
|
File diff suppressed because it is too large
Load Diff
BIN
Gym-Wordle-main/gym_wordle/dictionary/guess_list.npy
Normal file
BIN
Gym-Wordle-main/gym_wordle/dictionary/guess_list.npy
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
BIN
Gym-Wordle-main/gym_wordle/dictionary/solution_list.npy
Normal file
BIN
Gym-Wordle-main/gym_wordle/dictionary/solution_list.npy
Normal file
Binary file not shown.
94
Gym-Wordle-main/gym_wordle/utils.py
Normal file
94
Gym-Wordle-main/gym_wordle/utils.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_chars = ' abcdefghijklmnopqrstuvwxyz'
|
||||
_char_d = {c: i for i, c in enumerate(_chars)}
|
||||
|
||||
|
||||
def to_english(array: npt.NDArray[np.int64]) -> str:
|
||||
"""Converts a numpy integer array into a corresponding English string.
|
||||
|
||||
Args:
|
||||
array: Word in array (int) form. It is assumed that each integer in the
|
||||
array is between 0,...,26 (inclusive).
|
||||
|
||||
Returns:
|
||||
A (lowercase) string representation of the word.
|
||||
"""
|
||||
return ''.join(_chars[i] for i in array)
|
||||
|
||||
|
||||
def to_array(word: str) -> npt.NDArray[np.int64]:
|
||||
"""Converts a string of characters into a corresponding numpy array.
|
||||
|
||||
Args:
|
||||
word: Word in string form. It is assumed that each character in the
|
||||
string is either an empty space ' ' or lowercase alphabetical
|
||||
character.
|
||||
|
||||
Returns:
|
||||
An array representation of the word.
|
||||
"""
|
||||
return np.array([_char_d[c] for c in word])
|
||||
|
||||
|
||||
def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]:
|
||||
"""Loads a list of words in array form.
|
||||
|
||||
If specified, this will recompute the list from the human-readable list of
|
||||
words, and save the results in array form.
|
||||
|
||||
Args:
|
||||
category: Either 'guess' or 'solution', which corresponds to the list
|
||||
of acceptable guess words and the list of acceptable solution words.
|
||||
build: If True, recomputes and saves the array-version of the computed
|
||||
list for future access.
|
||||
|
||||
Returns:
|
||||
An array representation of the list of words specified by the category.
|
||||
This array has two dimensions, and the number of columns is fixed at
|
||||
five.
|
||||
"""
|
||||
assert category in {'guess', 'solution'}
|
||||
|
||||
arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy'
|
||||
if build:
|
||||
list_path = Path(__file__).parent / f'dictionary/{category}_list.csv'
|
||||
|
||||
with open(list_path, 'r') as f:
|
||||
words = np.array([to_array(line.strip()) for line in f])
|
||||
np.save(arr_path, words)
|
||||
|
||||
return np.load(arr_path)
|
||||
|
||||
|
||||
def play():
|
||||
"""Play Wordle yourself!"""
|
||||
import gym
|
||||
import gym_wordle
|
||||
|
||||
env = gym.make('Wordle-v0') # load the environment
|
||||
|
||||
env.reset()
|
||||
solution = to_english(env.unwrapped.solution_space[env.solution]).upper() # no peeking!
|
||||
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = -1
|
||||
|
||||
# in general, the environment won't be forgiving if you input an
|
||||
# invalid word, but for this function I want to let you screw up user
|
||||
# input without consequence, so just loops until valid input is taken
|
||||
while not env.action_space.contains(action):
|
||||
guess = input('Guess: ')
|
||||
action = env.unwrapped.action_space.index_of(to_array(guess))
|
||||
|
||||
state, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
print(f"The word was {solution}")
|
||||
|
286
Gym-Wordle-main/gym_wordle/wordle.py
Normal file
286
Gym-Wordle-main/gym_wordle/wordle.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
from collections import Counter
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
"""Super class for defining a space of valid words according to a specified
|
||||
list.
|
||||
|
||||
TODO: Fix these paragraphs
|
||||
The space is a subclass of gym.spaces.Discrete, where each element
|
||||
corresponds to an index of a valid word in the word list. The obfuscation
|
||||
is necessary for more direct implementation of RL algorithms, which expect
|
||||
spaces of less sophisticated form.
|
||||
|
||||
In addition to the default methods of the Discrete space, it implements
|
||||
a __getitem__ method for easy index lookup, and an index_of method to
|
||||
convert potential words into their corresponding index (if they exist).
|
||||
"""
|
||||
|
||||
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
||||
"""
|
||||
Args:
|
||||
words: Collection of words in array form with shape (_, 5), where
|
||||
each word is a row of the array. Each array element is an integer
|
||||
between 0,...,26 (inclusive).
|
||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||
"""
|
||||
super().__init__(words.shape[0], **kwargs)
|
||||
self.words = words
|
||||
|
||||
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
||||
"""Obtains the (int-encoded) word associated with the given index.
|
||||
|
||||
Args:
|
||||
index: Index for the list of words.
|
||||
|
||||
Returns:
|
||||
Associated word at the position specified by index.
|
||||
"""
|
||||
return self.words[index]
|
||||
|
||||
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
||||
"""Given a word, determine its index in the list (if it exists),
|
||||
otherwise returning -1 if no index exists.
|
||||
|
||||
Args:
|
||||
word: Word to find in the word list.
|
||||
|
||||
Returns:
|
||||
The index of the given word if it exists, otherwise -1.
|
||||
"""
|
||||
try:
|
||||
index, = np.nonzero((word == self.words).all(axis=1))
|
||||
return index[0]
|
||||
except:
|
||||
return -1
|
||||
|
||||
|
||||
class SolutionList(WordList):
|
||||
"""Space for *solution* words to the Wordle environment.
|
||||
|
||||
In the game Wordle, there are two different collections of words:
|
||||
|
||||
* "guesses", which the game accepts as valid words to use to guess the
|
||||
answer.
|
||||
* "solutions", which the game uses to choose solutions from.
|
||||
|
||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||
|
||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
||||
|
||||
This class represents the set of solution words.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||
"""
|
||||
words = get_words('solution')
|
||||
super().__init__(words, **kwargs)
|
||||
|
||||
|
||||
class WordleObsSpace(gym.spaces.Box):
|
||||
"""Implementation of the state (observation) space in terms of gym
|
||||
primatives, in this case, gym.spaces.Box.
|
||||
|
||||
The Wordle observation space can be thought of as a 6x5 array with two
|
||||
channels:
|
||||
|
||||
- the character channel, indicating which characters are placed on the
|
||||
board (unfilled rows are marked with the empty character, 0)
|
||||
- the flag channel, indicating the in-game information associated with
|
||||
each character's placement (green highlight, yellow highlight, etc.)
|
||||
|
||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
||||
the solution will always be a word of length 5.
|
||||
|
||||
For simplicity, and compatibility with the stable_baselines algorithms,
|
||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
||||
horizontally appended (along columns). Thus each row in the observation
|
||||
should be interpreted as
|
||||
|
||||
c0 c1 c2 c3 c4 f0 f1 f2 f3 f4
|
||||
|
||||
when the word is c0...c4 and its associated flags are f0...f4.
|
||||
|
||||
While the superclass method `sample` is available to the WordleObsSpace, it
|
||||
should be emphasized that the output of `sample` will (almost surely) not
|
||||
correspond to a real game configuration, because the sampling is not out of
|
||||
possible game configurations. Instead, the Box superclass just samples the
|
||||
integer array space uniformly.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.n_rows = 6
|
||||
self.n_cols = 5
|
||||
self.max_char = 26
|
||||
self.max_flag = 4
|
||||
|
||||
low = np.zeros((self.n_rows, 2*self.n_cols))
|
||||
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
||||
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
||||
|
||||
super().__init__(low, high, dtype=np.int64, **kwargs)
|
||||
|
||||
|
||||
class GuessList(WordList):
|
||||
"""Space for *solution* words to the Wordle environment.
|
||||
|
||||
In the game Wordle, there are two different collections of words:
|
||||
|
||||
* "guesses", which the game accepts as valid words to use to guess the
|
||||
answer.
|
||||
* "solutions", which the game uses to choose solutions from.
|
||||
|
||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||
|
||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
||||
|
||||
This class represents the set of guess words.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||
"""
|
||||
words = get_words('guess')
|
||||
super().__init__(words, **kwargs)
|
||||
|
||||
|
||||
class WordleEnv(gym.Env):
|
||||
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
# character flag codes
|
||||
no_char = 0
|
||||
right_pos = 1
|
||||
wrong_pos = 2
|
||||
wrong_char = 3
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.seed()
|
||||
self.action_space = GuessList()
|
||||
self.solution_space = SolutionList()
|
||||
|
||||
self.observation_space = WordleObsSpace()
|
||||
|
||||
self._highlights = {
|
||||
self.right_pos: (bg.green, bg.rs),
|
||||
self.wrong_pos: (bg.yellow, bg.rs),
|
||||
self.wrong_char: ('', ''),
|
||||
self.no_char: ('', ''),
|
||||
}
|
||||
|
||||
self.n_rounds = 6
|
||||
self.n_letters = 5
|
||||
|
||||
def _highlighter(self, char: str, flag: int) -> str:
|
||||
"""Terminal renderer functionality. Properly highlights a character
|
||||
based on the flag associated with it.
|
||||
|
||||
Args:
|
||||
char: Character in question.
|
||||
flag: Associated flag, one of:
|
||||
- 0: no character (render no background)
|
||||
- 1: right position (render green background)
|
||||
- 2: wrong position (render yellow background)
|
||||
- 3: wrong character (render no background)
|
||||
|
||||
Returns:
|
||||
Correct ASCII sequence producing the desired character in the
|
||||
correct background.
|
||||
"""
|
||||
front, back = self._highlights[flag]
|
||||
return front + char + back
|
||||
|
||||
def reset(self):
|
||||
self.round = 0
|
||||
self.solution = self.solution_space.sample()
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters),
|
||||
dtype=np.int64)
|
||||
|
||||
return self.state
|
||||
|
||||
def render(self, mode: str ='human'):
|
||||
"""Renders the Wordle environment.
|
||||
|
||||
Currently supported render modes:
|
||||
|
||||
- human: renders the Wordle game to the terminal.
|
||||
|
||||
Args:
|
||||
mode: the mode to render with
|
||||
"""
|
||||
if mode == 'human':
|
||||
for row in self.states:
|
||||
text = ''.join(map(
|
||||
self._highlighter,
|
||||
to_english(row[:self.n_letters]).upper(),
|
||||
row[self.n_letters:]
|
||||
))
|
||||
|
||||
print(text)
|
||||
else:
|
||||
super(WordleEnv, self).render(mode=mode)
|
||||
|
||||
def step(self, action):
|
||||
"""Run one step of the Wordle game. Every game must be previously
|
||||
initialized by a call to the `reset` method.
|
||||
|
||||
Args:
|
||||
action: Word guessed by the agent.
|
||||
Returns:
|
||||
state (object): Wordle game state after the guess.
|
||||
reward (float): Reward associated with the guess (-1 for incorrect,
|
||||
0 for correct)
|
||||
done (bool): Whether the game has ended (by a correct guess or
|
||||
after six guesses).
|
||||
info (dict): Auxiliary diagnostic information (empty).
|
||||
"""
|
||||
assert self.action_space.contains(action), 'Invalid word!'
|
||||
|
||||
# transform the action, solution indices to their words
|
||||
action = self.action_space[action]
|
||||
solution = self.solution_space[self.solution]
|
||||
|
||||
# populate the word chars into the row (character channel)
|
||||
self.state[self.round][:self.n_letters] = action
|
||||
|
||||
# populate the flag characters into the row (flag channel)
|
||||
counter = Counter()
|
||||
for i, char in enumerate(action):
|
||||
flag_i = i + self.n_letters # starts at 5
|
||||
counter[char] += 1
|
||||
|
||||
if char == solution[i]: # character is in correct position
|
||||
self.state[self.round, i] = self.right_pos
|
||||
elif counter[char] <= (char == solution).sum():
|
||||
# current character has been seen within correct number of
|
||||
# occurrences
|
||||
self.state[self.round, i] = self.wrong_pos
|
||||
else:
|
||||
# wrong character, or "correct" character too many times
|
||||
self.state[self.round, i] = self.wrong_char
|
||||
|
||||
self.round += 1
|
||||
|
||||
correct = (action == solution).all()
|
||||
game_over = (self.round == self.n_rounds)
|
||||
|
||||
done = correct or game_over
|
||||
|
||||
# Total reward equals -(number of incorrect guesses)
|
||||
reward = 0. if correct else -1.
|
||||
|
||||
return self.state, reward, done, {}
|
||||
|
35
Gym-Wordle-main/setup.py
Normal file
35
Gym-Wordle-main/setup.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md', 'r', encoding='utf-8') as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(
|
||||
name='gym_wordle',
|
||||
version='0.1.3',
|
||||
author='David Kraemer',
|
||||
author_email='david.kraemer@stonybrook.edu',
|
||||
description='OpenAI gym environment for training agents on Wordle',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/DavidNKraemer/Gym-Wordle',
|
||||
packages=find_packages(
|
||||
include=[
|
||||
'gym_wordle',
|
||||
'gym_wordle.*'
|
||||
]
|
||||
),
|
||||
package_data={
|
||||
'gym_wordle': ['dictionary/*']
|
||||
},
|
||||
python_requires='>=3.7',
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
install_requires=[
|
||||
'numpy>=1.20',
|
||||
'gym==0.19',
|
||||
'sty==1.0',
|
||||
],
|
||||
)
|
91
custom_env/agent.py
Normal file
91
custom_env/agent.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Agent:
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
# EPS_START is the starting value of epsilon
|
||||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
# TAU is the update rate of the target network
|
||||
# LR is the learning rate of the ``AdamW`` optimizer
|
||||
self.batch_size = 128
|
||||
self.gamma = 0.99
|
||||
self.eps_start = 0.9
|
||||
self.eps_end = 0.05
|
||||
self.eps_decay = 1000
|
||||
self.tau = 0.005
|
||||
self.lr = 1e-4
|
||||
self.n_actions = n_actions
|
||||
|
||||
policy_net = DQN(n_observations, n_actions).to(device)
|
||||
target_net = DQN(n_observations, n_actions).to(device)
|
||||
target_net.load_state_dict(policy_net.state_dict())
|
||||
|
||||
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
||||
memory = ReplayMemory(10000)
|
||||
|
||||
def get_state(self, game):
|
||||
pass
|
||||
|
||||
def select_action(state):
|
||||
sample = random.random()
|
||||
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
|
||||
math.exp(-1. * steps_done / EPS_DECAY)
|
||||
steps_done += 1
|
||||
if sample > eps_threshold:
|
||||
with torch.no_grad():
|
||||
# t.max(1) will return the largest column value of each row.
|
||||
# second column on max result is index of where max element was
|
||||
# found, so we pick action with the larger expected reward.
|
||||
return policy_net(state).max(1).indices.view(1, 1)
|
||||
else:
|
||||
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
|
||||
|
||||
def optimize_model():
|
||||
if len(memory) < BATCH_SIZE:
|
||||
return
|
||||
transitions = memory.sample(BATCH_SIZE)
|
||||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
||||
# detailed explanation). This converts batch-array of Transitions
|
||||
# to Transition of batch-arrays.
|
||||
batch = Transition(*zip(*transitions))
|
||||
|
||||
# Compute a mask of non-final states and concatenate the batch elements
|
||||
# (a final state would've been the one after which simulation ended)
|
||||
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
||||
batch.next_state)), device=device, dtype=torch.bool)
|
||||
non_final_next_states = torch.cat([s for s in batch.next_state
|
||||
if s is not None])
|
||||
state_batch = torch.cat(batch.state)
|
||||
action_batch = torch.cat(batch.action)
|
||||
reward_batch = torch.cat(batch.reward)
|
||||
|
||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
||||
# columns of actions taken. These are the actions which would've been taken
|
||||
# for each batch state according to policy_net
|
||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||||
|
||||
# Compute V(s_{t+1}) for all next states.
|
||||
# Expected values of actions for non_final_next_states are computed based
|
||||
# on the "older" target_net; selecting their best reward with max(1).values
|
||||
# This is merged based on the mask, such that we'll have either the expected
|
||||
# state value or 0 in case the state was final.
|
||||
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
||||
with torch.no_grad():
|
||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
|
||||
# Compute the expected Q values
|
||||
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
||||
|
||||
# Compute Huber loss
|
||||
criterion = nn.SmoothL1Loss()
|
||||
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
||||
|
||||
# Optimize the model
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# In-place gradient clipping
|
||||
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
||||
optimizer.step()
|
16
custom_env/create_wordlist.py
Normal file
16
custom_env/create_wordlist.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import pathlib
|
||||
import sys
|
||||
from string import ascii_letters
|
||||
|
||||
in_path = pathlib.Path(sys.argv[1])
|
||||
out_path = pathlib.Path(sys.argv[2])
|
||||
|
||||
words = sorted(
|
||||
{
|
||||
word.lower()
|
||||
for word in in_path.read_text(encoding="utf-8").split()
|
||||
if all(letter in ascii_letters for letter in word)
|
||||
},
|
||||
key=lambda word: (len(word), word),
|
||||
)
|
||||
out_path.write_text("\n".join(words))
|
5757
custom_env/five_letter_words.txt
Normal file
5757
custom_env/five_letter_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
44
custom_env/model.py
Normal file
44
custom_env/model.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import math
|
||||
import random
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import namedtuple, deque
|
||||
from itertools import count
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
Transition = namedtuple('Transition',
|
||||
('state', 'action', 'next_state', 'reward'))
|
||||
|
||||
|
||||
class ReplayMemory(object):
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.memory = deque([], maxlen=capacity)
|
||||
|
||||
def push(self, *args):
|
||||
self.memory.append(Transition(*args))
|
||||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
|
||||
def __init__(self, n_observations: int, n_actions: int) -> None:
|
||||
super(DQN, self).__init__()
|
||||
self.layer1 = nn.Linear(n_observations, 128)
|
||||
self.layer2 = nn.Linear(128, 128)
|
||||
self.layer3 = nn.Linear(128, n_actions)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.layer1(x))
|
||||
x = F.relu(self.layer2(x))
|
||||
return self.layer3(x)
|
61
custom_env/test2.ipynb
Normal file
61
custom_env/test2.ipynb
Normal file
@@ -0,0 +1,61 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from string import ascii_letters, ascii_uppercase, ascii_lowercase"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'ABCDEFGHIJKLMNOPQRSTUVWXYZ'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ascii_uppercase"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
1098
custom_env/wordlist.txt
Normal file
1098
custom_env/wordlist.txt
Normal file
File diff suppressed because it is too large
Load Diff
119
custom_env/wyrdl.py
Normal file
119
custom_env/wyrdl.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import contextlib
|
||||
import pathlib
|
||||
import random
|
||||
from string import ascii_letters, ascii_lowercase
|
||||
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
|
||||
console = Console(width=40, theme=Theme({"warning": "red on yellow"}))
|
||||
|
||||
NUM_LETTERS = 5
|
||||
NUM_GUESSES = 6
|
||||
WORDS_PATH = pathlib.Path(__file__).parent / "wordlist.txt"
|
||||
|
||||
|
||||
class Wordle:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.word_list = WORDS_PATH.read_text(encoding="utf-8").split("\n")
|
||||
self.n_guesses = 6
|
||||
self.num_letters = 5
|
||||
self.curr_word = None
|
||||
self.reset()
|
||||
|
||||
def refresh_page(self, headline):
|
||||
console.clear()
|
||||
console.rule(f"[bold blue]:leafy_green: {headline} :leafy_green:[/]\n")
|
||||
|
||||
def start_game(self):
|
||||
# get a new random word
|
||||
word = self.get_random_word(self.word_list)
|
||||
|
||||
self.curr_word = word
|
||||
|
||||
def get_state(self):
|
||||
return
|
||||
|
||||
def action_to_word(self, action):
|
||||
# Calculate the word from the array
|
||||
word = ''
|
||||
for i in range(0, len(ascii_lowercase), 26):
|
||||
# Find the index of 1 in each block of 26
|
||||
letter_index = action[i:i+26].index(1)
|
||||
# Append the corresponding letter to the word
|
||||
word += ascii_lowercase[letter_index]
|
||||
|
||||
return word
|
||||
|
||||
def play_guess(self, action):
|
||||
# probably an array of length 26 * 5 for 26 letters and 5 positions
|
||||
guess = action
|
||||
|
||||
def get_random_word(self, word_list):
|
||||
if words := [
|
||||
word.upper()
|
||||
for word in word_list
|
||||
if len(word) == NUM_LETTERS
|
||||
and all(letter in ascii_letters for letter in word)
|
||||
]:
|
||||
return random.choice(words)
|
||||
else:
|
||||
console.print(
|
||||
f"No words of length {NUM_LETTERS} in the word list",
|
||||
style="warning",
|
||||
)
|
||||
raise SystemExit()
|
||||
|
||||
def show_guesses(self, guesses, word):
|
||||
letter_status = {letter: letter for letter in ascii_lowercase}
|
||||
for guess in guesses:
|
||||
styled_guess = []
|
||||
for letter, correct in zip(guess, word):
|
||||
if letter == correct:
|
||||
style = "bold white on green"
|
||||
elif letter in word:
|
||||
style = "bold white on yellow"
|
||||
elif letter in ascii_letters:
|
||||
style = "white on #666666"
|
||||
else:
|
||||
style = "dim"
|
||||
styled_guess.append(f"[{style}]{letter}[/]")
|
||||
if letter != "_":
|
||||
letter_status[letter] = f"[{style}]{letter}[/]"
|
||||
|
||||
console.print("".join(styled_guess), justify="center")
|
||||
console.print("\n" + "".join(letter_status.values()), justify="center")
|
||||
|
||||
def guess_word(self, previous_guesses):
|
||||
guess = console.input("\nGuess word: ").upper()
|
||||
|
||||
if guess in previous_guesses:
|
||||
console.print(f"You've already guessed {guess}.", style="warning")
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
if len(guess) != NUM_LETTERS:
|
||||
console.print(
|
||||
f"Your guess must be {NUM_LETTERS} letters.", style="warning"
|
||||
)
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
if any((invalid := letter) not in ascii_letters for letter in guess):
|
||||
console.print(
|
||||
f"Invalid letter: '{invalid}'. Please use English letters.",
|
||||
style="warning",
|
||||
)
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
return guess
|
||||
|
||||
def reset(self, guesses, word, guessed_correctly, n_episodes):
|
||||
refresh_page(headline=f"Game: {n_episodes}")
|
||||
|
||||
if guessed_correctly:
|
||||
console.print(f"\n[bold white on green]Correct, the word is {word}[/]")
|
||||
else:
|
||||
console.print(f"\n[bold white on red]Sorry, the word was {word}[/]")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
346
dqn_wordle.ipynb
346
dqn_wordle.ipynb
@@ -6,11 +6,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import gym_wordle\n",
|
||||
"from stable_baselines3 import DQN, PPO, common\n",
|
||||
"import numpy as np\n",
|
||||
"import tqdm"
|
||||
"import gymnasium as gym\n",
|
||||
"from stable_baselines3 import DQN\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -19,291 +17,36 @@
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<Monitor<WordleEnv instance>>\n"
|
||||
"ename": "NameNotFound",
|
||||
"evalue": "Environment `Wordle` doesn't exist.",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mNameNotFound\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mgym\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mWordle-v0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(env)\n",
|
||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:741\u001b[0m, in \u001b[0;36mmake\u001b[1;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[0;32m 738\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mid\u001b[39m, \u001b[38;5;28mstr\u001b[39m)\n\u001b[0;32m 740\u001b[0m \u001b[38;5;66;03m# The environment name can include an unloaded module in \"module:env_name\" style\u001b[39;00m\n\u001b[1;32m--> 741\u001b[0m env_spec \u001b[38;5;241m=\u001b[39m \u001b[43m_find_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 743\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(env_spec, EnvSpec)\n\u001b[0;32m 745\u001b[0m \u001b[38;5;66;03m# Update the env spec kwargs with the `make` kwargs\u001b[39;00m\n",
|
||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:527\u001b[0m, in \u001b[0;36m_find_spec\u001b[1;34m(env_id)\u001b[0m\n\u001b[0;32m 521\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[0;32m 522\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing the latest versioned environment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnew_env_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 523\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minstead of the unversioned environment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 524\u001b[0m )\n\u001b[0;32m 526\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m env_spec \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 527\u001b[0m \u001b[43m_check_version_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mversion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 528\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mError(\n\u001b[0;32m 529\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo registered env with id: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Did you register it, or import the package that registers it? Use `gymnasium.pprint_registry()` to see all of the registered environments.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 530\u001b[0m )\n\u001b[0;32m 532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m env_spec\n",
|
||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:393\u001b[0m, in \u001b[0;36m_check_version_exists\u001b[1;34m(ns, name, version)\u001b[0m\n\u001b[0;32m 390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m get_env_id(ns, name, version) \u001b[38;5;129;01min\u001b[39;00m registry:\n\u001b[0;32m 391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m--> 393\u001b[0m \u001b[43m_check_name_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 394\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m version \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n",
|
||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:370\u001b[0m, in \u001b[0;36m_check_name_exists\u001b[1;34m(ns, name)\u001b[0m\n\u001b[0;32m 367\u001b[0m namespace_msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in namespace \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mns\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ns \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 368\u001b[0m suggestion_msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Did you mean: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestion[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`?\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m suggestion \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m--> 370\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mNameNotFound(\n\u001b[0;32m 371\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEnvironment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt exist\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnamespace_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestion_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 372\u001b[0m )\n",
|
||||
"\u001b[1;31mNameNotFound\u001b[0m: Environment `Wordle` doesn't exist."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||
"env = common.monitor.Monitor(env)\n",
|
||||
"env = gym.make(\"Wordle-v0\")\n",
|
||||
"\n",
|
||||
"print(env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using cuda device\n",
|
||||
"Wrapping the env in a DummyVecEnv.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6921a0721569456abf5bceac7e7b6b34",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Output()"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 4.97 |\n",
|
||||
"| ep_rew_mean | -63.8 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 10000 |\n",
|
||||
"| fps | 1628 |\n",
|
||||
"| time_elapsed | 30 |\n",
|
||||
"| total_timesteps | 49995 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"----------------------------------\n",
|
||||
"| rollout/ | |\n",
|
||||
"| ep_len_mean | 5 |\n",
|
||||
"| ep_rew_mean | -70.5 |\n",
|
||||
"| exploration_rate | 0.05 |\n",
|
||||
"| time/ | |\n",
|
||||
"| episodes | 20000 |\n",
|
||||
"| fps | 662 |\n",
|
||||
"| time_elapsed | 150 |\n",
|
||||
"| total_timesteps | 99992 |\n",
|
||||
"| train/ | |\n",
|
||||
"| learning_rate | 0.0001 |\n",
|
||||
"| loss | 11.7 |\n",
|
||||
"| n_updates | 12497 |\n",
|
||||
"----------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
||||
],
|
||||
"text/plain": []
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||||
"</pre>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"total_timesteps = 100_000\n",
|
||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"dqn_new_state\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"Exception: code() argument 13 must be str, not int\n",
|
||||
" warnings.warn(\n",
|
||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||
"Exception: code() argument 13 must be str, not int\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# model = DQN.load(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
||||
" 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
||||
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||
"0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||
"\n",
|
||||
"for i in range(1000):\n",
|
||||
" \n",
|
||||
" state, info = env.reset()\n",
|
||||
"\n",
|
||||
" done = False\n",
|
||||
"\n",
|
||||
" wins = 0\n",
|
||||
"\n",
|
||||
" while not done:\n",
|
||||
"\n",
|
||||
" action, _states = model.predict(state, deterministic=True)\n",
|
||||
"\n",
|
||||
" state, reward, done, truncated, info = env.step(action)\n",
|
||||
"\n",
|
||||
" print(state)\n",
|
||||
" if info[\"correct\"]:\n",
|
||||
" wins += 1\n",
|
||||
"\n",
|
||||
"print(wins)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n",
|
||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||
" 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n",
|
||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
||||
" 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
|
||||
" -50)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"state, reward"
|
||||
"total_timesteps = 100000\n",
|
||||
"model = DQN(\"MlpPolicy\", env, verbose=0)\n",
|
||||
"model.learn(total_timesteps=total_timesteps, progress_bar=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -311,7 +54,54 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"def test(model):\n",
|
||||
"\n",
|
||||
" end_rewards = []\n",
|
||||
"\n",
|
||||
" for i in range(1000):\n",
|
||||
" \n",
|
||||
" state = env.reset()\n",
|
||||
"\n",
|
||||
" done = False\n",
|
||||
"\n",
|
||||
" while not done:\n",
|
||||
"\n",
|
||||
" action, _states = model.predict(state, deterministic=True)\n",
|
||||
"\n",
|
||||
" state, reward, done, info = env.step(action)\n",
|
||||
" \n",
|
||||
" end_rewards.append(reward == 0)\n",
|
||||
" \n",
|
||||
" return np.sum(end_rewards) / len(end_rewards)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = DQN.load(\"dqn_wordle\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(test(model))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@@ -1,11 +0,0 @@
|
||||
# N-dle Solver
|
||||
|
||||
A solver designed to beat New York Time's Wordle (link [here](https://www.nytimes.com/games/wordle/index.html)). If you are bored enough, can extend to solve the more general N-dle problem (for quordle, octordle, etc.)
|
||||
|
||||
I originally made this out of frustration for the game (and my own lack of lingual talent). One day, my friend thought she could beat my bot. To her dismay, she learned that she is no better than a machine. Let's see if you can do any better (the average number of attempts is 3.6).
|
||||
|
||||
## Usage:
|
||||
1. Run `python main.py --n 1`
|
||||
2. Follow the prompts
|
||||
|
||||
Currently only supports solving for 1 word at a time (i.e. wordle).
|
@@ -1,126 +0,0 @@
|
||||
import re
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AI:
|
||||
def __init__(self, vocab_file, num_letters=5, num_guesses=6):
|
||||
self.vocab_file = vocab_file
|
||||
self.num_letters = num_letters
|
||||
self.num_guesses = 6
|
||||
|
||||
self.vocab, self.vocab_scores, self.letter_scores = self.get_vocab(self.vocab_file)
|
||||
self.best_words = sorted(list(self.vocab_scores.items()), key=lambda tup: tup[1])[::-1]
|
||||
|
||||
self.domains = None
|
||||
self.possible_letters = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def solve(self):
|
||||
num_guesses = 0
|
||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||
num_guesses += 1
|
||||
word = self.sample()
|
||||
|
||||
# # Always start with these two words
|
||||
# if num_guesses == 1:
|
||||
# word = 'soare'
|
||||
# elif num_guesses == 2:
|
||||
# word = 'culti'
|
||||
|
||||
print('-----------------------------------------------')
|
||||
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
|
||||
print('-----------------------------------------------')
|
||||
self.arc_consistency(word)
|
||||
|
||||
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
||||
|
||||
|
||||
def arc_consistency(self, word):
|
||||
print(f'Performing arc consistency check on {word}...')
|
||||
print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.')
|
||||
results = []
|
||||
|
||||
# Collect results
|
||||
for l in word:
|
||||
while True:
|
||||
result = input(f'{l}: ')
|
||||
if result not in ['0', '1', '2']:
|
||||
print('Incorrect option. Try again.')
|
||||
continue
|
||||
results.append(result)
|
||||
break
|
||||
|
||||
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
|
||||
|
||||
for i in range(len(word)):
|
||||
if results[i] == '0':
|
||||
if word[i] in self.possible_letters:
|
||||
if word[i] in self.domains[i]:
|
||||
self.domains[i].remove(word[i])
|
||||
else:
|
||||
for j in range(len(self.domains)):
|
||||
if word[i] in self.domains[j] and len(self.domains[j]) > 1:
|
||||
self.domains[j].remove(word[i])
|
||||
if results[i] == '1':
|
||||
if word[i] in self.domains[i]:
|
||||
self.domains[i].remove(word[i])
|
||||
if results[i] == '2':
|
||||
self.domains[i] = [word[i]]
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
||||
self.possible_letters = []
|
||||
|
||||
def sample(self):
|
||||
"""
|
||||
Samples a best word given the current domains
|
||||
:return:
|
||||
"""
|
||||
# Compile a regex of possible words with the current domain
|
||||
regex_string = ''
|
||||
for domain in self.domains:
|
||||
regex_string += ''.join(['[', ''.join(domain), ']', '{1}'])
|
||||
pattern = re.compile(regex_string)
|
||||
|
||||
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||
for word, _ in self.best_words:
|
||||
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||
return word
|
||||
|
||||
def get_vocab(self, vocab_file):
|
||||
vocab = []
|
||||
with open(vocab_file, 'r') as f:
|
||||
for l in f:
|
||||
vocab.append(l.strip())
|
||||
|
||||
# Count letter frequencies at each index
|
||||
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||
for word in vocab:
|
||||
for i, l in enumerate(word):
|
||||
letter_freqs[i][l] += 1
|
||||
|
||||
# Assign a score to each letter at each index by the probability of it appearing
|
||||
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)]
|
||||
for i in range(len(letter_scores)):
|
||||
max_freq = np.max(list(letter_freqs[i].values()))
|
||||
for l in letter_scores[i].keys():
|
||||
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||
|
||||
# Find a sorted list of words ranked by sum of letter scores
|
||||
vocab_scores = {} # (score, word)
|
||||
for word in vocab:
|
||||
score = 0
|
||||
for i, l in enumerate(word):
|
||||
score += letter_scores[i][l]
|
||||
|
||||
# # Optimization: If repeating letters, deduct a couple points
|
||||
# if len(set(word)) < len(word):
|
||||
# score -= 0.25 * (len(word) - len(set(word)))
|
||||
|
||||
vocab_scores[word] = score
|
||||
|
||||
return vocab, vocab_scores, letter_scores
|
@@ -1,37 +0,0 @@
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
words = []
|
||||
with open('words.txt', 'r') as f:
|
||||
for l in f:
|
||||
words.append(l.strip())
|
||||
|
||||
# Count letter frequencies at each index
|
||||
letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||
for word in words:
|
||||
for i, l in enumerate(word):
|
||||
letter_freqs[i][l] += 1
|
||||
|
||||
# Assign a score to each letter at each index by the probability of it appearing
|
||||
letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)]
|
||||
for i in range(len(letter_scores)):
|
||||
max_freq = np.max(list(letter_freqs[i].values()))
|
||||
for l in letter_scores[i].keys():
|
||||
letter_scores[i][l] = letter_freqs[i][l] / max_freq
|
||||
|
||||
# Find a sorted list of words ranked by sum of letter scores
|
||||
word_scores = [] # (score, word)
|
||||
for word in words:
|
||||
score = 0
|
||||
for i, l in enumerate(word):
|
||||
score += letter_scores[i][l]
|
||||
word_scores.append((score, word))
|
||||
|
||||
sorted_by_second = sorted(word_scores, key=lambda tup: tup[0])[::-1]
|
||||
print(sorted_by_second[:10])
|
||||
|
||||
for i, (score, word) in enumerate(sorted_by_second):
|
||||
if word == 'soare':
|
||||
print(f'{word} with a score of {score} is found at index {i}')
|
||||
|
@@ -1,18 +0,0 @@
|
||||
import argparse
|
||||
from ai import AI
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.n is None:
|
||||
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||
|
||||
ai = AI(args.vocab_file)
|
||||
ai.solve()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@@ -1,15 +0,0 @@
|
||||
import pandas
|
||||
|
||||
print('Loading in words dictionary; this may take a while...')
|
||||
df = pandas.read_json('words_dictionary.json')
|
||||
print('Done loading words dictionary.')
|
||||
words = []
|
||||
for word in df.axes[0].tolist():
|
||||
if len(word) != 5:
|
||||
continue
|
||||
words.append(word)
|
||||
words.sort()
|
||||
|
||||
with open('words.txt', 'w') as f:
|
||||
for word in words:
|
||||
f.write(word + '\n')
|
File diff suppressed because it is too large
Load Diff
108
letter_guess.py
108
letter_guess.py
@@ -1,108 +0,0 @@
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
class LetterGuessingEnv(gym.Env):
|
||||
"""
|
||||
Custom Gymnasium environment for a letter guessing game with a focus on forming
|
||||
valid prefixes and words from a list of valid Wordle words. The environment tracks
|
||||
the current guess prefix and validates it against known valid words, ending the game
|
||||
early with a negative reward for invalid prefixes.
|
||||
"""
|
||||
|
||||
metadata = {'render_modes': ['human']}
|
||||
|
||||
def __init__(self, valid_words, seed=None):
|
||||
self.action_space = spaces.Discrete(26)
|
||||
self.observation_space = spaces.Box(low=0, high=1, shape=(26*2 + 26*4,), dtype=np.int32)
|
||||
|
||||
self.valid_words = valid_words # List of valid Wordle words
|
||||
self.target_word = '' # Target word for the current episode
|
||||
self.valid_words_str = ' '.join(self.valid_words) + ' '
|
||||
self.letter_flags = None
|
||||
self.letter_positions = None
|
||||
self.guessed_letters = set()
|
||||
self.guess_prefix = "" # Tracks the current guess prefix
|
||||
|
||||
self.reset()
|
||||
|
||||
def step(self, action):
|
||||
letter_index = action % 26 # Assuming action is the letter index directly
|
||||
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||
letter = chr(ord('a') + letter_index)
|
||||
|
||||
reward = 0
|
||||
done = False
|
||||
|
||||
# Check if the letter has already been used in the guess prefix
|
||||
if letter in self.guessed_letters:
|
||||
reward = -1 # Penalize for repeating letters in the prefix
|
||||
else:
|
||||
# Add the new letter to the prefix and update guessed letters set
|
||||
self.guess_prefix += letter
|
||||
self.guessed_letters.add(letter)
|
||||
|
||||
# Update letter flags based on whether the letter is in the target word
|
||||
if self.target_word[position] == letter:
|
||||
self.letter_flags[letter_index, :] = [1, 0] # Update flag for correct guess
|
||||
elif letter in self.target_word:
|
||||
self.letter_flags[letter_index, :] = [0, 1] # Update flag for correct guess wrong position
|
||||
else:
|
||||
self.letter_flags[letter_index, :] = [0, 0] # Update flag for incorrect guess
|
||||
|
||||
reward = 1 # Reward for adding new information by trying a new letter
|
||||
|
||||
# Update the letter_positions matrix to reflect the new guess
|
||||
if position == 4:
|
||||
self.letter_positions[:,:] = 1
|
||||
else:
|
||||
self.letter_positions[:, position] = 0
|
||||
self.letter_positions[letter_index, position] = 1
|
||||
|
||||
# Use regex to check if the current prefix can lead to a valid word
|
||||
if not re.search(r'\b' + self.guess_prefix, self.valid_words_str):
|
||||
reward = -5 # Penalize for forming an invalid prefix
|
||||
done = True # End the episode if the prefix is invalid
|
||||
|
||||
# guessed a full word so we reset our guess prefix to guess next round
|
||||
if len(self.guess_prefix) == len(self.target_word):
|
||||
self.guess_prefix = ''
|
||||
self.round += 1
|
||||
|
||||
# end after 5 rounds of total guesses
|
||||
if self.round == 2:
|
||||
# reward = 5
|
||||
done = True
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
if reward < -50:
|
||||
print(obs, reward, done)
|
||||
|
||||
return obs, reward, done, False, {}
|
||||
|
||||
def reset(self, seed=None):
|
||||
self.target_word = random.choice(self.valid_words)
|
||||
# self.target_word_encoded = self.encode_word(self.target_word)
|
||||
self.letter_flags = np.ones((26, 2), dtype=np.int32)
|
||||
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
||||
self.guessed_letters = set()
|
||||
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||
self.round = 1
|
||||
return self._get_obs(), {}
|
||||
|
||||
def encode_word(self, word):
|
||||
encoded = np.zeros((26,))
|
||||
for char in word:
|
||||
index = ord(char) - ord('a')
|
||||
encoded[index] = 1
|
||||
return encoded
|
||||
|
||||
def _get_obs(self):
|
||||
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||
|
||||
def render(self, mode='human'):
|
||||
pass # Optional: Implement rendering logic if needed
|
Reference in New Issue
Block a user