mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			4 Commits
		
	
	
		
			83e81722d2
			...
			ethan-test
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 335d56ac88 | ||
|  | 496b8ad796 | ||
|  | 8d3ce990e3 | ||
|  | 7ad5b97463 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1,3 @@ | ||||
| **/data/* | ||||
| **/data/* | ||||
| /env | ||||
| **/*.zip | ||||
							
								
								
									
										129
									
								
								Gym-Wordle-main/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								Gym-Wordle-main/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| # Byte-compiled / optimized / DLL files | ||||
| __pycache__/ | ||||
| *.py[cod] | ||||
| *$py.class | ||||
|  | ||||
| # C extensions | ||||
| *.so | ||||
|  | ||||
| # Distribution / packaging | ||||
| .Python | ||||
| build/ | ||||
| develop-eggs/ | ||||
| dist/ | ||||
| downloads/ | ||||
| eggs/ | ||||
| .eggs/ | ||||
| lib/ | ||||
| lib64/ | ||||
| parts/ | ||||
| sdist/ | ||||
| var/ | ||||
| wheels/ | ||||
| pip-wheel-metadata/ | ||||
| share/python-wheels/ | ||||
| *.egg-info/ | ||||
| .installed.cfg | ||||
| *.egg | ||||
| MANIFEST | ||||
|  | ||||
| # PyInstaller | ||||
| #  Usually these files are written by a python script from a template | ||||
| #  before PyInstaller builds the exe, so as to inject date/other infos into it. | ||||
| *.manifest | ||||
| *.spec | ||||
|  | ||||
| # Installer logs | ||||
| pip-log.txt | ||||
| pip-delete-this-directory.txt | ||||
|  | ||||
| # Unit test / coverage reports | ||||
| htmlcov/ | ||||
| .tox/ | ||||
| .nox/ | ||||
| .coverage | ||||
| .coverage.* | ||||
| .cache | ||||
| nosetests.xml | ||||
| coverage.xml | ||||
| *.cover | ||||
| *.py,cover | ||||
| .hypothesis/ | ||||
| .pytest_cache/ | ||||
|  | ||||
| # Translations | ||||
| *.mo | ||||
| *.pot | ||||
|  | ||||
| # Django stuff: | ||||
| *.log | ||||
| local_settings.py | ||||
| db.sqlite3 | ||||
| db.sqlite3-journal | ||||
|  | ||||
| # Flask stuff: | ||||
| instance/ | ||||
| .webassets-cache | ||||
|  | ||||
| # Scrapy stuff: | ||||
| .scrapy | ||||
|  | ||||
| # Sphinx documentation | ||||
| docs/_build/ | ||||
|  | ||||
| # PyBuilder | ||||
| target/ | ||||
|  | ||||
| # Jupyter Notebook | ||||
| .ipynb_checkpoints | ||||
|  | ||||
| # IPython | ||||
| profile_default/ | ||||
| ipython_config.py | ||||
|  | ||||
| # pyenv | ||||
| .python-version | ||||
|  | ||||
| # pipenv | ||||
| #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||||
| #   However, in case of collaboration, if having platform-specific dependencies or dependencies | ||||
| #   having no cross-platform support, pipenv may install dependencies that don't work, or not | ||||
| #   install all needed dependencies. | ||||
| #Pipfile.lock | ||||
|  | ||||
| # PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||||
| __pypackages__/ | ||||
|  | ||||
| # Celery stuff | ||||
| celerybeat-schedule | ||||
| celerybeat.pid | ||||
|  | ||||
| # SageMath parsed files | ||||
| *.sage.py | ||||
|  | ||||
| # Environments | ||||
| .env | ||||
| .venv | ||||
| env/ | ||||
| venv/ | ||||
| ENV/ | ||||
| env.bak/ | ||||
| venv.bak/ | ||||
|  | ||||
| # Spyder project settings | ||||
| .spyderproject | ||||
| .spyproject | ||||
|  | ||||
| # Rope project settings | ||||
| .ropeproject | ||||
|  | ||||
| # mkdocs documentation | ||||
| /site | ||||
|  | ||||
| # mypy | ||||
| .mypy_cache/ | ||||
| .dmypy.json | ||||
| dmypy.json | ||||
|  | ||||
| # Pyre type checker | ||||
| .pyre/ | ||||
							
								
								
									
										21
									
								
								Gym-Wordle-main/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								Gym-Wordle-main/LICENSE
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| MIT License | ||||
|  | ||||
| Copyright (c) 2022 David Kraemer | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
							
								
								
									
										78
									
								
								Gym-Wordle-main/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								Gym-Wordle-main/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | ||||
