mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-22 18:49:21 +00:00
Compare commits
7 Commits
ethan-test
...
9172326013
Author | SHA1 | Date | |
---|---|---|---|
|
9172326013 | ||
|
4836be8121 | ||
|
5672169073 | ||
|
5ec123e0f1 | ||
|
e9622b6f68 | ||
|
83e81722d2 | ||
|
320f2f81b7 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,3 +1,3 @@
|
||||
**/data/*
|
||||
/env
|
||||
**/*.zip
|
||||
**/data/*
|
||||
**/*.zip
|
||||
**/__pycache__
|
129
Gym-Wordle-main/.gitignore
vendored
129
Gym-Wordle-main/.gitignore
vendored
@@ -1,129 +0,0 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 David Kraemer
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@@ -1,78 +0,0 @@
|
||||
# Gym-Wordle
|
||||
|
||||
An OpenAI gym compatible environment for training agents to play Wordle.
|
||||
|
||||
<p align='center'>
|
||||
<img src="https://user-images.githubusercontent.com/8514041/152437216-d78e85f6-8049-4cb9-ae61-3c015a8a0e4f.gif"><br/>
|
||||
<em>User-input demo of the environment</em>
|
||||
</p>
|
||||
|
||||
## Installation
|
||||
|
||||
My goal is for a minimalist package that lets you install quickly and get on
|
||||
with your research. Installation is just a simple call to `pip`:
|
||||
|
||||
```
|
||||
$ pip install gym_wordle
|
||||
```
|
||||
|
||||
### Requirements
|
||||
|
||||
In keeping with my desire to have a minimalist package, there are only three
|
||||
major requirements:
|
||||
|
||||
* `numpy`
|
||||
* `gym`
|
||||
* `sty`, a lovely little package for stylizing text in terminals
|
||||
|
||||
## Usage
|
||||
|
||||
The basic flow for training agents with the `Wordle-v0` environment is the same
|
||||
as with gym environments generally:
|
||||
|
||||
```Python
|
||||
import gym
|
||||
import gym_wordle
|
||||
|
||||
eng = gym.make("Wordle-v0")
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
action = ... # RL magic
|
||||
state, reward, done, info = env.step(action)
|
||||
```
|
||||
|
||||
If you're like millions of other people, you're a Wordle-obsessive in your own
|
||||
right. I have good news for you! The `Wordle-v0` environment currently has an
|
||||
implemented `render` method, which allows you to see a human-friendly version
|
||||
of the game. And it isn't so hard to set up the environment to play for
|
||||
yourself. Here's an example script:
|
||||
|
||||
```Python
|
||||
from gym_wordle.utils import play
|
||||
|
||||
play()
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
Coming soon!
|
||||
|
||||
## Examples
|
||||
|
||||
Coming soon!
|
||||
|
||||
## Citing
|
||||
|
||||
If you decide to use this project in your work, please consider a citation!
|
||||
|
||||
```bibtex
|
||||
@misc{gym_wordle,
|
||||
author = {Kraemer, David},
|
||||
title = {An Environment for Reinforcement Learning with Wordle},
|
||||
year = {2022},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/DavidNKraemer/Gym-Wordle}},
|
||||
}
|
||||
```
|
@@ -1,7 +0,0 @@
|
||||
[build-system]
|
||||
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"wheel"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
@@ -1,35 +0,0 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md', 'r', encoding='utf-8') as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(
|
||||
name='gym_wordle',
|
||||
version='0.1.3',
|
||||
author='David Kraemer',
|
||||
author_email='david.kraemer@stonybrook.edu',
|
||||
description='OpenAI gym environment for training agents on Wordle',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/DavidNKraemer/Gym-Wordle',
|
||||
packages=find_packages(
|
||||
include=[
|
||||
'gym_wordle',
|
||||
'gym_wordle.*'
|
||||
]
|
||||
),
|
||||
package_data={
|
||||
'gym_wordle': ['dictionary/*']
|
||||
},
|
||||
python_requires='>=3.7',
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
install_requires=[
|
||||
'numpy>=1.20',
|
||||
'gym==0.19',
|
||||
'sty==1.0',
|
||||
],
|
||||
)
|
@@ -1,91 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Agent:
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
# EPS_START is the starting value of epsilon
|
||||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
# TAU is the update rate of the target network
|
||||
# LR is the learning rate of the ``AdamW`` optimizer
|
||||
self.batch_size = 128
|
||||
self.gamma = 0.99
|
||||
self.eps_start = 0.9
|
||||
self.eps_end = 0.05
|
||||
self.eps_decay = 1000
|
||||
self.tau = 0.005
|
||||
self.lr = 1e-4
|
||||
self.n_actions = n_actions
|
||||
|
||||
policy_net = DQN(n_observations, n_actions).to(device)
|
||||
target_net = DQN(n_observations, n_actions).to(device)
|
||||
target_net.load_state_dict(policy_net.state_dict())
|
||||
|
||||
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
||||
memory = ReplayMemory(10000)
|
||||
|
||||
def get_state(self, game):
|
||||
pass
|
||||
|
||||
def select_action(state):
|
||||
sample = random.random()
|
||||
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
|
||||
math.exp(-1. * steps_done / EPS_DECAY)
|
||||
steps_done += 1
|
||||
if sample > eps_threshold:
|
||||
with torch.no_grad():
|
||||
# t.max(1) will return the largest column value of each row.
|
||||
# second column on max result is index of where max element was
|
||||
# found, so we pick action with the larger expected reward.
|
||||
return policy_net(state).max(1).indices.view(1, 1)
|
||||
else:
|
||||
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
|
||||
|
||||
def optimize_model():
|
||||
if len(memory) < BATCH_SIZE:
|
||||
return
|
||||
transitions = memory.sample(BATCH_SIZE)
|
||||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
||||
# detailed explanation). This converts batch-array of Transitions
|
||||
# to Transition of batch-arrays.
|
||||
batch = Transition(*zip(*transitions))
|
||||
|
||||
# Compute a mask of non-final states and concatenate the batch elements
|
||||
# (a final state would've been the one after which simulation ended)
|
||||
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
||||
batch.next_state)), device=device, dtype=torch.bool)
|
||||
non_final_next_states = torch.cat([s for s in batch.next_state
|
||||
if s is not None])
|
||||
state_batch = torch.cat(batch.state)
|
||||
action_batch = torch.cat(batch.action)
|
||||
reward_batch = torch.cat(batch.reward)
|
||||
|
||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
||||
# columns of actions taken. These are the actions which would've been taken
|
||||
# for each batch state according to policy_net
|
||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||||
|
||||
# Compute V(s_{t+1}) for all next states.
|
||||
# Expected values of actions for non_final_next_states are computed based
|
||||
# on the "older" target_net; selecting their best reward with max(1).values
|
||||
# This is merged based on the mask, such that we'll have either the expected
|
||||
# state value or 0 in case the state was final.
|
||||
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
||||
with torch.no_grad():
|
||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
|
||||
# Compute the expected Q values
|
||||
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
||||
|
||||
# Compute Huber loss
|
||||
criterion = nn.SmoothL1Loss()
|
||||
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
||||
|
||||
# Optimize the model
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# In-place gradient clipping
|
||||
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
||||
optimizer.step()
|
@@ -1,16 +0,0 @@
|
||||
import pathlib
|
||||
import sys
|
||||
from string import ascii_letters
|
||||
|
||||
in_path = pathlib.Path(sys.argv[1])
|
||||
out_path = pathlib.Path(sys.argv[2])
|
||||
|
||||
words = sorted(
|
||||
{
|
||||
word.lower()
|
||||
for word in in_path.read_text(encoding="utf-8").split()
|
||||
if all(letter in ascii_letters for letter in word)
|
||||
},
|
||||
key=lambda word: (len(word), word),
|
||||
)
|
||||
out_path.write_text("\n".join(words))
|
File diff suppressed because it is too large
Load Diff
@@ -1,44 +0,0 @@
|
||||
import math
|
||||
import random
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import namedtuple, deque
|
||||
from itertools import count
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
Transition = namedtuple('Transition',
|
||||
('state', 'action', 'next_state', 'reward'))
|
||||
|
||||
|
||||
class ReplayMemory(object):
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.memory = deque([], maxlen=capacity)
|
||||
|
||||
def push(self, *args):
|
||||
self.memory.append(Transition(*args))
|
||||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
|
||||
def __init__(self, n_observations: int, n_actions: int) -> None:
|
||||
super(DQN, self).__init__()
|
||||
self.layer1 = nn.Linear(n_observations, 128)
|
||||
self.layer2 = nn.Linear(128, 128)
|
||||
self.layer3 = nn.Linear(128, n_actions)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.layer1(x))
|
||||
x = F.relu(self.layer2(x))
|
||||
return self.layer3(x)
|
@@ -1,61 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from string import ascii_letters, ascii_uppercase, ascii_lowercase"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'ABCDEFGHIJKLMNOPQRSTUVWXYZ'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ascii_uppercase"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@@ -1,119 +0,0 @@
|
||||
import contextlib
|
||||
import pathlib
|
||||
import random
|
||||
from string import ascii_letters, ascii_lowercase
|
||||
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
|
||||
console = Console(width=40, theme=Theme({"warning": "red on yellow"}))
|
||||
|
||||
NUM_LETTERS = 5
|
||||
NUM_GUESSES = 6
|
||||
WORDS_PATH = pathlib.Path(__file__).parent / "wordlist.txt"
|
||||
|
||||
|
||||
class Wordle:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.word_list = WORDS_PATH.read_text(encoding="utf-8").split("\n")
|
||||
self.n_guesses = 6
|
||||
self.num_letters = 5
|
||||
self.curr_word = None
|
||||
self.reset()
|
||||
|
||||
def refresh_page(self, headline):
|
||||
console.clear()
|
||||
console.rule(f"[bold blue]:leafy_green: {headline} :leafy_green:[/]\n")
|
||||
|
||||
def start_game(self):
|
||||
# get a new random word
|
||||
word = self.get_random_word(self.word_list)
|
||||
|
||||
self.curr_word = word
|
||||
|
||||
def get_state(self):
|
||||
return
|
||||
|
||||
def action_to_word(self, action):
|
||||
# Calculate the word from the array
|
||||
word = ''
|
||||
for i in range(0, len(ascii_lowercase), 26):
|
||||
# Find the index of 1 in each block of 26
|
||||
letter_index = action[i:i+26].index(1)
|
||||
# Append the corresponding letter to the word
|
||||
word += ascii_lowercase[letter_index]
|
||||
|
||||
return word
|
||||
|
||||
def play_guess(self, action):
|
||||
# probably an array of length 26 * 5 for 26 letters and 5 positions
|
||||
guess = action
|
||||
|
||||
def get_random_word(self, word_list):
|
||||
if words := [
|
||||
word.upper()
|
||||
for word in word_list
|
||||
if len(word) == NUM_LETTERS
|
||||
and all(letter in ascii_letters for letter in word)
|
||||
]:
|
||||
return random.choice(words)
|
||||
else:
|
||||
console.print(
|
||||
f"No words of length {NUM_LETTERS} in the word list",
|
||||
style="warning",
|
||||
)
|
||||
raise SystemExit()
|
||||
|
||||
def show_guesses(self, guesses, word):
|
||||
letter_status = {letter: letter for letter in ascii_lowercase}
|
||||
for guess in guesses:
|
||||
styled_guess = []
|
||||
for letter, correct in zip(guess, word):
|
||||
if letter == correct:
|
||||
style = "bold white on green"
|
||||
elif letter in word:
|
||||
style = "bold white on yellow"
|
||||
elif letter in ascii_letters:
|
||||
style = "white on #666666"
|
||||
else:
|
||||
style = "dim"
|
||||
styled_guess.append(f"[{style}]{letter}[/]")
|
||||
if letter != "_":
|
||||
letter_status[letter] = f"[{style}]{letter}[/]"
|
||||
|
||||
console.print("".join(styled_guess), justify="center")
|
||||
console.print("\n" + "".join(letter_status.values()), justify="center")
|
||||
|
||||
def guess_word(self, previous_guesses):
|
||||
guess = console.input("\nGuess word: ").upper()
|
||||
|
||||
if guess in previous_guesses:
|
||||
console.print(f"You've already guessed {guess}.", style="warning")
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
if len(guess) != NUM_LETTERS:
|
||||
console.print(
|
||||
f"Your guess must be {NUM_LETTERS} letters.", style="warning"
|
||||
)
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
if any((invalid := letter) not in ascii_letters for letter in guess):
|
||||
console.print(
|
||||
f"Invalid letter: '{invalid}'. Please use English letters.",
|
||||
style="warning",
|
||||
)
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
return guess
|
||||
|
||||
def reset(self, guesses, word, guessed_correctly, n_episodes):
|
||||
refresh_page(headline=f"Game: {n_episodes}")
|
||||
|
||||
if guessed_correctly:
|
||||
console.print(f"\n[bold white on green]Correct, the word is {word}[/]")
|
||||
else:
|
||||
console.print(f"\n[bold white on red]Sorry, the word was {word}[/]")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1144
dqn_wordle.ipynb
1144
dqn_wordle.ipynb
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -90,5 +90,4 @@ def play():
|
||||
state, reward, done, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
print(f"The word was {solution}")
|
||||
|
||||
print(f"The word was {solution}")
|
@@ -7,7 +7,6 @@ from collections import Counter
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
"""Super class for defining a space of valid words according to a specified
|
||||
list.
|
||||
@@ -206,8 +205,7 @@ class WordleEnv(gym.Env):
|
||||
self.round = 0
|
||||
self.solution = self.solution_space.sample()
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters),
|
||||
dtype=np.int64)
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||
|
||||
return self.state
|
||||
|
||||
@@ -263,14 +261,14 @@ class WordleEnv(gym.Env):
|
||||
counter[char] += 1
|
||||
|
||||
if char == solution[i]: # character is in correct position
|
||||
self.state[self.round, i] = self.right_pos
|
||||
self.state[self.round, flag_i] = self.right_pos
|
||||
elif counter[char] <= (char == solution).sum():
|
||||
# current character has been seen within correct number of
|
||||
# occurrences
|
||||
self.state[self.round, i] = self.wrong_pos
|
||||
self.state[self.round, flag_i] = self.wrong_pos
|
||||
else:
|
||||
# wrong character, or "correct" character too many times
|
||||
self.state[self.round, i] = self.wrong_char
|
||||
self.state[self.round, flag_i] = self.wrong_char
|
||||
|
||||
self.round += 1
|
||||
|
||||
@@ -280,7 +278,19 @@ class WordleEnv(gym.Env):
|
||||
done = correct or game_over
|
||||
|
||||
# Total reward equals -(number of incorrect guesses)
|
||||
reward = 0. if correct else -1.
|
||||
# reward = 0. if correct else -1.
|
||||
|
||||
# correct +10
|
||||
# guesses new letter +1
|
||||
# guesses correct letter +1
|
||||
# spent another guess -1
|
||||
|
||||
return self.state, reward, done, {}
|
||||
reward = 0
|
||||
reward += np.sum(self.state[:, 5:] == 1) * 1
|
||||
reward += np.sum(self.state[:, 5:] == 2) * 0.5
|
||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||
reward += 10 if correct else -10 if done else 0
|
||||
|
||||
info = {'correct': correct}
|
||||
|
||||
return self.state, reward, done, info
|
Reference in New Issue
Block a user