mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-23 10:59:21 +00:00
Compare commits
4 Commits
9172326013
...
ethan-test
Author | SHA1 | Date | |
---|---|---|---|
|
335d56ac88 | ||
|
496b8ad796 | ||
|
8d3ce990e3 | ||
|
7ad5b97463 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,3 +1,3 @@
|
|||||||
**/data/*
|
**/data/*
|
||||||
|
/env
|
||||||
**/*.zip
|
**/*.zip
|
||||||
**/__pycache__
|
|
129
Gym-Wordle-main/.gitignore
vendored
Normal file
129
Gym-Wordle-main/.gitignore
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
21
Gym-Wordle-main/LICENSE
Normal file
21
Gym-Wordle-main/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 David Kraemer
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
78
Gym-Wordle-main/README.md
Normal file
78
Gym-Wordle-main/README.md
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# Gym-Wordle
|
||||||
|
|
||||||
|
An OpenAI gym compatible environment for training agents to play Wordle.
|
||||||
|
|
||||||
|
<p align='center'>
|
||||||
|
<img src="https://user-images.githubusercontent.com/8514041/152437216-d78e85f6-8049-4cb9-ae61-3c015a8a0e4f.gif"><br/>
|
||||||
|
<em>User-input demo of the environment</em>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
My goal is for a minimalist package that lets you install quickly and get on
|
||||||
|
with your research. Installation is just a simple call to `pip`:
|
||||||
|
|
||||||
|
```
|
||||||
|
$ pip install gym_wordle
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
In keeping with my desire to have a minimalist package, there are only three
|
||||||
|
major requirements:
|
||||||
|
|
||||||
|
* `numpy`
|
||||||
|
* `gym`
|
||||||
|
* `sty`, a lovely little package for stylizing text in terminals
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The basic flow for training agents with the `Wordle-v0` environment is the same
|
||||||
|
as with gym environments generally:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
import gym
|
||||||
|
import gym_wordle
|
||||||
|
|
||||||
|
eng = gym.make("Wordle-v0")
|
||||||
|
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action = ... # RL magic
|
||||||
|
state, reward, done, info = env.step(action)
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're like millions of other people, you're a Wordle-obsessive in your own
|
||||||
|
right. I have good news for you! The `Wordle-v0` environment currently has an
|
||||||
|
implemented `render` method, which allows you to see a human-friendly version
|
||||||
|
of the game. And it isn't so hard to set up the environment to play for
|
||||||
|
yourself. Here's an example script:
|
||||||
|
|
||||||
|
```Python
|
||||||
|
from gym_wordle.utils import play
|
||||||
|
|
||||||
|
play()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
Coming soon!
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Coming soon!
|
||||||
|
|
||||||
|
## Citing
|
||||||
|
|
||||||
|
If you decide to use this project in your work, please consider a citation!
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{gym_wordle,
|
||||||
|
author = {Kraemer, David},
|
||||||
|
title = {An Environment for Reinforcement Learning with Wordle},
|
||||||
|
year = {2022},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/DavidNKraemer/Gym-Wordle}},
|
||||||
|
}
|
||||||
|
```
|
7
Gym-Wordle-main/gym-wordle.toml
Normal file
7
Gym-Wordle-main/gym-wordle.toml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[build-system]
|
||||||
|
|
||||||
|
requires = [
|
||||||
|
"setuptools>=42",
|
||||||
|
"wheel"
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
@@ -91,3 +91,4 @@ def play():
|
|||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
print(f"The word was {solution}")
|
print(f"The word was {solution}")
|
||||||
|
|
@@ -7,6 +7,7 @@ from collections import Counter
|
|||||||
from gym_wordle.utils import to_english, to_array, get_words
|
from gym_wordle.utils import to_english, to_array, get_words
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class WordList(gym.spaces.Discrete):
|
class WordList(gym.spaces.Discrete):
|
||||||
"""Super class for defining a space of valid words according to a specified
|
"""Super class for defining a space of valid words according to a specified
|
||||||
list.
|
list.
|
||||||
@@ -205,7 +206,8 @@ class WordleEnv(gym.Env):
|
|||||||
self.round = 0
|
self.round = 0
|
||||||
self.solution = self.solution_space.sample()
|
self.solution = self.solution_space.sample()
|
||||||
|
|
||||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
self.state = np.zeros((self.n_rounds, 2 * self.n_letters),
|
||||||
|
dtype=np.int64)
|
||||||
|
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
@@ -261,14 +263,14 @@ class WordleEnv(gym.Env):
|
|||||||
counter[char] += 1
|
counter[char] += 1
|
||||||
|
|
||||||
if char == solution[i]: # character is in correct position
|
if char == solution[i]: # character is in correct position
|
||||||
self.state[self.round, flag_i] = self.right_pos
|
self.state[self.round, i] = self.right_pos
|
||||||
elif counter[char] <= (char == solution).sum():
|
elif counter[char] <= (char == solution).sum():
|
||||||
# current character has been seen within correct number of
|
# current character has been seen within correct number of
|
||||||
# occurrences
|
# occurrences
|
||||||
self.state[self.round, flag_i] = self.wrong_pos
|
self.state[self.round, i] = self.wrong_pos
|
||||||
else:
|
else:
|
||||||
# wrong character, or "correct" character too many times
|
# wrong character, or "correct" character too many times
|
||||||
self.state[self.round, flag_i] = self.wrong_char
|
self.state[self.round, i] = self.wrong_char
|
||||||
|
|
||||||
self.round += 1
|
self.round += 1
|
||||||
|
|
||||||
@@ -278,19 +280,7 @@ class WordleEnv(gym.Env):
|
|||||||
done = correct or game_over
|
done = correct or game_over
|
||||||
|
|
||||||
# Total reward equals -(number of incorrect guesses)
|
# Total reward equals -(number of incorrect guesses)
|
||||||
# reward = 0. if correct else -1.
|
reward = 0. if correct else -1.
|
||||||
|
|
||||||
# correct +10
|
return self.state, reward, done, {}
|
||||||
# guesses new letter +1
|
|
||||||
# guesses correct letter +1
|
|
||||||
# spent another guess -1
|
|
||||||
|
|
||||||
reward = 0
|
|
||||||
reward += np.sum(self.state[:, 5:] == 1) * 1
|
|
||||||
reward += np.sum(self.state[:, 5:] == 2) * 0.5
|
|
||||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
|
||||||
reward += 10 if correct else -10 if done else 0
|
|
||||||
|
|
||||||
info = {'correct': correct}
|
|
||||||
|
|
||||||
return self.state, reward, done, info
|
|
35
Gym-Wordle-main/setup.py
Normal file
35
Gym-Wordle-main/setup.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
with open('README.md', 'r', encoding='utf-8') as fh:
|
||||||
|
long_description = fh.read()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='gym_wordle',
|
||||||
|
version='0.1.3',
|
||||||
|
author='David Kraemer',
|
||||||
|
author_email='david.kraemer@stonybrook.edu',
|
||||||
|
description='OpenAI gym environment for training agents on Wordle',
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
url='https://github.com/DavidNKraemer/Gym-Wordle',
|
||||||
|
packages=find_packages(
|
||||||
|
include=[
|
||||||
|
'gym_wordle',
|
||||||
|
'gym_wordle.*'
|
||||||
|
]
|
||||||
|
),
|
||||||
|
package_data={
|
||||||
|
'gym_wordle': ['dictionary/*']
|
||||||
|
},
|
||||||
|
python_requires='>=3.7',
|
||||||
|
classifiers=[
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
],
|
||||||
|
install_requires=[
|
||||||
|
'numpy>=1.20',
|
||||||
|
'gym==0.19',
|
||||||
|
'sty==1.0',
|
||||||
|
],
|
||||||
|
)
|
91
custom_env/agent.py
Normal file
91
custom_env/agent.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Agent:
|
||||||
|
|
||||||
|
def __init__(self, ) -> None:
|
||||||
|
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
||||||
|
# GAMMA is the discount factor as mentioned in the previous section
|
||||||
|
# EPS_START is the starting value of epsilon
|
||||||
|
# EPS_END is the final value of epsilon
|
||||||
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||||
|
# TAU is the update rate of the target network
|
||||||
|
# LR is the learning rate of the ``AdamW`` optimizer
|
||||||
|
self.batch_size = 128
|
||||||
|
self.gamma = 0.99
|
||||||
|
self.eps_start = 0.9
|
||||||
|
self.eps_end = 0.05
|
||||||
|
self.eps_decay = 1000
|
||||||
|
self.tau = 0.005
|
||||||
|
self.lr = 1e-4
|
||||||
|
self.n_actions = n_actions
|
||||||
|
|
||||||
|
policy_net = DQN(n_observations, n_actions).to(device)
|
||||||
|
target_net = DQN(n_observations, n_actions).to(device)
|
||||||
|
target_net.load_state_dict(policy_net.state_dict())
|
||||||
|
|
||||||
|
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
||||||
|
memory = ReplayMemory(10000)
|
||||||
|
|
||||||
|
def get_state(self, game):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def select_action(state):
|
||||||
|
sample = random.random()
|
||||||
|
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
|
||||||
|
math.exp(-1. * steps_done / EPS_DECAY)
|
||||||
|
steps_done += 1
|
||||||
|
if sample > eps_threshold:
|
||||||
|
with torch.no_grad():
|
||||||
|
# t.max(1) will return the largest column value of each row.
|
||||||
|
# second column on max result is index of where max element was
|
||||||
|
# found, so we pick action with the larger expected reward.
|
||||||
|
return policy_net(state).max(1).indices.view(1, 1)
|
||||||
|
else:
|
||||||
|
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
def optimize_model():
|
||||||
|
if len(memory) < BATCH_SIZE:
|
||||||
|
return
|
||||||
|
transitions = memory.sample(BATCH_SIZE)
|
||||||
|
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
||||||
|
# detailed explanation). This converts batch-array of Transitions
|
||||||
|
# to Transition of batch-arrays.
|
||||||
|
batch = Transition(*zip(*transitions))
|
||||||
|
|
||||||
|
# Compute a mask of non-final states and concatenate the batch elements
|
||||||
|
# (a final state would've been the one after which simulation ended)
|
||||||
|
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
||||||
|
batch.next_state)), device=device, dtype=torch.bool)
|
||||||
|
non_final_next_states = torch.cat([s for s in batch.next_state
|
||||||
|
if s is not None])
|
||||||
|
state_batch = torch.cat(batch.state)
|
||||||
|
action_batch = torch.cat(batch.action)
|
||||||
|
reward_batch = torch.cat(batch.reward)
|
||||||
|
|
||||||
|
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
||||||
|
# columns of actions taken. These are the actions which would've been taken
|
||||||
|
# for each batch state according to policy_net
|
||||||
|
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||||||
|
|
||||||
|
# Compute V(s_{t+1}) for all next states.
|
||||||
|
# Expected values of actions for non_final_next_states are computed based
|
||||||
|
# on the "older" target_net; selecting their best reward with max(1).values
|
||||||
|
# This is merged based on the mask, such that we'll have either the expected
|
||||||
|
# state value or 0 in case the state was final.
|
||||||
|
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
||||||
|
with torch.no_grad():
|
||||||
|
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
|
||||||
|
# Compute the expected Q values
|
||||||
|
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
||||||
|
|
||||||
|
# Compute Huber loss
|
||||||
|
criterion = nn.SmoothL1Loss()
|
||||||
|
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
||||||
|
|
||||||
|
# Optimize the model
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
# In-place gradient clipping
|
||||||
|
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
||||||
|
optimizer.step()
|
16
custom_env/create_wordlist.py
Normal file
16
custom_env/create_wordlist.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
from string import ascii_letters
|
||||||
|
|
||||||
|
in_path = pathlib.Path(sys.argv[1])
|
||||||
|
out_path = pathlib.Path(sys.argv[2])
|
||||||
|
|
||||||
|
words = sorted(
|
||||||
|
{
|
||||||
|
word.lower()
|
||||||
|
for word in in_path.read_text(encoding="utf-8").split()
|
||||||
|
if all(letter in ascii_letters for letter in word)
|
||||||
|
},
|
||||||
|
key=lambda word: (len(word), word),
|
||||||
|
)
|
||||||
|
out_path.write_text("\n".join(words))
|
5757
custom_env/five_letter_words.txt
Normal file
5757
custom_env/five_letter_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
44
custom_env/model.py
Normal file
44
custom_env/model.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from collections import namedtuple, deque
|
||||||
|
from itertools import count
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
Transition = namedtuple('Transition',
|
||||||
|
('state', 'action', 'next_state', 'reward'))
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayMemory(object):
|
||||||
|
|
||||||
|
def __init__(self, capacity: int) -> None:
|
||||||
|
self.memory = deque([], maxlen=capacity)
|
||||||
|
|
||||||
|
def push(self, *args):
|
||||||
|
self.memory.append(Transition(*args))
|
||||||
|
|
||||||
|
def sample(self, batch_size):
|
||||||
|
return random.sample(self.memory, batch_size)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.memory)
|
||||||
|
|
||||||
|
|
||||||
|
class DQN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_observations: int, n_actions: int) -> None:
|
||||||
|
super(DQN, self).__init__()
|
||||||
|
self.layer1 = nn.Linear(n_observations, 128)
|
||||||
|
self.layer2 = nn.Linear(128, 128)
|
||||||
|
self.layer3 = nn.Linear(128, n_actions)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(self.layer1(x))
|
||||||
|
x = F.relu(self.layer2(x))
|
||||||
|
return self.layer3(x)
|
61
custom_env/test2.ipynb
Normal file
61
custom_env/test2.ipynb
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from string import ascii_letters, ascii_uppercase, ascii_lowercase"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'ABCDEFGHIJKLMNOPQRSTUVWXYZ'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ascii_uppercase"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "env",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
1098
custom_env/wordlist.txt
Normal file
1098
custom_env/wordlist.txt
Normal file
File diff suppressed because it is too large
Load Diff
119
custom_env/wyrdl.py
Normal file
119
custom_env/wyrdl.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import contextlib
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
from string import ascii_letters, ascii_lowercase
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.theme import Theme
|
||||||
|
|
||||||
|
console = Console(width=40, theme=Theme({"warning": "red on yellow"}))
|
||||||
|
|
||||||
|
NUM_LETTERS = 5
|
||||||
|
NUM_GUESSES = 6
|
||||||
|
WORDS_PATH = pathlib.Path(__file__).parent / "wordlist.txt"
|
||||||
|
|
||||||
|
|
||||||
|
class Wordle:
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.word_list = WORDS_PATH.read_text(encoding="utf-8").split("\n")
|
||||||
|
self.n_guesses = 6
|
||||||
|
self.num_letters = 5
|
||||||
|
self.curr_word = None
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def refresh_page(self, headline):
|
||||||
|
console.clear()
|
||||||
|
console.rule(f"[bold blue]:leafy_green: {headline} :leafy_green:[/]\n")
|
||||||
|
|
||||||
|
def start_game(self):
|
||||||
|
# get a new random word
|
||||||
|
word = self.get_random_word(self.word_list)
|
||||||
|
|
||||||
|
self.curr_word = word
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def action_to_word(self, action):
|
||||||
|
# Calculate the word from the array
|
||||||
|
word = ''
|
||||||
|
for i in range(0, len(ascii_lowercase), 26):
|
||||||
|
# Find the index of 1 in each block of 26
|
||||||
|
letter_index = action[i:i+26].index(1)
|
||||||
|
# Append the corresponding letter to the word
|
||||||
|
word += ascii_lowercase[letter_index]
|
||||||
|
|
||||||
|
return word
|
||||||
|
|
||||||
|
def play_guess(self, action):
|
||||||
|
# probably an array of length 26 * 5 for 26 letters and 5 positions
|
||||||
|
guess = action
|
||||||
|
|
||||||
|
def get_random_word(self, word_list):
|
||||||
|
if words := [
|
||||||
|
word.upper()
|
||||||
|
for word in word_list
|
||||||
|
if len(word) == NUM_LETTERS
|
||||||
|
and all(letter in ascii_letters for letter in word)
|
||||||
|
]:
|
||||||
|
return random.choice(words)
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
f"No words of length {NUM_LETTERS} in the word list",
|
||||||
|
style="warning",
|
||||||
|
)
|
||||||
|
raise SystemExit()
|
||||||
|
|
||||||
|
def show_guesses(self, guesses, word):
|
||||||
|
letter_status = {letter: letter for letter in ascii_lowercase}
|
||||||
|
for guess in guesses:
|
||||||
|
styled_guess = []
|
||||||
|
for letter, correct in zip(guess, word):
|
||||||
|
if letter == correct:
|
||||||
|
style = "bold white on green"
|
||||||
|
elif letter in word:
|
||||||
|
style = "bold white on yellow"
|
||||||
|
elif letter in ascii_letters:
|
||||||
|
style = "white on #666666"
|
||||||
|
else:
|
||||||
|
style = "dim"
|
||||||
|
styled_guess.append(f"[{style}]{letter}[/]")
|
||||||
|
if letter != "_":
|
||||||
|
letter_status[letter] = f"[{style}]{letter}[/]"
|
||||||
|
|
||||||
|
console.print("".join(styled_guess), justify="center")
|
||||||
|
console.print("\n" + "".join(letter_status.values()), justify="center")
|
||||||
|
|
||||||
|
def guess_word(self, previous_guesses):
|
||||||
|
guess = console.input("\nGuess word: ").upper()
|
||||||
|
|
||||||
|
if guess in previous_guesses:
|
||||||
|
console.print(f"You've already guessed {guess}.", style="warning")
|
||||||
|
return guess_word(previous_guesses)
|
||||||
|
|
||||||
|
if len(guess) != NUM_LETTERS:
|
||||||
|
console.print(
|
||||||
|
f"Your guess must be {NUM_LETTERS} letters.", style="warning"
|
||||||
|
)
|
||||||
|
return guess_word(previous_guesses)
|
||||||
|
|
||||||
|
if any((invalid := letter) not in ascii_letters for letter in guess):
|
||||||
|
console.print(
|
||||||
|
f"Invalid letter: '{invalid}'. Please use English letters.",
|
||||||
|
style="warning",
|
||||||
|
)
|
||||||
|
return guess_word(previous_guesses)
|
||||||
|
|
||||||
|
return guess
|
||||||
|
|
||||||
|
def reset(self, guesses, word, guessed_correctly, n_episodes):
|
||||||
|
refresh_page(headline=f"Game: {n_episodes}")
|
||||||
|
|
||||||
|
if guessed_correctly:
|
||||||
|
console.print(f"\n[bold white on green]Correct, the word is {word}[/]")
|
||||||
|
else:
|
||||||
|
console.print(f"\n[bold white on red]Sorry, the word was {word}[/]")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1144
dqn_wordle.ipynb
1144
dqn_wordle.ipynb
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user