| # Gym-Wordle | ||||
|  | ||||
| An OpenAI gym compatible environment for training agents to play Wordle. | ||||
|  | ||||
| <p align='center'> | ||||
|   <img src="https://user-images.githubusercontent.com/8514041/152437216-d78e85f6-8049-4cb9-ae61-3c015a8a0e4f.gif"><br/> | ||||
|   <em>User-input demo of the environment</em> | ||||
| </p> | ||||
|  | ||||
| ## Installation | ||||
|  | ||||
| My goal is for a minimalist package that lets you install quickly and get on | ||||
| with your research. Installation is just a simple call to `pip`: | ||||
|  | ||||
| ``` | ||||
| $ pip install gym_wordle | ||||
| ``` | ||||
|  | ||||
| ### Requirements | ||||
|  | ||||
| In keeping with my desire to have a minimalist package, there are only three | ||||
| major requirements: | ||||
|  | ||||
| * `numpy` | ||||
| * `gym` | ||||
| * `sty`, a lovely little package for stylizing text in terminals | ||||
|  | ||||
| ## Usage | ||||
|  | ||||
| The basic flow for training agents with the `Wordle-v0` environment is the same | ||||
| as with gym environments generally: | ||||
|  | ||||
| ```Python | ||||
| import gym | ||||
| import gym_wordle | ||||
|  | ||||
| eng = gym.make("Wordle-v0") | ||||
|  | ||||
| done = False | ||||
| while not done: | ||||
|     action = ...  # RL magic | ||||
|     state, reward, done, info = env.step(action) | ||||
| ``` | ||||
|  | ||||
| If you're like millions of other people, you're a Wordle-obsessive in your own | ||||
| right. I have good news for you! The `Wordle-v0` environment currently has an | ||||
| implemented `render` method, which allows you to see a human-friendly version | ||||
| of the game. And it isn't so hard to set up the environment to play for | ||||
| yourself. Here's an example script: | ||||
|  | ||||
| ```Python | ||||
| from gym_wordle.utils import play | ||||
|  | ||||
| play() | ||||
| ``` | ||||
|  | ||||
| ## Documentation | ||||
|  | ||||
| Coming soon! | ||||
|  | ||||
| ## Examples | ||||
|  | ||||
| Coming soon! | ||||
|  | ||||
| ## Citing | ||||
|  | ||||
| If you decide to use this project in your work, please consider a citation! | ||||
|  | ||||
| ```bibtex | ||||
| @misc{gym_wordle, | ||||
|   author = {Kraemer, David}, | ||||
|   title = {An Environment for Reinforcement Learning with Wordle}, | ||||
|   year = {2022}, | ||||
|   publisher = {GitHub}, | ||||
|   journal = {GitHub repository}, | ||||
|   howpublished = {\url{https://github.com/DavidNKraemer/Gym-Wordle}}, | ||||
| } | ||||
| ``` | ||||
							
								
								
									
										7
									
								
								Gym-Wordle-main/gym-wordle.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								Gym-Wordle-main/gym-wordle.toml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| [build-system] | ||||
|  | ||||
| requires = [ | ||||
|   "setuptools>=42", | ||||
|   "wheel" | ||||
| ] | ||||
| build-backend = "setuptools.build_meta" | ||||
							
								
								
									
										7
									
								
								Gym-Wordle-main/gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								Gym-Wordle-main/gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| from gym.envs.registration import register | ||||
| from .wordle import WordleEnv | ||||
|  | ||||
| register( | ||||
|     id='Wordle-v0', | ||||
|     entry_point='gym_wordle.wordle:WordleEnv' | ||||
| ) | ||||
							
								
								
									
										12972
									
								
								Gym-Wordle-main/gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12972
									
								
								Gym-Wordle-main/gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								Gym-Wordle-main/gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								Gym-Wordle-main/gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2315
									
								
								Gym-Wordle-main/gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2315
									
								
								Gym-Wordle-main/gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								Gym-Wordle-main/gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								Gym-Wordle-main/gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										94
									
								
								Gym-Wordle-main/gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								Gym-Wordle-main/gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
|  | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| _chars = ' abcdefghijklmnopqrstuvwxyz' | ||||
| _char_d = {c: i for i, c in enumerate(_chars)} | ||||
|  | ||||
|  | ||||
| def to_english(array: npt.NDArray[np.int64]) -> str: | ||||
|     """Converts a numpy integer array into a corresponding English string. | ||||
|  | ||||
|     Args: | ||||
|         array: Word in array (int) form. It is assumed that each integer in the | ||||
|           array is between 0,...,26 (inclusive). | ||||
|  | ||||
|     Returns: | ||||
|         A (lowercase) string representation of the word.     | ||||
|     """ | ||||
|     return ''.join(_chars[i] for i in array) | ||||
|  | ||||
|  | ||||
| def to_array(word: str) -> npt.NDArray[np.int64]: | ||||
|     """Converts a string of characters into a corresponding numpy array. | ||||
|  | ||||
|     Args: | ||||
|         word: Word in string form. It is assumed that each character in the | ||||
|           string is either an empty space ' ' or lowercase alphabetical | ||||
|           character. | ||||
|  | ||||
|     Returns: | ||||
|         An array representation of the word. | ||||
|     """ | ||||
|     return np.array([_char_d[c] for c in word]) | ||||
|  | ||||
|  | ||||
| def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | ||||
|     """Loads a list of words in array form.  | ||||
|  | ||||
|     If specified, this will recompute the list from the human-readable list of | ||||
|     words, and save the results in array form. | ||||
|  | ||||
|     Args: | ||||
|         category: Either 'guess' or 'solution', which corresponds to the list | ||||
|           of acceptable guess words and the list of acceptable solution words. | ||||
|         build: If True, recomputes and saves the array-version of the computed | ||||
|           list for future access. | ||||
|  | ||||
|     Returns: | ||||
|         An array representation of the list of words specified by the category. | ||||
|         This array has two dimensions, and the number of columns is fixed at | ||||
|         five. | ||||
|     """ | ||||
|     assert category in {'guess', 'solution'} | ||||
|      | ||||
|     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' | ||||
|     if build: | ||||
|        list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||
|  | ||||
|        with open(list_path, 'r') as f: | ||||
|            words = np.array([to_array(line.strip()) for line in f]) | ||||
|            np.save(arr_path, words) | ||||
|  | ||||
|     return np.load(arr_path) | ||||
|  | ||||
|  | ||||
| def play(): | ||||
|     """Play Wordle yourself!""" | ||||
|     import gym | ||||
|     import gym_wordle | ||||
|      | ||||
|     env = gym.make('Wordle-v0')  # load the environment | ||||
|      | ||||
|     env.reset() | ||||
|     solution = to_english(env.unwrapped.solution_space[env.solution]).upper()  # no peeking! | ||||
|  | ||||
|     done = False | ||||
|      | ||||
|     while not done: | ||||
|         action = -1  | ||||
|  | ||||
|         # in general, the environment won't be forgiving if you input an | ||||
|         # invalid word, but for this function I want to let you screw up user | ||||
|         # input without consequence, so just loops until valid input is taken | ||||
|         while not env.action_space.contains(action): | ||||
|             guess = input('Guess: ') | ||||
|             action = env.unwrapped.action_space.index_of(to_array(guess)) | ||||
|      | ||||
|         state, reward, done, info = env.step(action) | ||||
|         env.render() | ||||
|      | ||||
|     print(f"The word was {solution}") | ||||
|  | ||||
							
								
								
									
										286
									
								
								Gym-Wordle-main/gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										286
									
								
								Gym-Wordle-main/gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,286 @@ | ||||
