mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-23 02:59:19 +00:00
Compare commits
9 Commits
ethan-test
...
e799c14ece
Author | SHA1 | Date | |
---|---|---|---|
|
e799c14ece | ||
|
bbe9a1891c | ||
|
9172326013 | ||
|
4836be8121 | ||
|
5672169073 | ||
|
5ec123e0f1 | ||
|
e9622b6f68 | ||
|
83e81722d2 | ||
|
320f2f81b7 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
**/data/*
|
**/data/*
|
||||||
/env
|
|
||||||
**/*.zip
|
**/*.zip
|
||||||
|
**/__pycache__
|
||||||
|
/env
|
129
Gym-Wordle-main/.gitignore
vendored
129
Gym-Wordle-main/.gitignore
vendored
@@ -1,129 +0,0 @@
|
|||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
pip-wheel-metadata/
|
|
||||||
share/python-wheels/
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# IPython
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
.python-version
|
|
||||||
|
|
||||||
# pipenv
|
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# Celery stuff
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
|
||||||
.pyre/
|
|
@@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2022 David Kraemer
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
@@ -1,78 +0,0 @@
|
|||||||
# Gym-Wordle
|
|
||||||
|
|
||||||
An OpenAI gym compatible environment for training agents to play Wordle.
|
|
||||||
|
|
||||||
<p align='center'>
|
|
||||||
<img src="https://user-images.githubusercontent.com/8514041/152437216-d78e85f6-8049-4cb9-ae61-3c015a8a0e4f.gif"><br/>
|
|
||||||
<em>User-input demo of the environment</em>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
My goal is for a minimalist package that lets you install quickly and get on
|
|
||||||
with your research. Installation is just a simple call to `pip`:
|
|
||||||
|
|
||||||
```
|
|
||||||
$ pip install gym_wordle
|
|
||||||
```
|
|
||||||
|
|
||||||
### Requirements
|
|
||||||
|
|
||||||
In keeping with my desire to have a minimalist package, there are only three
|
|
||||||
major requirements:
|
|
||||||
|
|
||||||
* `numpy`
|
|
||||||
* `gym`
|
|
||||||
* `sty`, a lovely little package for stylizing text in terminals
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
The basic flow for training agents with the `Wordle-v0` environment is the same
|
|
||||||
as with gym environments generally:
|
|
||||||
|
|
||||||
```Python
|
|
||||||
import gym
|
|
||||||
import gym_wordle
|
|
||||||
|
|
||||||
eng = gym.make("Wordle-v0")
|
|
||||||
|
|
||||||
done = False
|
|
||||||
while not done:
|
|
||||||
action = ... # RL magic
|
|
||||||
state, reward, done, info = env.step(action)
|
|
||||||
```
|
|
||||||
|
|
||||||
If you're like millions of other people, you're a Wordle-obsessive in your own
|
|
||||||
right. I have good news for you! The `Wordle-v0` environment currently has an
|
|
||||||
implemented `render` method, which allows you to see a human-friendly version
|
|
||||||
of the game. And it isn't so hard to set up the environment to play for
|
|
||||||
yourself. Here's an example script:
|
|
||||||
|
|
||||||
```Python
|
|
||||||
from gym_wordle.utils import play
|
|
||||||
|
|
||||||
play()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Documentation
|
|
||||||
|
|
||||||
Coming soon!
|
|
||||||
|
|
||||||
## Examples
|
|
||||||
|
|
||||||
Coming soon!
|
|
||||||
|
|
||||||
## Citing
|
|
||||||
|
|
||||||
If you decide to use this project in your work, please consider a citation!
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{gym_wordle,
|
|
||||||
author = {Kraemer, David},
|
|
||||||
title = {An Environment for Reinforcement Learning with Wordle},
|
|
||||||
year = {2022},
|
|
||||||
publisher = {GitHub},
|
|
||||||
journal = {GitHub repository},
|
|
||||||
howpublished = {\url{https://github.com/DavidNKraemer/Gym-Wordle}},
|
|
||||||
}
|
|
||||||
```
|
|
@@ -1,7 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
|
|
||||||
requires = [
|
|
||||||
"setuptools>=42",
|
|
||||||
"wheel"
|
|
||||||
]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
@@ -1,286 +0,0 @@
|
|||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
from sty import fg, bg, ef, rs
|
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from gym_wordle.utils import to_english, to_array, get_words
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class WordList(gym.spaces.Discrete):
|
|
||||||
"""Super class for defining a space of valid words according to a specified
|
|
||||||
list.
|
|
||||||
|
|
||||||
TODO: Fix these paragraphs
|
|
||||||
The space is a subclass of gym.spaces.Discrete, where each element
|
|
||||||
corresponds to an index of a valid word in the word list. The obfuscation
|
|
||||||
is necessary for more direct implementation of RL algorithms, which expect
|
|
||||||
spaces of less sophisticated form.
|
|
||||||
|
|
||||||
In addition to the default methods of the Discrete space, it implements
|
|
||||||
a __getitem__ method for easy index lookup, and an index_of method to
|
|
||||||
convert potential words into their corresponding index (if they exist).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
words: Collection of words in array form with shape (_, 5), where
|
|
||||||
each word is a row of the array. Each array element is an integer
|
|
||||||
between 0,...,26 (inclusive).
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
super().__init__(words.shape[0], **kwargs)
|
|
||||||
self.words = words
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
|
||||||
"""Obtains the (int-encoded) word associated with the given index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index: Index for the list of words.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Associated word at the position specified by index.
|
|
||||||
"""
|
|
||||||
return self.words[index]
|
|
||||||
|
|
||||||
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
|
||||||
"""Given a word, determine its index in the list (if it exists),
|
|
||||||
otherwise returning -1 if no index exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word: Word to find in the word list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The index of the given word if it exists, otherwise -1.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
index, = np.nonzero((word == self.words).all(axis=1))
|
|
||||||
return index[0]
|
|
||||||
except:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
class SolutionList(WordList):
|
|
||||||
"""Space for *solution* words to the Wordle environment.
|
|
||||||
|
|
||||||
In the game Wordle, there are two different collections of words:
|
|
||||||
|
|
||||||
* "guesses", which the game accepts as valid words to use to guess the
|
|
||||||
answer.
|
|
||||||
* "solutions", which the game uses to choose solutions from.
|
|
||||||
|
|
||||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
|
||||||
|
|
||||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
|
||||||
|
|
||||||
This class represents the set of solution words.
|
|
||||||
"""
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
words = get_words('solution')
|
|
||||||
super().__init__(words, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleObsSpace(gym.spaces.Box):
|
|
||||||
"""Implementation of the state (observation) space in terms of gym
|
|
||||||
primatives, in this case, gym.spaces.Box.
|
|
||||||
|
|
||||||
The Wordle observation space can be thought of as a 6x5 array with two
|
|
||||||
channels:
|
|
||||||
|
|
||||||
- the character channel, indicating which characters are placed on the
|
|
||||||
board (unfilled rows are marked with the empty character, 0)
|
|
||||||
- the flag channel, indicating the in-game information associated with
|
|
||||||
each character's placement (green highlight, yellow highlight, etc.)
|
|
||||||
|
|
||||||
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
|
||||||
the solution will always be a word of length 5.
|
|
||||||
|
|
||||||
For simplicity, and compatibility with the stable_baselines algorithms,
|
|
||||||
this multichannel is modeled as a 6x10 array, where the two channels are
|
|
||||||
horizontally appended (along columns). Thus each row in the observation
|
|
||||||
should be interpreted as
|
|
||||||
|
|
||||||
c0 c1 c2 c3 c4 f0 f1 f2 f3 f4
|
|
||||||
|
|
||||||
when the word is c0...c4 and its associated flags are f0...f4.
|
|
||||||
|
|
||||||
While the superclass method `sample` is available to the WordleObsSpace, it
|
|
||||||
should be emphasized that the output of `sample` will (almost surely) not
|
|
||||||
correspond to a real game configuration, because the sampling is not out of
|
|
||||||
possible game configurations. Instead, the Box superclass just samples the
|
|
||||||
integer array space uniformly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.n_rows = 6
|
|
||||||
self.n_cols = 5
|
|
||||||
self.max_char = 26
|
|
||||||
self.max_flag = 4
|
|
||||||
|
|
||||||
low = np.zeros((self.n_rows, 2*self.n_cols))
|
|
||||||
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
|
||||||
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
|
||||||
|
|
||||||
super().__init__(low, high, dtype=np.int64, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class GuessList(WordList):
|
|
||||||
"""Space for *solution* words to the Wordle environment.
|
|
||||||
|
|
||||||
In the game Wordle, there are two different collections of words:
|
|
||||||
|
|
||||||
* "guesses", which the game accepts as valid words to use to guess the
|
|
||||||
answer.
|
|
||||||
* "solutions", which the game uses to choose solutions from.
|
|
||||||
|
|
||||||
Of course, the set of solutions is a strict subset of the set of guesses.
|
|
||||||
|
|
||||||
Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/
|
|
||||||
|
|
||||||
This class represents the set of guess words.
|
|
||||||
"""
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
kwargs: See documentation for gym.spaces.MultiDiscrete
|
|
||||||
"""
|
|
||||||
words = get_words('guess')
|
|
||||||
super().__init__(words, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class WordleEnv(gym.Env):
|
|
||||||
|
|
||||||
metadata = {'render.modes': ['human']}
|
|
||||||
|
|
||||||
# character flag codes
|
|
||||||
no_char = 0
|
|
||||||
right_pos = 1
|
|
||||||
wrong_pos = 2
|
|
||||||
wrong_char = 3
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.seed()
|
|
||||||
self.action_space = GuessList()
|
|
||||||
self.solution_space = SolutionList()
|
|
||||||
|
|
||||||
self.observation_space = WordleObsSpace()
|
|
||||||
|
|
||||||
self._highlights = {
|
|
||||||
self.right_pos: (bg.green, bg.rs),
|
|
||||||
self.wrong_pos: (bg.yellow, bg.rs),
|
|
||||||
self.wrong_char: ('', ''),
|
|
||||||
self.no_char: ('', ''),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.n_rounds = 6
|
|
||||||
self.n_letters = 5
|
|
||||||
|
|
||||||
def _highlighter(self, char: str, flag: int) -> str:
|
|
||||||
"""Terminal renderer functionality. Properly highlights a character
|
|
||||||
based on the flag associated with it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
char: Character in question.
|
|
||||||
flag: Associated flag, one of:
|
|
||||||
- 0: no character (render no background)
|
|
||||||
- 1: right position (render green background)
|
|
||||||
- 2: wrong position (render yellow background)
|
|
||||||
- 3: wrong character (render no background)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Correct ASCII sequence producing the desired character in the
|
|
||||||
correct background.
|
|
||||||
"""
|
|
||||||
front, back = self._highlights[flag]
|
|
||||||
return front + char + back
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.round = 0
|
|
||||||
self.solution = self.solution_space.sample()
|
|
||||||
|
|
||||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters),
|
|
||||||
dtype=np.int64)
|
|
||||||
|
|
||||||
return self.state
|
|
||||||
|
|
||||||
def render(self, mode: str ='human'):
|
|
||||||
"""Renders the Wordle environment.
|
|
||||||
|
|
||||||
Currently supported render modes:
|
|
||||||
|
|
||||||
- human: renders the Wordle game to the terminal.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode: the mode to render with
|
|
||||||
"""
|
|
||||||
if mode == 'human':
|
|
||||||
for row in self.states:
|
|
||||||
text = ''.join(map(
|
|
||||||
self._highlighter,
|
|
||||||
to_english(row[:self.n_letters]).upper(),
|
|
||||||
row[self.n_letters:]
|
|
||||||
))
|
|
||||||
|
|
||||||
print(text)
|
|
||||||
else:
|
|
||||||
super(WordleEnv, self).render(mode=mode)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
"""Run one step of the Wordle game. Every game must be previously
|
|
||||||
initialized by a call to the `reset` method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: Word guessed by the agent.
|
|
||||||
Returns:
|
|
||||||
state (object): Wordle game state after the guess.
|
|
||||||
reward (float): Reward associated with the guess (-1 for incorrect,
|
|
||||||
0 for correct)
|
|
||||||
done (bool): Whether the game has ended (by a correct guess or
|
|
||||||
after six guesses).
|
|
||||||
info (dict): Auxiliary diagnostic information (empty).
|
|
||||||
"""
|
|
||||||
assert self.action_space.contains(action), 'Invalid word!'
|
|
||||||
|
|
||||||
# transform the action, solution indices to their words
|
|
||||||
action = self.action_space[action]
|
|
||||||
solution = self.solution_space[self.solution]
|
|
||||||
|
|
||||||
# populate the word chars into the row (character channel)
|
|
||||||
self.state[self.round][:self.n_letters] = action
|
|
||||||
|
|
||||||
# populate the flag characters into the row (flag channel)
|
|
||||||
counter = Counter()
|
|
||||||
for i, char in enumerate(action):
|
|
||||||
flag_i = i + self.n_letters # starts at 5
|
|
||||||
counter[char] += 1
|
|
||||||
|
|
||||||
if char == solution[i]: # character is in correct position
|
|
||||||
self.state[self.round, i] = self.right_pos
|
|
||||||
elif counter[char] <= (char == solution).sum():
|
|
||||||
# current character has been seen within correct number of
|
|
||||||
# occurrences
|
|
||||||
self.state[self.round, i] = self.wrong_pos
|
|
||||||
else:
|
|
||||||
# wrong character, or "correct" character too many times
|
|
||||||
self.state[self.round, i] = self.wrong_char
|
|
||||||
|
|
||||||
self.round += 1
|
|
||||||
|
|
||||||
correct = (action == solution).all()
|
|
||||||
game_over = (self.round == self.n_rounds)
|
|
||||||
|
|
||||||
done = correct or game_over
|
|
||||||
|
|
||||||
# Total reward equals -(number of incorrect guesses)
|
|
||||||
reward = 0. if correct else -1.
|
|
||||||
|
|
||||||
return self.state, reward, done, {}
|
|
||||||
|
|
@@ -1,35 +0,0 @@
|
|||||||
from setuptools import setup, find_packages
|
|
||||||
|
|
||||||
with open('README.md', 'r', encoding='utf-8') as fh:
|
|
||||||
long_description = fh.read()
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='gym_wordle',
|
|
||||||
version='0.1.3',
|
|
||||||
author='David Kraemer',
|
|
||||||
author_email='david.kraemer@stonybrook.edu',
|
|
||||||
description='OpenAI gym environment for training agents on Wordle',
|
|
||||||
long_description=long_description,
|
|
||||||
long_description_content_type='text/markdown',
|
|
||||||
url='https://github.com/DavidNKraemer/Gym-Wordle',
|
|
||||||
packages=find_packages(
|
|
||||||
include=[
|
|
||||||
'gym_wordle',
|
|
||||||
'gym_wordle.*'
|
|
||||||
]
|
|
||||||
),
|
|
||||||
package_data={
|
|
||||||
'gym_wordle': ['dictionary/*']
|
|
||||||
},
|
|
||||||
python_requires='>=3.7',
|
|
||||||
classifiers=[
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
],
|
|
||||||
install_requires=[
|
|
||||||
'numpy>=1.20',
|
|
||||||
'gym==0.19',
|
|
||||||
'sty==1.0',
|
|
||||||
],
|
|
||||||
)
|
|
@@ -1,91 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
|
||||||
|
|
||||||
def __init__(self, ) -> None:
|
|
||||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
# TAU is the update rate of the target network
|
|
||||||
# LR is the learning rate of the ``AdamW`` optimizer
|
|
||||||
self.batch_size = 128
|
|
||||||
self.gamma = 0.99
|
|
||||||
self.eps_start = 0.9
|
|
||||||
self.eps_end = 0.05
|
|
||||||
self.eps_decay = 1000
|
|
||||||
self.tau = 0.005
|
|
||||||
self.lr = 1e-4
|
|
||||||
self.n_actions = n_actions
|
|
||||||
|
|
||||||
policy_net = DQN(n_observations, n_actions).to(device)
|
|
||||||
target_net = DQN(n_observations, n_actions).to(device)
|
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
|
||||||
|
|
||||||
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
|
||||||
memory = ReplayMemory(10000)
|
|
||||||
|
|
||||||
def get_state(self, game):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def select_action(state):
|
|
||||||
sample = random.random()
|
|
||||||
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
|
|
||||||
math.exp(-1. * steps_done / EPS_DECAY)
|
|
||||||
steps_done += 1
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return policy_net(state).max(1).indices.view(1, 1)
|
|
||||||
else:
|
|
||||||
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
def optimize_model():
|
|
||||||
if len(memory) < BATCH_SIZE:
|
|
||||||
return
|
|
||||||
transitions = memory.sample(BATCH_SIZE)
|
|
||||||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
|
||||||
# detailed explanation). This converts batch-array of Transitions
|
|
||||||
# to Transition of batch-arrays.
|
|
||||||
batch = Transition(*zip(*transitions))
|
|
||||||
|
|
||||||
# Compute a mask of non-final states and concatenate the batch elements
|
|
||||||
# (a final state would've been the one after which simulation ended)
|
|
||||||
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
|
||||||
batch.next_state)), device=device, dtype=torch.bool)
|
|
||||||
non_final_next_states = torch.cat([s for s in batch.next_state
|
|
||||||
if s is not None])
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
|
||||||
# columns of actions taken. These are the actions which would've been taken
|
|
||||||
# for each batch state according to policy_net
|
|
||||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
# Compute V(s_{t+1}) for all next states.
|
|
||||||
# Expected values of actions for non_final_next_states are computed based
|
|
||||||
# on the "older" target_net; selecting their best reward with max(1).values
|
|
||||||
# This is merged based on the mask, such that we'll have either the expected
|
|
||||||
# state value or 0 in case the state was final.
|
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
|
||||||
with torch.no_grad():
|
|
||||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
|
|
||||||
# Compute the expected Q values
|
|
||||||
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
|
||||||
|
|
||||||
# Compute Huber loss
|
|
||||||
criterion = nn.SmoothL1Loss()
|
|
||||||
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
|
||||||
|
|
||||||
# Optimize the model
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
# In-place gradient clipping
|
|
||||||
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
|
||||||
optimizer.step()
|
|
@@ -1,16 +0,0 @@
|
|||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
from string import ascii_letters
|
|
||||||
|
|
||||||
in_path = pathlib.Path(sys.argv[1])
|
|
||||||
out_path = pathlib.Path(sys.argv[2])
|
|
||||||
|
|
||||||
words = sorted(
|
|
||||||
{
|
|
||||||
word.lower()
|
|
||||||
for word in in_path.read_text(encoding="utf-8").split()
|
|
||||||
if all(letter in ascii_letters for letter in word)
|
|
||||||
},
|
|
||||||
key=lambda word: (len(word), word),
|
|
||||||
)
|
|
||||||
out_path.write_text("\n".join(words))
|
|
File diff suppressed because it is too large
Load Diff
@@ -1,44 +0,0 @@
|
|||||||
import math
|
|
||||||
import random
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from collections import namedtuple, deque
|
|
||||||
from itertools import count
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple('Transition',
|
|
||||||
('state', 'action', 'next_state', 'reward'))
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayMemory(object):
|
|
||||||
|
|
||||||
def __init__(self, capacity: int) -> None:
|
|
||||||
self.memory = deque([], maxlen=capacity)
|
|
||||||
|
|
||||||
def push(self, *args):
|
|
||||||
self.memory.append(Transition(*args))
|
|
||||||
|
|
||||||
def sample(self, batch_size):
|
|
||||||
return random.sample(self.memory, batch_size)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.memory)
|
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, n_observations: int, n_actions: int) -> None:
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
@@ -1,61 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from string import ascii_letters, ascii_uppercase, ascii_lowercase"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"'ABCDEFGHIJKLMNOPQRSTUVWXYZ'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"ascii_uppercase"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "env",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@@ -1,119 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
import pathlib
|
|
||||||
import random
|
|
||||||
from string import ascii_letters, ascii_lowercase
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.theme import Theme
|
|
||||||
|
|
||||||
console = Console(width=40, theme=Theme({"warning": "red on yellow"}))
|
|
||||||
|
|
||||||
NUM_LETTERS = 5
|
|
||||||
NUM_GUESSES = 6
|
|
||||||
WORDS_PATH = pathlib.Path(__file__).parent / "wordlist.txt"
|
|
||||||
|
|
||||||
|
|
||||||
class Wordle:
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.word_list = WORDS_PATH.read_text(encoding="utf-8").split("\n")
|
|
||||||
self.n_guesses = 6
|
|
||||||
self.num_letters = 5
|
|
||||||
self.curr_word = None
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def refresh_page(self, headline):
|
|
||||||
console.clear()
|
|
||||||
console.rule(f"[bold blue]:leafy_green: {headline} :leafy_green:[/]\n")
|
|
||||||
|
|
||||||
def start_game(self):
|
|
||||||
# get a new random word
|
|
||||||
word = self.get_random_word(self.word_list)
|
|
||||||
|
|
||||||
self.curr_word = word
|
|
||||||
|
|
||||||
def get_state(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
def action_to_word(self, action):
|
|
||||||
# Calculate the word from the array
|
|
||||||
word = ''
|
|
||||||
for i in range(0, len(ascii_lowercase), 26):
|
|
||||||
# Find the index of 1 in each block of 26
|
|
||||||
letter_index = action[i:i+26].index(1)
|
|
||||||
# Append the corresponding letter to the word
|
|
||||||
word += ascii_lowercase[letter_index]
|
|
||||||
|
|
||||||
return word
|
|
||||||
|
|
||||||
def play_guess(self, action):
|
|
||||||
# probably an array of length 26 * 5 for 26 letters and 5 positions
|
|
||||||
guess = action
|
|
||||||
|
|
||||||
def get_random_word(self, word_list):
|
|
||||||
if words := [
|
|
||||||
word.upper()
|
|
||||||
for word in word_list
|
|
||||||
if len(word) == NUM_LETTERS
|
|
||||||
and all(letter in ascii_letters for letter in word)
|
|
||||||
]:
|
|
||||||
return random.choice(words)
|
|
||||||
else:
|
|
||||||
console.print(
|
|
||||||
f"No words of length {NUM_LETTERS} in the word list",
|
|
||||||
style="warning",
|
|
||||||
)
|
|
||||||
raise SystemExit()
|
|
||||||
|
|
||||||
def show_guesses(self, guesses, word):
|
|
||||||
letter_status = {letter: letter for letter in ascii_lowercase}
|
|
||||||
for guess in guesses:
|
|
||||||
styled_guess = []
|
|
||||||
for letter, correct in zip(guess, word):
|
|
||||||
if letter == correct:
|
|
||||||
style = "bold white on green"
|
|
||||||
elif letter in word:
|
|
||||||
style = "bold white on yellow"
|
|
||||||
elif letter in ascii_letters:
|
|
||||||
style = "white on #666666"
|
|
||||||
else:
|
|
||||||
style = "dim"
|
|
||||||
styled_guess.append(f"[{style}]{letter}[/]")
|
|
||||||
if letter != "_":
|
|
||||||
letter_status[letter] = f"[{style}]{letter}[/]"
|
|
||||||
|
|
||||||
console.print("".join(styled_guess), justify="center")
|
|
||||||
console.print("\n" + "".join(letter_status.values()), justify="center")
|
|
||||||
|
|
||||||
def guess_word(self, previous_guesses):
|
|
||||||
guess = console.input("\nGuess word: ").upper()
|
|
||||||
|
|
||||||
if guess in previous_guesses:
|
|
||||||
console.print(f"You've already guessed {guess}.", style="warning")
|
|
||||||
return guess_word(previous_guesses)
|
|
||||||
|
|
||||||
if len(guess) != NUM_LETTERS:
|
|
||||||
console.print(
|
|
||||||
f"Your guess must be {NUM_LETTERS} letters.", style="warning"
|
|
||||||
)
|
|
||||||
return guess_word(previous_guesses)
|
|
||||||
|
|
||||||
if any((invalid := letter) not in ascii_letters for letter in guess):
|
|
||||||
console.print(
|
|
||||||
f"Invalid letter: '{invalid}'. Please use English letters.",
|
|
||||||
style="warning",
|
|
||||||
)
|
|
||||||
return guess_word(previous_guesses)
|
|
||||||
|
|
||||||
return guess
|
|
||||||
|
|
||||||
def reset(self, guesses, word, guessed_correctly, n_episodes):
|
|
||||||
refresh_page(headline=f"Game: {n_episodes}")
|
|
||||||
|
|
||||||
if guessed_correctly:
|
|
||||||
console.print(f"\n[bold white on green]Correct, the word is {word}[/]")
|
|
||||||
else:
|
|
||||||
console.print(f"\n[bold white on red]Sorry, the word was {word}[/]")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
643
dqn_wordle.ipynb
643
dqn_wordle.ipynb
@@ -6,9 +6,11 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import gymnasium as gym\n",
|
"import gym\n",
|
||||||
"from stable_baselines3 import DQN\n",
|
"import gym_wordle\n",
|
||||||
"import numpy as np"
|
"from stable_baselines3 import DQN, PPO, common\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import tqdm"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -17,63 +19,624 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"ename": "NameNotFound",
|
"name": "stdout",
|
||||||
"evalue": "Environment `Wordle` doesn't exist.",
|
"output_type": "stream",
|
||||||
"output_type": "error",
|
"text": [
|
||||||
"traceback": [
|
"<Monitor<WordleEnv instance>>\n"
|
||||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[1;31mNameNotFound\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mgym\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mWordle-v0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(env)\n",
|
|
||||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:741\u001b[0m, in \u001b[0;36mmake\u001b[1;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[0;32m 738\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mid\u001b[39m, \u001b[38;5;28mstr\u001b[39m)\n\u001b[0;32m 740\u001b[0m \u001b[38;5;66;03m# The environment name can include an unloaded module in \"module:env_name\" style\u001b[39;00m\n\u001b[1;32m--> 741\u001b[0m env_spec \u001b[38;5;241m=\u001b[39m \u001b[43m_find_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 743\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(env_spec, EnvSpec)\n\u001b[0;32m 745\u001b[0m \u001b[38;5;66;03m# Update the env spec kwargs with the `make` kwargs\u001b[39;00m\n",
|
|
||||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:527\u001b[0m, in \u001b[0;36m_find_spec\u001b[1;34m(env_id)\u001b[0m\n\u001b[0;32m 521\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[0;32m 522\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing the latest versioned environment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnew_env_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 523\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minstead of the unversioned environment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 524\u001b[0m )\n\u001b[0;32m 526\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m env_spec \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 527\u001b[0m \u001b[43m_check_version_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mversion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 528\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mError(\n\u001b[0;32m 529\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo registered env with id: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Did you register it, or import the package that registers it? Use `gymnasium.pprint_registry()` to see all of the registered environments.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 530\u001b[0m )\n\u001b[0;32m 532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m env_spec\n",
|
|
||||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:393\u001b[0m, in \u001b[0;36m_check_version_exists\u001b[1;34m(ns, name, version)\u001b[0m\n\u001b[0;32m 390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m get_env_id(ns, name, version) \u001b[38;5;129;01min\u001b[39;00m registry:\n\u001b[0;32m 391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m--> 393\u001b[0m \u001b[43m_check_name_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 394\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m version \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n",
|
|
||||||
"File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:370\u001b[0m, in \u001b[0;36m_check_name_exists\u001b[1;34m(ns, name)\u001b[0m\n\u001b[0;32m 367\u001b[0m namespace_msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in namespace \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mns\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ns \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 368\u001b[0m suggestion_msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Did you mean: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestion[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`?\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m suggestion \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m--> 370\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mNameNotFound(\n\u001b[0;32m 371\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEnvironment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt exist\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnamespace_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestion_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 372\u001b[0m )\n",
|
|
||||||
"\u001b[1;31mNameNotFound\u001b[0m: Environment `Wordle` doesn't exist."
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"env = gym.make(\"Wordle-v0\")\n",
|
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||||
|
"env = common.monitor.Monitor(env)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(env)"
|
"print(env)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 35,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "7c52630b65904d5e8e200be505d2121a",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Output()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Using cuda device\n",
|
||||||
|
"Wrapping the env with a `Monitor` wrapper\n",
|
||||||
|
"Wrapping the env in a DummyVecEnv.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -175 |\n",
|
||||||
|
"| exploration_rate | 0.525 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 10000 |\n",
|
||||||
|
"| fps | 4606 |\n",
|
||||||
|
"| time_elapsed | 10 |\n",
|
||||||
|
"| total_timesteps | 49989 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -208 |\n",
|
||||||
|
"| exploration_rate | 0.0502 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 20000 |\n",
|
||||||
|
"| fps | 1118 |\n",
|
||||||
|
"| time_elapsed | 89 |\n",
|
||||||
|
"| total_timesteps | 99980 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 24.6 |\n",
|
||||||
|
"| n_updates | 12494 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -230 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 30000 |\n",
|
||||||
|
"| fps | 856 |\n",
|
||||||
|
"| time_elapsed | 175 |\n",
|
||||||
|
"| total_timesteps | 149974 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 18.7 |\n",
|
||||||
|
"| n_updates | 24993 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -242 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 40000 |\n",
|
||||||
|
"| fps | 766 |\n",
|
||||||
|
"| time_elapsed | 260 |\n",
|
||||||
|
"| total_timesteps | 199967 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 24 |\n",
|
||||||
|
"| n_updates | 37491 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -186 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 50000 |\n",
|
||||||
|
"| fps | 722 |\n",
|
||||||
|
"| time_elapsed | 346 |\n",
|
||||||
|
"| total_timesteps | 249962 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 21.5 |\n",
|
||||||
|
"| n_updates | 49990 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -183 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 60000 |\n",
|
||||||
|
"| fps | 694 |\n",
|
||||||
|
"| time_elapsed | 431 |\n",
|
||||||
|
"| total_timesteps | 299957 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 17.6 |\n",
|
||||||
|
"| n_updates | 62489 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -181 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 70000 |\n",
|
||||||
|
"| fps | 675 |\n",
|
||||||
|
"| time_elapsed | 517 |\n",
|
||||||
|
"| total_timesteps | 349953 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 26.8 |\n",
|
||||||
|
"| n_updates | 74988 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -196 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 80000 |\n",
|
||||||
|
"| fps | 663 |\n",
|
||||||
|
"| time_elapsed | 603 |\n",
|
||||||
|
"| total_timesteps | 399936 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 22.5 |\n",
|
||||||
|
"| n_updates | 87483 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -174 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 90000 |\n",
|
||||||
|
"| fps | 653 |\n",
|
||||||
|
"| time_elapsed | 688 |\n",
|
||||||
|
"| total_timesteps | 449928 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 21.1 |\n",
|
||||||
|
"| n_updates | 99981 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -155 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 100000 |\n",
|
||||||
|
"| fps | 645 |\n",
|
||||||
|
"| time_elapsed | 774 |\n",
|
||||||
|
"| total_timesteps | 499920 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 22.8 |\n",
|
||||||
|
"| n_updates | 112479 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -153 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 110000 |\n",
|
||||||
|
"| fps | 638 |\n",
|
||||||
|
"| time_elapsed | 860 |\n",
|
||||||
|
"| total_timesteps | 549916 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 16 |\n",
|
||||||
|
"| n_updates | 124978 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -164 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 120000 |\n",
|
||||||
|
"| fps | 633 |\n",
|
||||||
|
"| time_elapsed | 947 |\n",
|
||||||
|
"| total_timesteps | 599915 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 17.8 |\n",
|
||||||
|
"| n_updates | 137478 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -145 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 130000 |\n",
|
||||||
|
"| fps | 628 |\n",
|
||||||
|
"| time_elapsed | 1033 |\n",
|
||||||
|
"| total_timesteps | 649910 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 17.8 |\n",
|
||||||
|
"| n_updates | 149977 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -154 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 140000 |\n",
|
||||||
|
"| fps | 624 |\n",
|
||||||
|
"| time_elapsed | 1120 |\n",
|
||||||
|
"| total_timesteps | 699902 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 20.9 |\n",
|
||||||
|
"| n_updates | 162475 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -192 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 150000 |\n",
|
||||||
|
"| fps | 621 |\n",
|
||||||
|
"| time_elapsed | 1206 |\n",
|
||||||
|
"| total_timesteps | 749884 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 18.3 |\n",
|
||||||
|
"| n_updates | 174970 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -170 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 160000 |\n",
|
||||||
|
"| fps | 618 |\n",
|
||||||
|
"| time_elapsed | 1293 |\n",
|
||||||
|
"| total_timesteps | 799869 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 17.7 |\n",
|
||||||
|
"| n_updates | 187467 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -233 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 170000 |\n",
|
||||||
|
"| fps | 615 |\n",
|
||||||
|
"| time_elapsed | 1380 |\n",
|
||||||
|
"| total_timesteps | 849855 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 21.6 |\n",
|
||||||
|
"| n_updates | 199963 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -146 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 180000 |\n",
|
||||||
|
"| fps | 613 |\n",
|
||||||
|
"| time_elapsed | 1466 |\n",
|
||||||
|
"| total_timesteps | 899847 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 19.4 |\n",
|
||||||
|
"| n_updates | 212461 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -142 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 190000 |\n",
|
||||||
|
"| fps | 611 |\n",
|
||||||
|
"| time_elapsed | 1553 |\n",
|
||||||
|
"| total_timesteps | 949846 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 22.9 |\n",
|
||||||
|
"| n_updates | 224961 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"----------------------------------\n",
|
||||||
|
"| rollout/ | |\n",
|
||||||
|
"| ep_len_mean | 5 |\n",
|
||||||
|
"| ep_rew_mean | -171 |\n",
|
||||||
|
"| exploration_rate | 0.05 |\n",
|
||||||
|
"| time/ | |\n",
|
||||||
|
"| episodes | 200000 |\n",
|
||||||
|
"| fps | 609 |\n",
|
||||||
|
"| time_elapsed | 1640 |\n",
|
||||||
|
"| total_timesteps | 999839 |\n",
|
||||||
|
"| train/ | |\n",
|
||||||
|
"| learning_rate | 0.0001 |\n",
|
||||||
|
"| loss | 20.3 |\n",
|
||||||
|
"| n_updates | 237459 |\n",
|
||||||
|
"----------------------------------\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
||||||
|
],
|
||||||
|
"text/plain": []
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||||||
|
"</pre>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<stable_baselines3.dqn.dqn.DQN at 0x294981ca090>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"total_timesteps = 100000\n",
|
"total_timesteps = 1_000_000\n",
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=0)\n",
|
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
||||||
"model.learn(total_timesteps=total_timesteps, progress_bar=True)"
|
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def test(model):\n",
|
"model.save(\"dqn_new_rewards\")"
|
||||||
"\n",
|
]
|
||||||
" end_rewards = []\n",
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||||
|
"Exception: code() argument 13 must be str, not int\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
||||||
|
"Exception: code() argument 13 must be str, not int\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# model = DQN.load(\"dqn_wordle\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"0\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"env = gym_wordle.wordle.WordleEnv()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for i in range(1000):\n",
|
"for i in range(1000):\n",
|
||||||
" \n",
|
" \n",
|
||||||
" state = env.reset()\n",
|
" state, info = env.reset()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" done = False\n",
|
" done = False\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" wins = 0\n",
|
||||||
|
"\n",
|
||||||
" while not done:\n",
|
" while not done:\n",
|
||||||
"\n",
|
"\n",
|
||||||
" action, _states = model.predict(state, deterministic=True)\n",
|
" action, _states = model.predict(state, deterministic=True)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" state, reward, done, info = env.step(action)\n",
|
" state, reward, done, truncated, info = env.step(action)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" end_rewards.append(reward == 0)\n",
|
" if info[\"correct\"]:\n",
|
||||||
|
" wins += 1\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return np.sum(end_rewards) / len(end_rewards)"
|
"print(wins)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n",
|
||||||
|
" [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n",
|
||||||
|
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
||||||
|
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n",
|
||||||
|
" [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n",
|
||||||
|
" [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n",
|
||||||
|
" -130)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"state, reward"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"blah = (14, 1, 9, 22, 5)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"True"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 23,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"blah in info['guesses']"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -81,27 +644,7 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": []
|
||||||
"model.save(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = DQN.load(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(test(model))"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@@ -91,4 +91,3 @@ def play():
|
|||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
print(f"The word was {solution}")
|
print(f"The word was {solution}")
|
||||||
|
|
340
gym_wordle/wordle.py
Normal file
340
gym_wordle/wordle.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
from sty import fg, bg, ef, rs
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
from gym_wordle.utils import to_english, to_array, get_words
|
||||||
|
from typing import Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
class WordList(gym.spaces.Discrete):
|
||||||
|
"""Super class for defining a space of valid words according to a specified
|
||||||
|
list.
|
||||||
|
|
||||||
|
The space is a subclass of gym.spaces.Discrete, where each element
|
||||||
|
corresponds to an index of a valid word in the word list. The obfuscation
|
||||||
|
is necessary for more direct implementation of RL algorithms, which expect
|
||||||
|
spaces of less sophisticated form.
|
||||||
|
|
||||||
|
In addition to the default methods of the Discrete space, it implements
|
||||||
|
a __getitem__ method for easy index lookup, and an index_of method to
|
||||||
|
convert potential words into their corresponding index (if they exist).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, words: npt.NDArray[np.int64], **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
words: Collection of words in array form with shape (_, 5), where
|
||||||
|
each word is a row of the array. Each array element is an integer
|
||||||
|
between 0,...,26 (inclusive).
|
||||||
|
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||||
|
"""
|
||||||
|
super().__init__(words.shape[0], **kwargs)
|
||||||
|
self.words = words
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> npt.NDArray[np.int64]:
|
||||||
|
"""Obtains the (int-encoded) word associated with the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: Index for the list of words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Associated word at the position specified by index.
|
||||||
|
"""
|
||||||
|
return self.words[index]
|
||||||
|
|
||||||
|
def index_of(self, word: npt.NDArray[np.int64]) -> int:
|
||||||
|
"""Given a word, determine its index in the list (if it exists),
|
||||||
|
otherwise returning -1 if no index exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: Word to find in the word list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The index of the given word if it exists, otherwise -1.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
index, = np.nonzero((word == self.words).all(axis=1))
|
||||||
|
return index[0]
|
||||||
|
except:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
class SolutionList(WordList):
|
||||||
|
"""Space for *solution* words to the Wordle environment.
|
||||||
|
|
||||||
|
In the game Wordle, there are two different collections of words:
|
||||||
|
|
||||||
|
* "guesses", which the game accepts as valid words to use to guess the
|
||||||
|
answer.
|
||||||
|
* "solutions", which the game uses to choose solutions from.
|
||||||
|
|
||||||
|
Of course, the set of solutions is a strict subset of the set of guesses.
|
||||||
|
|
||||||
|
This class represents the set of solution words.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||||
|
"""
|
||||||
|
words = get_words('solution')
|
||||||
|
super().__init__(words, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class WordleObsSpace(gym.spaces.Box):
|
||||||
|
"""Implementation of the state (observation) space in terms of gym
|
||||||
|
primitives, in this case, gym.spaces.Box.
|
||||||
|
|
||||||
|
The Wordle observation space can be thought of as a 6x5 array with two
|
||||||
|
channels:
|
||||||
|
|
||||||
|
- the character channel, indicating which characters are placed on the
|
||||||
|
board (unfilled rows are marked with the empty character, 0)
|
||||||
|
- the flag channel, indicating the in-game information associated with
|
||||||
|
each character's placement (green highlight, yellow highlight, etc.)
|
||||||
|
|
||||||
|
where there are 6 rows, one for each turn in the game, and 5 columns, since
|
||||||
|
the solution will always be a word of length 5.
|
||||||
|
|
||||||
|
For simplicity, and compatibility with stable_baselines algorithms,
|
||||||
|
this multichannel is modeled as a 6x10 array, where the two channels are
|
||||||
|
horizontally appended (along columns). Thus each row in the observation
|
||||||
|
should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is
|
||||||
|
c0...c4 and its associated flags are f0...f4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.n_rows = 6
|
||||||
|
self.n_cols = 5
|
||||||
|
self.max_char = 26
|
||||||
|
self.max_flag = 4
|
||||||
|
|
||||||
|
low = np.zeros((self.n_rows, 2*self.n_cols))
|
||||||
|
high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char),
|
||||||
|
np.full((self.n_rows, self.n_cols), self.max_flag)]
|
||||||
|
|
||||||
|
super().__init__(low, high, dtype=np.int64, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class GuessList(WordList):
|
||||||
|
"""Space for *guess* words to the Wordle environment.
|
||||||
|
|
||||||
|
This class represents the set of guess words.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
kwargs: See documentation for gym.spaces.MultiDiscrete
|
||||||
|
"""
|
||||||
|
words = get_words('guess')
|
||||||
|
super().__init__(words, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class WordleEnv(gym.Env):
|
||||||
|
metadata = {'render.modes': ['human']}
|
||||||
|
|
||||||
|
# Character flag codes
|
||||||
|
no_char = 0
|
||||||
|
right_pos = 1
|
||||||
|
wrong_pos = 2
|
||||||
|
wrong_char = 3
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.action_space = GuessList()
|
||||||
|
self.solution_space = SolutionList()
|
||||||
|
|
||||||
|
self.observation_space = WordleObsSpace()
|
||||||
|
|
||||||
|
self._highlights = {
|
||||||
|
self.right_pos: (bg.green, bg.rs),
|
||||||
|
self.wrong_pos: (bg.yellow, bg.rs),
|
||||||
|
self.wrong_char: ('', ''),
|
||||||
|
self.no_char: ('', ''),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.n_rounds = 6
|
||||||
|
self.n_letters = 5
|
||||||
|
self.info = {
|
||||||
|
'correct': False,
|
||||||
|
'guesses': set(),
|
||||||
|
'known_positions': np.full(5, -1), # -1 for unknown, else letter index
|
||||||
|
'known_letters': set(), # Letters known to be in the word
|
||||||
|
'not_in_word': set(), # Letters known not to be in the word
|
||||||
|
'tried_positions': defaultdict(set) # Positions tried for each letter
|
||||||
|
}
|
||||||
|
|
||||||
|
def _highlighter(self, char: str, flag: int) -> str:
|
||||||
|
"""Terminal renderer functionality. Properly highlights a character
|
||||||
|
based on the flag associated with it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char: Character in question.
|
||||||
|
flag: Associated flag, one of:
|
||||||
|
- 0: no character (render no background)
|
||||||
|
- 1: right position (render green background)
|
||||||
|
- 2: wrong position (render yellow background)
|
||||||
|
- 3: wrong character (render no background)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correct ASCII sequence producing the desired character in the
|
||||||
|
correct background.
|
||||||
|
"""
|
||||||
|
front, back = self._highlights[flag]
|
||||||
|
return front + char + back
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
"""Reset the environment to an initial state and returns an initial
|
||||||
|
observation.
|
||||||
|
|
||||||
|
Note: The observation space instance should be a Box space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (object): The initial observation of the space.
|
||||||
|
"""
|
||||||
|
self.round = 0
|
||||||
|
self.solution = self.solution_space.sample()
|
||||||
|
self.soln_hash = set(self.solution_space[self.solution])
|
||||||
|
|
||||||
|
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||||
|
|
||||||
|
self.info = {
|
||||||
|
'correct': False,
|
||||||
|
'guesses': set(),
|
||||||
|
'known_positions': np.full(5, -1),
|
||||||
|
'known_letters': set(),
|
||||||
|
'not_in_word': set(),
|
||||||
|
'tried_positions': defaultdict(set)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.simulate_first_guess()
|
||||||
|
|
||||||
|
return self.state, self.info
|
||||||
|
|
||||||
|
def simulate_first_guess(self):
|
||||||
|
fixed_first_guess = "rates"
|
||||||
|
fixed_first_guess_array = to_array(fixed_first_guess)
|
||||||
|
|
||||||
|
# Simulate the feedback for each letter in the fixed first guess
|
||||||
|
feedback = np.zeros(self.n_letters, dtype=int) # Initialize feedback array
|
||||||
|
for i, letter in enumerate(fixed_first_guess_array):
|
||||||
|
if letter in self.solution_space[self.solution]:
|
||||||
|
if letter == self.solution_space[self.solution][i]:
|
||||||
|
feedback[i] = 1 # Correct position
|
||||||
|
else:
|
||||||
|
feedback[i] = 2 # Correct letter, wrong position
|
||||||
|
else:
|
||||||
|
feedback[i] = 3 # Letter not in word
|
||||||
|
|
||||||
|
# Update the state to reflect the fixed first guess and its feedback
|
||||||
|
self.state[0, :self.n_letters] = fixed_first_guess_array
|
||||||
|
self.state[0, self.n_letters:] = feedback
|
||||||
|
|
||||||
|
# Update self.info based on the feedback
|
||||||
|
for i, flag in enumerate(feedback):
|
||||||
|
if flag == self.right_pos:
|
||||||
|
# Mark letter as correctly placed
|
||||||
|
self.info['known_positions'][i] = fixed_first_guess_array[i]
|
||||||
|
elif flag == self.wrong_pos:
|
||||||
|
# Note the letter is in the word but in a different position
|
||||||
|
self.info['known_letters'].add(fixed_first_guess_array[i])
|
||||||
|
elif flag == self.wrong_char:
|
||||||
|
# Note the letter is not in the word
|
||||||
|
self.info['not_in_word'].add(fixed_first_guess_array[i])
|
||||||
|
|
||||||
|
# Since we're simulating the first guess, increment the round counter
|
||||||
|
self.round = 1
|
||||||
|
|
||||||
|
def render(self, mode: str = 'human'):
|
||||||
|
"""Renders the Wordle environment.
|
||||||
|
|
||||||
|
Currently supported render modes:
|
||||||
|
- human: renders the Wordle game to the terminal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: the mode to render with.
|
||||||
|
"""
|
||||||
|
if mode == 'human':
|
||||||
|
for row in self.state:
|
||||||
|
text = ''.join(map(
|
||||||
|
self._highlighter,
|
||||||
|
to_english(row[:self.n_letters]).upper(),
|
||||||
|
row[self.n_letters:]
|
||||||
|
))
|
||||||
|
print(text)
|
||||||
|
else:
|
||||||
|
super().render(mode=mode)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
assert self.action_space.contains(action), 'Invalid word!'
|
||||||
|
|
||||||
|
guessed_word = self.action_space[action]
|
||||||
|
solution_word = self.solution_space[self.solution]
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
correct_guess = np.array_equal(guessed_word, solution_word)
|
||||||
|
|
||||||
|
# Initialize flags for current guess
|
||||||
|
current_flags = np.full(self.n_letters, self.wrong_char)
|
||||||
|
|
||||||
|
# Track newly discovered information
|
||||||
|
new_info = False
|
||||||
|
|
||||||
|
for i in range(self.n_letters):
|
||||||
|
guessed_letter = guessed_word[i]
|
||||||
|
if guessed_letter in solution_word:
|
||||||
|
# Penalize for reusing a letter found to not be in the word
|
||||||
|
if guessed_letter in self.info['not_in_word']:
|
||||||
|
reward -= 2
|
||||||
|
|
||||||
|
# Handle correct letter in the correct position
|
||||||
|
if guessed_letter == solution_word[i]:
|
||||||
|
current_flags[i] = self.right_pos
|
||||||
|
if self.info['known_positions'][i] != guessed_letter:
|
||||||
|
reward += 10 # Large reward for new correct placement
|
||||||
|
new_info = True
|
||||||
|
self.info['known_positions'][i] = guessed_letter
|
||||||
|
else:
|
||||||
|
reward += 20 # Large reward for repeating correct placement
|
||||||
|
else:
|
||||||
|
current_flags[i] = self.wrong_pos
|
||||||
|
if guessed_letter not in self.info['known_letters'] or i not in self.info['tried_positions'][guessed_letter]:
|
||||||
|
reward += 10 # Reward for guessing a letter in a new position
|
||||||
|
new_info = True
|
||||||
|
else:
|
||||||
|
reward -= 20 # Penalize for not leveraging known information
|
||||||
|
self.info['known_letters'].add(guessed_letter)
|
||||||
|
self.info['tried_positions'][guessed_letter].add(i)
|
||||||
|
else:
|
||||||
|
# New incorrect letter
|
||||||
|
if guessed_letter not in self.info['not_in_word']:
|
||||||
|
reward -= 2 # Penalize for guessing a letter not in the word
|
||||||
|
self.info['not_in_word'].add(guessed_letter)
|
||||||
|
new_info = True
|
||||||
|
else:
|
||||||
|
reward -= 15 # Larger penalty for repeating an incorrect letter
|
||||||
|
|
||||||
|
# Update observation state with the current guess and flags
|
||||||
|
self.state[self.round, :self.n_letters] = guessed_word
|
||||||
|
self.state[self.round, self.n_letters:] = current_flags
|
||||||
|
|
||||||
|
# Check if the game is over
|
||||||
|
done = self.round == self.n_rounds - 1 or correct_guess
|
||||||
|
self.info['correct'] = correct_guess
|
||||||
|
|
||||||
|
if correct_guess:
|
||||||
|
reward += 100 # Major reward for winning
|
||||||
|
elif done:
|
||||||
|
reward -= 50 # Penalty for losing without using new information effectively
|
||||||
|
elif not new_info:
|
||||||
|
reward -= 10 # Penalty if no new information was used in this guess
|
||||||
|
|
||||||
|
self.round += 1
|
||||||
|
|
||||||
|
return self.state, reward, done, False, self.info
|
189
test.ipynb
Normal file
189
test.ipynb
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from collections import defaultdict"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def my_func()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"t = defaultdict(lambda: [0, 1, 2, 3, 4])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"defaultdict(<function __main__.<lambda>()>, {})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"t"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[0, 1, 2, 3, 4]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"t['t']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"defaultdict(<function __main__.<lambda>()>, {'t': [0, 1, 2, 3, 4]})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"t"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"'x' in t"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"x = np.array([1, 1, 1])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"x[:] = 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([0, 0, 0])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"x"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"'abcde'aaa\n",
|
||||||
|
" 33221\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "env",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Reference in New Issue
Block a user