| import gym | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
| from sty import fg, bg, ef, rs | ||||
|  | ||||
| from collections import Counter  | ||||
| from gym_wordle.utils import to_english, to_array, get_words | ||||
| from typing import Optional  | ||||
|  | ||||
|  | ||||
| class WordList(gym.spaces.Discrete): | ||||
|     """Super class for defining a space of valid words according to a specified | ||||
|     list. | ||||
|  | ||||
|     TODO: Fix these paragraphs | ||||
|     The space is a subclass of gym.spaces.Discrete, where each element | ||||
|     corresponds to an index of a valid word in the word list. The obfuscation | ||||
|     is necessary for more direct implementation of RL algorithms, which expect | ||||
|     spaces of less sophisticated form. | ||||
|  | ||||
|     In addition to the default methods of the Discrete space, it implements | ||||
|     a __getitem__ method for easy index lookup, and an index_of method to | ||||
|     convert potential words into their corresponding index (if they exist). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, words: npt.NDArray[np.int64], **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             words: Collection of words in array form with shape (_, 5), where | ||||
|               each word is a row of the array. Each array element is an integer | ||||
|               between 0,...,26 (inclusive). | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         super().__init__(words.shape[0], **kwargs) | ||||
|         self.words = words | ||||
|  | ||||
|     def __getitem__(self, index: int) -> npt.NDArray[np.int64]: | ||||
|         """Obtains the (int-encoded) word associated with the given index. | ||||
|  | ||||
|         Args: | ||||
|             index: Index for the list of words. | ||||
|  | ||||
|         Returns: | ||||
|             Associated word at the position specified by index. | ||||
|         """ | ||||
|         return self.words[index] | ||||
|  | ||||
|     def index_of(self, word: npt.NDArray[np.int64]) -> int: | ||||
|         """Given a word, determine its index in the list (if it exists), | ||||
|         otherwise returning -1 if no index exists. | ||||
|  | ||||
|         Args: | ||||
|             word: Word to find in the word list. | ||||
|  | ||||
|         Returns: | ||||
|             The index of the given word if it exists, otherwise -1. | ||||
|         """ | ||||
|         try: | ||||
|             index, = np.nonzero((word == self.words).all(axis=1)) | ||||
|             return index[0] | ||||
|         except: | ||||
|             return -1 | ||||
|  | ||||
|  | ||||
| class SolutionList(WordList): | ||||
|     """Space for *solution* words to the Wordle environment. | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|         * "guesses", which the game accepts as valid words to use to guess the | ||||
|           answer. | ||||
|         * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||
|  | ||||
|     This class represents the set of solution words. | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         words = get_words('solution') | ||||
|         super().__init__(words, **kwargs) | ||||
|  | ||||
|  | ||||
| class WordleObsSpace(gym.spaces.Box): | ||||
|     """Implementation of the state (observation) space in terms of gym | ||||
|     primatives, in this case, gym.spaces.Box. | ||||
|  | ||||
|     The Wordle observation space can be thought of as a 6x5 array with two | ||||
|     channels: | ||||
|  | ||||
|       - the character channel, indicating which characters are placed on the | ||||
|         board (unfilled rows are marked with the empty character, 0) | ||||
|       - the flag channel, indicating the in-game information associated with | ||||
|         each character's placement (green highlight, yellow highlight, etc.) | ||||
|  | ||||
|     where there are 6 rows, one for each turn in the game, and 5 columns, since | ||||
|     the solution will always be a word of length 5. | ||||
|  | ||||
|     For simplicity, and compatibility with the stable_baselines algorithms, | ||||
|     this multichannel is modeled as a 6x10 array, where the two channels are | ||||
|     horizontally appended (along columns). Thus each row in the observation | ||||
|     should be interpreted as  | ||||
|  | ||||
|         c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 | ||||
|  | ||||
|     when the word is c0...c4 and its associated flags are f0...f4. | ||||
|  | ||||
|     While the superclass method `sample` is available to the WordleObsSpace, it | ||||
|     should be emphasized that the output of `sample` will (almost surely) not | ||||
|     correspond to a real game configuration, because the sampling is not out of | ||||
|     possible game configurations. Instead, the Box superclass just samples the | ||||
|     integer array space uniformly. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         self.n_rows = 6 | ||||
|         self.n_cols = 5 | ||||
|         self.max_char = 26 | ||||
|         self.max_flag = 4 | ||||
|  | ||||
|         low = np.zeros((self.n_rows, 2*self.n_cols)) | ||||
|         high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char), | ||||
|                      np.full((self.n_rows, self.n_cols), self.max_flag)] | ||||
|  | ||||
|         super().__init__(low, high, dtype=np.int64, **kwargs) | ||||
|  | ||||
|  | ||||
| class GuessList(WordList): | ||||
|     """Space for *solution* words to the Wordle environment. | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|         * "guesses", which the game accepts as valid words to use to guess the | ||||
|           answer. | ||||
|         * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||
|  | ||||
|     This class represents the set of guess words. | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         words = get_words('guess') | ||||
|         super().__init__(words, **kwargs) | ||||
|  | ||||
|  | ||||
| class WordleEnv(gym.Env): | ||||
|  | ||||
|     metadata = {'render.modes': ['human']} | ||||
|  | ||||
|     # character flag codes | ||||
|     no_char = 0 | ||||
|     right_pos = 1 | ||||
|     wrong_pos = 2 | ||||
|     wrong_char = 3 | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.seed() | ||||
|         self.action_space = GuessList() | ||||
|         self.solution_space = SolutionList() | ||||
|  | ||||
|         self.observation_space = WordleObsSpace() | ||||
|  | ||||
|         self._highlights = { | ||||
|             self.right_pos: (bg.green, bg.rs), | ||||
|             self.wrong_pos: (bg.yellow, bg.rs), | ||||
|             self.wrong_char: ('', ''), | ||||
|             self.no_char: ('', ''), | ||||
|         } | ||||
|  | ||||
|         self.n_rounds = 6 | ||||
|         self.n_letters = 5 | ||||
|  | ||||
|     def _highlighter(self, char: str, flag: int) -> str: | ||||
|         """Terminal renderer functionality. Properly highlights a character | ||||
|         based on the flag associated with it. | ||||
|  | ||||
|         Args: | ||||
|             char: Character in question. | ||||
|             flag: Associated flag, one of: | ||||
|                 - 0: no character (render no background) | ||||
|                 - 1: right position (render green background) | ||||
|                 - 2: wrong position (render yellow background) | ||||
|                 - 3: wrong character (render no background) | ||||
|  | ||||
|         Returns: | ||||
|             Correct ASCII sequence producing the desired character in the | ||||
|             correct background. | ||||
|         """ | ||||
|         front, back = self._highlights[flag] | ||||
|         return front + char + back | ||||
|  | ||||
|     def reset(self): | ||||
|         self.round = 0 | ||||
|         self.solution = self.solution_space.sample() | ||||
|  | ||||
|         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), | ||||
|                               dtype=np.int64) | ||||
|  | ||||
|         return self.state | ||||
|  | ||||
|     def render(self, mode: str ='human'): | ||||
|         """Renders the Wordle environment. | ||||
|  | ||||
|         Currently supported render modes: | ||||
|  | ||||
|         - human: renders the Wordle game to the terminal. | ||||
|  | ||||
|         Args: | ||||
|             mode: the mode to render with | ||||
|         """ | ||||
|         if mode == 'human': | ||||
|             for row in self.states: | ||||
|                 text = ''.join(map( | ||||
|                     self._highlighter,  | ||||
|                     to_english(row[:self.n_letters]).upper(),  | ||||
|                     row[self.n_letters:] | ||||
|                 )) | ||||
|  | ||||
|                 print(text) | ||||
|         else: | ||||
|             super(WordleEnv, self).render(mode=mode) | ||||
|                  | ||||
|     def step(self, action): | ||||
|         """Run one step of the Wordle game. Every game must be previously | ||||
|         initialized by a call to the `reset` method. | ||||
|  | ||||
|         Args: | ||||
|             action: Word guessed by the agent. | ||||
|         Returns: | ||||
|             state (object): Wordle game state after the guess. | ||||
|             reward (float): Reward associated with the guess (-1 for incorrect, | ||||
|               0 for correct) | ||||
|             done (bool): Whether the game has ended (by a correct guess or | ||||
|               after six guesses). | ||||
|             info (dict): Auxiliary diagnostic information (empty). | ||||
|         """ | ||||
|         assert self.action_space.contains(action), 'Invalid word!' | ||||
|  | ||||
|         # transform the action, solution indices to their words | ||||
|         action = self.action_space[action]   | ||||
|         solution = self.solution_space[self.solution] | ||||
|  | ||||
|         # populate the word chars into the row (character channel) | ||||
|         self.state[self.round][:self.n_letters] = action | ||||
|  | ||||
|         # populate the flag characters into the row (flag channel) | ||||
|         counter = Counter() | ||||
|         for i, char in enumerate(action): | ||||
|             flag_i = i + self.n_letters  # starts at 5 | ||||
|             counter[char] += 1 | ||||
|  | ||||
|             if char == solution[i]:  # character is in correct position | ||||
|                 self.state[self.round, i] = self.right_pos | ||||
|             elif counter[char] <= (char == solution).sum():   | ||||
|                 # current character has been seen within correct number of | ||||
|                 # occurrences  | ||||
|                 self.state[self.round, i] = self.wrong_pos | ||||
|             else: | ||||
|                 # wrong character, or "correct" character too many times | ||||
|                 self.state[self.round, i] = self.wrong_char | ||||
|  | ||||
|         self.round += 1 | ||||
|  | ||||
|         correct = (action == solution).all() | ||||
|         game_over = (self.round == self.n_rounds) | ||||
|  | ||||
|         done = correct or game_over | ||||
|  | ||||
|         # Total reward equals -(number of incorrect guesses) | ||||
|         reward = 0. if correct else -1. | ||||
|  | ||||
|         return self.state, reward, done, {} | ||||
|  | ||||
							
								
								
									
										35
									
								
								Gym-Wordle-main/setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								Gym-Wordle-main/setup.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| from setuptools import setup, find_packages | ||||
|  | ||||
| with open('README.md', 'r', encoding='utf-8') as fh: | ||||
|     long_description = fh.read() | ||||
|  | ||||
| setup( | ||||
|     name='gym_wordle', | ||||
|     version='0.1.3', | ||||
|     author='David Kraemer', | ||||
|     author_email='david.kraemer@stonybrook.edu', | ||||
|     description='OpenAI gym environment for training agents on Wordle', | ||||
|     long_description=long_description, | ||||
|     long_description_content_type='text/markdown', | ||||
|     url='https://github.com/DavidNKraemer/Gym-Wordle', | ||||
|     packages=find_packages( | ||||
|         include=[ | ||||
|             'gym_wordle', | ||||
|             'gym_wordle.*' | ||||
|         ] | ||||
|     ), | ||||
|     package_data={ | ||||
|         'gym_wordle': ['dictionary/*'] | ||||
|     }, | ||||
|     python_requires='>=3.7', | ||||
|     classifiers=[ | ||||
|         "Programming Language :: Python :: 3", | ||||
|         "License :: OSI Approved :: MIT License", | ||||
|         "Operating System :: OS Independent", | ||||
|     ], | ||||
|     install_requires=[ | ||||
|         'numpy>=1.20', | ||||
|         'gym==0.19', | ||||
|         'sty==1.0', | ||||
|     ], | ||||
| ) | ||||
							
								
								
									
										91
									
								
								custom_env/agent.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								custom_env/agent.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,91 @@ | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class Agent: | ||||
|  | ||||
|     def __init__(self, ) -> None: | ||||
|         # BATCH_SIZE is the number of transitions sampled from the replay buffer | ||||
|         # GAMMA is the discount factor as mentioned in the previous section | ||||
|         # EPS_START is the starting value of epsilon | ||||
|         # EPS_END is the final value of epsilon | ||||
|         # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay | ||||
|         # TAU is the update rate of the target network | ||||
|         # LR is the learning rate of the ``AdamW`` optimizer | ||||
|         self.batch_size = 128 | ||||
|         self.gamma = 0.99 | ||||
|         self.eps_start = 0.9 | ||||
|         self.eps_end = 0.05 | ||||
|         self.eps_decay = 1000 | ||||
|         self.tau = 0.005 | ||||
|         self.lr = 1e-4 | ||||
|         self.n_actions = n_actions | ||||
|  | ||||
|         policy_net = DQN(n_observations, n_actions).to(device) | ||||
|         target_net = DQN(n_observations, n_actions).to(device) | ||||
|         target_net.load_state_dict(policy_net.state_dict()) | ||||
|  | ||||
|         optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True) | ||||
|         memory = ReplayMemory(10000) | ||||
|  | ||||
|     def get_state(self, game): | ||||
|         pass | ||||
|  | ||||
|     def select_action(state): | ||||
|         sample = random.random() | ||||
|         eps_threshold = EPS_END + (EPS_START - EPS_END) * \ | ||||
|             math.exp(-1. * steps_done / EPS_DECAY) | ||||
|         steps_done += 1 | ||||
|         if sample > eps_threshold: | ||||
|             with torch.no_grad(): | ||||
|                 # t.max(1) will return the largest column value of each row. | ||||
|                 # second column on max result is index of where max element was | ||||
|                 # found, so we pick action with the larger expected reward. | ||||
|                 return policy_net(state).max(1).indices.view(1, 1) | ||||
|         else: | ||||
|             return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long) | ||||
|  | ||||
|     def optimize_model(): | ||||
|         if len(memory) < BATCH_SIZE: | ||||
|             return | ||||
|         transitions = memory.sample(BATCH_SIZE) | ||||
|         # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for | ||||
|         # detailed explanation). This converts batch-array of Transitions | ||||
|         # to Transition of batch-arrays. | ||||
|         batch = Transition(*zip(*transitions)) | ||||
|  | ||||
|         # Compute a mask of non-final states and concatenate the batch elements | ||||
|         # (a final state would've been the one after which simulation ended) | ||||
|         non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, | ||||
|                                             batch.next_state)), device=device, dtype=torch.bool) | ||||
|         non_final_next_states = torch.cat([s for s in batch.next_state | ||||
|                                            if s is not None]) | ||||
|         state_batch = torch.cat(batch.state) | ||||
|         action_batch = torch.cat(batch.action) | ||||
|         reward_batch = torch.cat(batch.reward) | ||||
|  | ||||
|         # Compute Q(s_t, a) - the model computes Q(s_t), then we select the | ||||
|         # columns of actions taken. These are the actions which would've been taken | ||||
|         # for each batch state according to policy_net | ||||
|         state_action_values = policy_net(state_batch).gather(1, action_batch) | ||||
|  | ||||
|         # Compute V(s_{t+1}) for all next states. | ||||
|         # Expected values of actions for non_final_next_states are computed based | ||||
|         # on the "older" target_net; selecting their best reward with max(1).values | ||||
|         # This is merged based on the mask, such that we'll have either the expected | ||||
|         # state value or 0 in case the state was final. | ||||
|         next_state_values = torch.zeros(BATCH_SIZE, device=device) | ||||
|         with torch.no_grad(): | ||||
|             next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values | ||||
|         # Compute the expected Q values | ||||
|         expected_state_action_values = (next_state_values * GAMMA) + reward_batch | ||||
|  | ||||
|         # Compute Huber loss | ||||
|         criterion = nn.SmoothL1Loss() | ||||
|         loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1)) | ||||
|  | ||||
|         # Optimize the model | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         # In-place gradient clipping | ||||
|         torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100) | ||||
|         optimizer.step() | ||||
							
								
								
									
										16
									
								
								custom_env/create_wordlist.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								custom_env/create_wordlist.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| import pathlib | ||||
| import sys | ||||
| from string import ascii_letters | ||||
|  | ||||
| in_path = pathlib.Path(sys.argv[1]) | ||||
| out_path = pathlib.Path(sys.argv[2]) | ||||
|  | ||||
| words = sorted( | ||||
|     { | ||||
|         word.lower() | ||||
|         for word in in_path.read_text(encoding="utf-8").split() | ||||
|         if all(letter in ascii_letters for letter in word) | ||||
|     }, | ||||
|     key=lambda word: (len(word), word), | ||||
| ) | ||||
| out_path.write_text("\n".join(words)) | ||||
							
								
								
									
										5757
									
								
								custom_env/five_letter_words.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5757
									
								
								custom_env/five_letter_words.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										44
									
								
								custom_env/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								custom_env/model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| import math | ||||
| import random | ||||
| import matplotlib | ||||
| import matplotlib.pyplot as plt | ||||
| from collections import namedtuple, deque | ||||
| from itertools import count | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.optim as optim | ||||
| import torch.nn.functional as F | ||||
|  | ||||
|  | ||||
| Transition = namedtuple('Transition', | ||||
|                         ('state', 'action', 'next_state', 'reward')) | ||||
|  | ||||
|  | ||||
| class ReplayMemory(object): | ||||
|  | ||||
|     def __init__(self, capacity: int) -> None: | ||||
|         self.memory = deque([], maxlen=capacity) | ||||
|  | ||||
|     def push(self, *args): | ||||
|         self.memory.append(Transition(*args)) | ||||
|  | ||||
|     def sample(self, batch_size): | ||||
|         return random.sample(self.memory, batch_size) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.memory) | ||||
|  | ||||
|  | ||||
| class DQN(nn.Module): | ||||
|  | ||||
|     def __init__(self, n_observations: int, n_actions: int) -> None: | ||||
|         super(DQN, self).__init__() | ||||
|         self.layer1 = nn.Linear(n_observations, 128) | ||||
|         self.layer2 = nn.Linear(128, 128) | ||||
|         self.layer3 = nn.Linear(128, n_actions) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = F.relu(self.layer1(x)) | ||||
|         x = F.relu(self.layer2(x)) | ||||
|         return self.layer3(x) | ||||
							
								
								
									
										61
									
								
								custom_env/test2.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								custom_env/test2.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from string import ascii_letters, ascii_uppercase, ascii_lowercase" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "'ABCDEFGHIJKLMNOPQRSTUVWXYZ'" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 3, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "ascii_uppercase" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "env", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.5" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
							
								
								
									
										1098
									
								
								custom_env/wordlist.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1098
									
								
								custom_env/wordlist.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										119
									
								
								custom_env/wyrdl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								custom_env/wyrdl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,119 @@ | ||||
| import contextlib | ||||
| import pathlib | ||||
| import random | ||||
| from string import ascii_letters, ascii_lowercase | ||||
|  | ||||
| from rich.console import Console | ||||
| from rich.theme import Theme | ||||
|  | ||||
| console = Console(width=40, theme=Theme({"warning": "red on yellow"})) | ||||
|  | ||||
| NUM_LETTERS = 5 | ||||
| NUM_GUESSES = 6 | ||||
| WORDS_PATH = pathlib.Path(__file__).parent / "wordlist.txt" | ||||
|  | ||||
|  | ||||
| class Wordle: | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.word_list = WORDS_PATH.read_text(encoding="utf-8").split("\n") | ||||
|         self.n_guesses = 6 | ||||
|         self.num_letters = 5 | ||||
|         self.curr_word = None | ||||
|         self.reset() | ||||
|  | ||||
|     def refresh_page(self, headline): | ||||
|         console.clear() | ||||
|         console.rule(f"[bold blue]:leafy_green: {headline} :leafy_green:[/]\n") | ||||
|  | ||||
|     def start_game(self): | ||||
|         # get a new random word | ||||
|         word = self.get_random_word(self.word_list) | ||||
|  | ||||
|         self.curr_word = word | ||||
|  | ||||
|     def get_state(self): | ||||
|         return  | ||||
|  | ||||
|     def action_to_word(self, action): | ||||
|         # Calculate the word from the array | ||||
|         word = '' | ||||
|         for i in range(0, len(ascii_lowercase), 26): | ||||
|             # Find the index of 1 in each block of 26 | ||||
|             letter_index = action[i:i+26].index(1) | ||||
|             # Append the corresponding letter to the word | ||||
|             word += ascii_lowercase[letter_index] | ||||
|  | ||||
|         return word | ||||
|  | ||||
|     def play_guess(self, action): | ||||
|         # probably an array of length 26 * 5 for 26 letters and 5 positions | ||||
|         guess = action | ||||
|  | ||||
|     def get_random_word(self, word_list): | ||||
|         if words := [ | ||||
|             word.upper() | ||||
|             for word in word_list | ||||
|             if len(word) == NUM_LETTERS | ||||
|             and all(letter in ascii_letters for letter in word) | ||||
|         ]: | ||||
|             return random.choice(words) | ||||
|         else: | ||||
|             console.print( | ||||
|                 f"No words of length {NUM_LETTERS} in the word list", | ||||
|                 style="warning", | ||||
|             ) | ||||
|             raise SystemExit() | ||||
|  | ||||
|     def show_guesses(self, guesses, word): | ||||
|         letter_status = {letter: letter for letter in ascii_lowercase} | ||||
|         for guess in guesses: | ||||
|             styled_guess = [] | ||||
|             for letter, correct in zip(guess, word): | ||||
|                 if letter == correct: | ||||
|                     style = "bold white on green" | ||||
|                 elif letter in word: | ||||
|                     style = "bold white on yellow" | ||||
|                 elif letter in ascii_letters: | ||||
|                     style = "white on #666666" | ||||
|                 else: | ||||
|                     style = "dim" | ||||
|                 styled_guess.append(f"[{style}]{letter}[/]") | ||||
|                 if letter != "_": | ||||
|                     letter_status[letter] = f"[{style}]{letter}[/]" | ||||
|  | ||||
|             console.print("".join(styled_guess), justify="center") | ||||
|         console.print("\n" + "".join(letter_status.values()), justify="center") | ||||
|  | ||||
|     def guess_word(self, previous_guesses): | ||||
|         guess = console.input("\nGuess word: ").upper() | ||||
|  | ||||
|         if guess in previous_guesses: | ||||
|             console.print(f"You've already guessed {guess}.", style="warning") | ||||
|             return guess_word(previous_guesses) | ||||
|  | ||||
|         if len(guess) != NUM_LETTERS: | ||||
|             console.print( | ||||
|                 f"Your guess must be {NUM_LETTERS} letters.", style="warning" | ||||
|             ) | ||||
|             return guess_word(previous_guesses) | ||||
|  | ||||
|         if any((invalid := letter) not in ascii_letters for letter in guess): | ||||
|             console.print( | ||||
|                 f"Invalid letter: '{invalid}'. Please use English letters.", | ||||
|                 style="warning", | ||||
|             ) | ||||
|             return guess_word(previous_guesses) | ||||
|  | ||||
|         return guess | ||||
|  | ||||
|     def reset(self, guesses, word, guessed_correctly, n_episodes): | ||||
|         refresh_page(headline=f"Game: {n_episodes}") | ||||
|  | ||||
|         if guessed_correctly: | ||||
|             console.print(f"\n[bold white on green]Correct, the word is {word}[/]") | ||||
|         else: | ||||
|             console.print(f"\n[bold white on red]Sorry, the word was {word}[/]") | ||||
|  | ||||
|     if __name__ == "__main__": | ||||
|         main() | ||||
							
								
								
									
										128
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gymnasium as gym\n", | ||||
|     "from stable_baselines3 import DQN\n", | ||||
|     "import numpy as np" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "ename": "NameNotFound", | ||||
|      "evalue": "Environment `Wordle` doesn't exist.", | ||||
|      "output_type": "error", | ||||
|      "traceback": [ | ||||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | ||||
|       "\u001b[1;31mNameNotFound\u001b[0m                              Traceback (most recent call last)", | ||||
|       "Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m env \u001b[38;5;241m=\u001b[39m \u001b[43mgym\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mWordle-v0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(env)\n", | ||||
|       "File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:741\u001b[0m, in \u001b[0;36mmake\u001b[1;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[0;32m    738\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mid\u001b[39m, \u001b[38;5;28mstr\u001b[39m)\n\u001b[0;32m    740\u001b[0m     \u001b[38;5;66;03m# The environment name can include an unloaded module in \"module:env_name\" style\u001b[39;00m\n\u001b[1;32m--> 741\u001b[0m     env_spec \u001b[38;5;241m=\u001b[39m \u001b[43m_find_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m    743\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(env_spec, EnvSpec)\n\u001b[0;32m    745\u001b[0m \u001b[38;5;66;03m# Update the env spec kwargs with the `make` kwargs\u001b[39;00m\n", | ||||
|       "File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:527\u001b[0m, in \u001b[0;36m_find_spec\u001b[1;34m(env_id)\u001b[0m\n\u001b[0;32m    521\u001b[0m     logger\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[0;32m    522\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing the latest versioned environment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnew_env_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    523\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minstead of the unversioned environment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    524\u001b[0m     )\n\u001b[0;32m    526\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m env_spec \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 527\u001b[0m     \u001b[43m_check_version_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mversion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    528\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mError(\n\u001b[0;32m    529\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo registered env with id: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Did you register it, or import the package that registers it? Use `gymnasium.pprint_registry()` to see all of the registered environments.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    530\u001b[0m     )\n\u001b[0;32m    532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m env_spec\n", | ||||
|       "File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:393\u001b[0m, in \u001b[0;36m_check_version_exists\u001b[1;34m(ns, name, version)\u001b[0m\n\u001b[0;32m    390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m get_env_id(ns, name, version) \u001b[38;5;129;01min\u001b[39;00m registry:\n\u001b[0;32m    391\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m--> 393\u001b[0m \u001b[43m_check_name_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    394\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m version \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    395\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m\n", | ||||
|       "File \u001b[1;32mc:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\gymnasium\\envs\\registration.py:370\u001b[0m, in \u001b[0;36m_check_name_exists\u001b[1;34m(ns, name)\u001b[0m\n\u001b[0;32m    367\u001b[0m namespace_msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in namespace \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mns\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ns \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    368\u001b[0m suggestion_msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Did you mean: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestion[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`?\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m suggestion \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m--> 370\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m error\u001b[38;5;241m.\u001b[39mNameNotFound(\n\u001b[0;32m    371\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEnvironment `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt exist\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnamespace_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestion_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    372\u001b[0m )\n", | ||||
|       "\u001b[1;31mNameNotFound\u001b[0m: Environment `Wordle` doesn't exist." | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym.make(\"Wordle-v0\")\n", | ||||
|     "\n", | ||||
|     "print(env)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 35, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "total_timesteps = 100000\n", | ||||
|     "model = DQN(\"MlpPolicy\", env, verbose=0)\n", | ||||
|     "model.learn(total_timesteps=total_timesteps, progress_bar=True)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def test(model):\n", | ||||
|     "\n", | ||||
|     "    end_rewards = []\n", | ||||
|     "\n", | ||||
|     "    for i in range(1000):\n", | ||||
|     "        \n", | ||||
|     "        state = env.reset()\n", | ||||
|     "\n", | ||||
|     "        done = False\n", | ||||
|     "\n", | ||||
|     "        while not done:\n", | ||||
|     "\n", | ||||
|     "            action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "            state, reward, done, info = env.step(action)\n", | ||||
|     "            \n", | ||||
|     "        end_rewards.append(reward == 0)\n", | ||||
|     "        \n", | ||||
|     "    return np.sum(end_rewards) / len(end_rewards)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model.save(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model = DQN.load(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "print(test(model))" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.5" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
							
								
								
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										61
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								test.py
									
									
									
									
									
								
							| @@ -1,61 +0,0 @@ | ||||
|  | ||||
| from torch.utils.data import Dataset | ||||
| from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer | ||||
| from tqdm import tqdm as progress_bar | ||||
| import torch | ||||
| import matplotlib | ||||
|  | ||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| print(device) | ||||
|  | ||||
| encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102) | ||||
| # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token | ||||
| decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) | ||||
| model = EncoderDecoderModel(encoder=encoder, decoder=decoder) | ||||
|  | ||||
| # create tokenizer... | ||||
| tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") | ||||
|  | ||||
| import json | ||||
|  | ||||
| class CodeDataset(Dataset): | ||||
|     def __init__(self): | ||||
|         with open("data/conala-train.json") as f: | ||||
|             self.data = json.load(f) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         intent = self.data[idx]["rewritten_intent"] if self.data[idx]["rewritten_intent"] else self.data[idx]["intent"] | ||||
|         return intent, self.data[idx]["snippet"] | ||||
|  | ||||
|  | ||||
| optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3) | ||||
| dataloader = CodeDataset() | ||||
| model = model.to(device) | ||||
|  | ||||
| losses = [] | ||||
| epochs = 10 | ||||
| for i in range(epochs): | ||||
|  | ||||
|     epoch_loss = 0 | ||||
|  | ||||
|     for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)): | ||||
|  | ||||
|         input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device) | ||||
|         label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) | ||||
|  | ||||
|         loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss | ||||
|  | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|  | ||||
|         epoch_loss += loss.item() | ||||
|  | ||||
|     losses.append(epoch_loss) | ||||
|  | ||||
| plt.plot(losses, color="green", label="Training Loss") | ||||
| plt.legend(loc = 'upper left') | ||||
| plt.savefig("plot.png") | ||||
		Reference in New Issue
	
	Block a user