mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			2 Commits
		
	
	
		
			9172326013
			...
			gymnasium-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | cf977e4797 | ||
|  | bbe9a1891c | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,3 +1,4 @@ | |||||||
| **/data/* | **/data/* | ||||||
| **/*.zip | **/*.zip | ||||||
| **/__pycache__ | **/__pycache__ | ||||||
|  | /env | ||||||
							
								
								
									
										1179
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
							
						
						
									
										1179
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -35,7 +35,7 @@ def to_array(word: str) -> npt.NDArray[np.int64]: | |||||||
|     return np.array([_char_d[c] for c in word]) |     return np.array([_char_d[c] for c in word]) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]: | ||||||
|     """Loads a list of words in array form.  |     """Loads a list of words in array form.  | ||||||
|  |  | ||||||
|     If specified, this will recompute the list from the human-readable list of |     If specified, this will recompute the list from the human-readable list of | ||||||
| @@ -53,14 +53,14 @@ def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | |||||||
|         five. |         five. | ||||||
|     """ |     """ | ||||||
|     assert category in {'guess', 'solution'} |     assert category in {'guess', 'solution'} | ||||||
|      |  | ||||||
|     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' |     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' | ||||||
|     if build: |     if build: | ||||||
|        list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' |         list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||||
|  |  | ||||||
|        with open(list_path, 'r') as f: |         with open(list_path, 'r') as f: | ||||||
|            words = np.array([to_array(line.strip()) for line in f]) |             words = np.array([to_array(line.strip()) for line in f]) | ||||||
|            np.save(arr_path, words) |             np.save(arr_path, words) | ||||||
|  |  | ||||||
|     return np.load(arr_path) |     return np.load(arr_path) | ||||||
|  |  | ||||||
| @@ -69,16 +69,16 @@ def play(): | |||||||
|     """Play Wordle yourself!""" |     """Play Wordle yourself!""" | ||||||
|     import gym |     import gym | ||||||
|     import gym_wordle |     import gym_wordle | ||||||
|      |  | ||||||
|     env = gym.make('Wordle-v0')  # load the environment |     env = gym.make('Wordle-v0')  # load the environment | ||||||
|      |  | ||||||
|     env.reset() |     env.reset() | ||||||
|     solution = to_english(env.unwrapped.solution_space[env.solution]).upper()  # no peeking! |     solution = to_english(env.unwrapped.solution_space[env.solution]).upper()  # no peeking! | ||||||
|  |  | ||||||
|     done = False |     done = False | ||||||
|      |  | ||||||
|     while not done: |     while not done: | ||||||
|         action = -1  |         action = -1 | ||||||
|  |  | ||||||
|         # in general, the environment won't be forgiving if you input an |         # in general, the environment won't be forgiving if you input an | ||||||
|         # invalid word, but for this function I want to let you screw up user |         # invalid word, but for this function I want to let you screw up user | ||||||
| @@ -86,8 +86,8 @@ def play(): | |||||||
|         while not env.action_space.contains(action): |         while not env.action_space.contains(action): | ||||||
|             guess = input('Guess: ') |             guess = input('Guess: ') | ||||||
|             action = env.unwrapped.action_space.index_of(to_array(guess)) |             action = env.unwrapped.action_space.index_of(to_array(guess)) | ||||||
|      |  | ||||||
|         state, reward, done, info = env.step(action) |         state, reward, done, info = env.step(action) | ||||||
|         env.render() |         env.render() | ||||||
|      |  | ||||||
|     print(f"The word was {solution}") |     print(f"The word was {solution}") | ||||||
|   | |||||||
| @@ -1,17 +1,16 @@ | |||||||
| import gym | import gymnasium as gym | ||||||
| import numpy as np | import numpy as np | ||||||
| import numpy.typing as npt | import numpy.typing as npt | ||||||
| from sty import fg, bg, ef, rs | from sty import fg, bg, ef, rs | ||||||
|  |  | ||||||
| from collections import Counter  | from collections import Counter, defaultdict | ||||||
| from gym_wordle.utils import to_english, to_array, get_words | from gym_wordle.utils import to_english, to_array, get_words | ||||||
| from typing import Optional  | from typing import Optional | ||||||
|  |  | ||||||
| class WordList(gym.spaces.Discrete): | class WordList(gym.spaces.Discrete): | ||||||
|     """Super class for defining a space of valid words according to a specified |     """Super class for defining a space of valid words according to a specified | ||||||
|     list. |     list. | ||||||
|  |  | ||||||
|     TODO: Fix these paragraphs |  | ||||||
|     The space is a subclass of gym.spaces.Discrete, where each element |     The space is a subclass of gym.spaces.Discrete, where each element | ||||||
|     corresponds to an index of a valid word in the word list. The obfuscation |     corresponds to an index of a valid word in the word list. The obfuscation | ||||||
|     is necessary for more direct implementation of RL algorithms, which expect |     is necessary for more direct implementation of RL algorithms, which expect | ||||||
| @@ -66,16 +65,15 @@ class SolutionList(WordList): | |||||||
|  |  | ||||||
|     In the game Wordle, there are two different collections of words: |     In the game Wordle, there are two different collections of words: | ||||||
|  |  | ||||||
|         * "guesses", which the game accepts as valid words to use to guess the |     * "guesses", which the game accepts as valid words to use to guess the | ||||||
|           answer. |       answer. | ||||||
|         * "solutions", which the game uses to choose solutions from. |     * "solutions", which the game uses to choose solutions from. | ||||||
|  |  | ||||||
|     Of course, the set of solutions is a strict subset of the set of guesses. |     Of course, the set of solutions is a strict subset of the set of guesses. | ||||||
|  |  | ||||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ |  | ||||||
|  |  | ||||||
|     This class represents the set of solution words. |     This class represents the set of solution words. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs): | ||||||
|         """ |         """ | ||||||
|         Args: |         Args: | ||||||
| @@ -87,7 +85,7 @@ class SolutionList(WordList): | |||||||
|  |  | ||||||
| class WordleObsSpace(gym.spaces.Box): | class WordleObsSpace(gym.spaces.Box): | ||||||
|     """Implementation of the state (observation) space in terms of gym |     """Implementation of the state (observation) space in terms of gym | ||||||
|     primatives, in this case, gym.spaces.Box. |     primitives, in this case, gym.spaces.Box. | ||||||
|  |  | ||||||
|     The Wordle observation space can be thought of as a 6x5 array with two |     The Wordle observation space can be thought of as a 6x5 array with two | ||||||
|     channels: |     channels: | ||||||
| @@ -100,20 +98,11 @@ class WordleObsSpace(gym.spaces.Box): | |||||||
|     where there are 6 rows, one for each turn in the game, and 5 columns, since |     where there are 6 rows, one for each turn in the game, and 5 columns, since | ||||||
|     the solution will always be a word of length 5. |     the solution will always be a word of length 5. | ||||||
|  |  | ||||||
|     For simplicity, and compatibility with the stable_baselines algorithms, |     For simplicity, and compatibility with stable_baselines algorithms, | ||||||
|     this multichannel is modeled as a 6x10 array, where the two channels are |     this multichannel is modeled as a 6x10 array, where the two channels are | ||||||
|     horizontally appended (along columns). Thus each row in the observation |     horizontally appended (along columns). Thus each row in the observation | ||||||
|     should be interpreted as  |     should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is | ||||||
|  |     c0...c4 and its associated flags are f0...f4. | ||||||
|         c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 |  | ||||||
|  |  | ||||||
|     when the word is c0...c4 and its associated flags are f0...f4. |  | ||||||
|  |  | ||||||
|     While the superclass method `sample` is available to the WordleObsSpace, it |  | ||||||
|     should be emphasized that the output of `sample` will (almost surely) not |  | ||||||
|     correspond to a real game configuration, because the sampling is not out of |  | ||||||
|     possible game configurations. Instead, the Box superclass just samples the |  | ||||||
|     integer array space uniformly. |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs): | ||||||
| @@ -130,20 +119,11 @@ class WordleObsSpace(gym.spaces.Box): | |||||||
|  |  | ||||||
|  |  | ||||||
| class GuessList(WordList): | class GuessList(WordList): | ||||||
|     """Space for *solution* words to the Wordle environment. |     """Space for *guess* words to the Wordle environment. | ||||||
|  |  | ||||||
|     In the game Wordle, there are two different collections of words: |  | ||||||
|  |  | ||||||
|         * "guesses", which the game accepts as valid words to use to guess the |  | ||||||
|           answer. |  | ||||||
|         * "solutions", which the game uses to choose solutions from. |  | ||||||
|  |  | ||||||
|     Of course, the set of solutions is a strict subset of the set of guesses. |  | ||||||
|  |  | ||||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ |  | ||||||
|  |  | ||||||
|     This class represents the set of guess words. |     This class represents the set of guess words. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, **kwargs): |     def __init__(self, **kwargs): | ||||||
|         """ |         """ | ||||||
|         Args: |         Args: | ||||||
| @@ -154,10 +134,9 @@ class GuessList(WordList): | |||||||
|  |  | ||||||
|  |  | ||||||
| class WordleEnv(gym.Env): | class WordleEnv(gym.Env): | ||||||
|  |  | ||||||
|     metadata = {'render.modes': ['human']} |     metadata = {'render.modes': ['human']} | ||||||
|  |  | ||||||
|     # character flag codes |     # Character flag codes | ||||||
|     no_char = 0 |     no_char = 0 | ||||||
|     right_pos = 1 |     right_pos = 1 | ||||||
|     wrong_pos = 2 |     wrong_pos = 2 | ||||||
| @@ -166,7 +145,6 @@ class WordleEnv(gym.Env): | |||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
|         self.seed() |  | ||||||
|         self.action_space = GuessList() |         self.action_space = GuessList() | ||||||
|         self.solution_space = SolutionList() |         self.solution_space = SolutionList() | ||||||
|  |  | ||||||
| @@ -181,6 +159,7 @@ class WordleEnv(gym.Env): | |||||||
|  |  | ||||||
|         self.n_rounds = 6 |         self.n_rounds = 6 | ||||||
|         self.n_letters = 5 |         self.n_letters = 5 | ||||||
|  |         self.info = {'correct': False, 'guesses': defaultdict(int)} | ||||||
|  |  | ||||||
|     def _highlighter(self, char: str, flag: int) -> str: |     def _highlighter(self, char: str, flag: int) -> str: | ||||||
|         """Terminal renderer functionality. Properly highlights a character |         """Terminal renderer functionality. Properly highlights a character | ||||||
| @@ -201,73 +180,74 @@ class WordleEnv(gym.Env): | |||||||
|         front, back = self._highlights[flag] |         front, back = self._highlights[flag] | ||||||
|         return front + char + back |         return front + char + back | ||||||
|  |  | ||||||
|     def reset(self): |     def reset(self, seed=None, options=None): | ||||||
|  |         """Reset the environment to an initial state and returns an initial | ||||||
|  |         observation. | ||||||
|  |  | ||||||
|  |         Note: The observation space instance should be a Box space. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             state (object): The initial observation of the space. | ||||||
|  |         """ | ||||||
|         self.round = 0 |         self.round = 0 | ||||||
|         self.solution = self.solution_space.sample() |         self.solution = self.solution_space.sample() | ||||||
|  |  | ||||||
|         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) |         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) | ||||||
|  |  | ||||||
|         return self.state |         self.info = {'correct': False, 'guesses': defaultdict(int)} | ||||||
|  |  | ||||||
|     def render(self, mode: str ='human'): |         return self.state, self.info | ||||||
|  |  | ||||||
|  |     def render(self, mode: str = 'human'): | ||||||
|         """Renders the Wordle environment. |         """Renders the Wordle environment. | ||||||
|  |  | ||||||
|         Currently supported render modes: |         Currently supported render modes: | ||||||
|  |  | ||||||
|         - human: renders the Wordle game to the terminal. |         - human: renders the Wordle game to the terminal. | ||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|             mode: the mode to render with |             mode: the mode to render with. | ||||||
|         """ |         """ | ||||||
|         if mode == 'human': |         if mode == 'human': | ||||||
|             for row in self.states: |             for row in self.state: | ||||||
|                 text = ''.join(map( |                 text = ''.join(map( | ||||||
|                     self._highlighter,  |                     self._highlighter, | ||||||
|                     to_english(row[:self.n_letters]).upper(),  |                     to_english(row[:self.n_letters]).upper(), | ||||||
|                     row[self.n_letters:] |                     row[self.n_letters:] | ||||||
|                 )) |                 )) | ||||||
|  |  | ||||||
|                 print(text) |                 print(text) | ||||||
|         else: |         else: | ||||||
|             super(WordleEnv, self).render(mode=mode) |             super().render(mode=mode) | ||||||
|                  |  | ||||||
|     def step(self, action): |     def step(self, action): | ||||||
|         """Run one step of the Wordle game. Every game must be previously |         """Run one step of the Wordle game. Every game must be previously | ||||||
|         initialized by a call to the `reset` method. |         initialized by a call to the `reset` method. | ||||||
|  |  | ||||||
|         Args: |         Args: | ||||||
|             action: Word guessed by the agent. |             action: Word guessed by the agent. | ||||||
|  |  | ||||||
|         Returns: |         Returns: | ||||||
|             state (object): Wordle game state after the guess. |             state (object): Wordle game state after the guess. | ||||||
|             reward (float): Reward associated with the guess (-1 for incorrect, |             reward (float): Reward associated with the guess. | ||||||
|               0 for correct) |             done (bool): Whether the game has ended. | ||||||
|             done (bool): Whether the game has ended (by a correct guess or |             info (dict): Auxiliary diagnostic information. | ||||||
|               after six guesses). |  | ||||||
|             info (dict): Auxiliary diagnostic information (empty). |  | ||||||
|         """ |         """ | ||||||
|         assert self.action_space.contains(action), 'Invalid word!' |         assert self.action_space.contains(action), 'Invalid word!' | ||||||
|  |  | ||||||
|         # transform the action, solution indices to their words |         action = self.action_space[action] | ||||||
|         action = self.action_space[action]   |  | ||||||
|         solution = self.solution_space[self.solution] |         solution = self.solution_space[self.solution] | ||||||
|  |  | ||||||
|         # populate the word chars into the row (character channel) |  | ||||||
|         self.state[self.round][:self.n_letters] = action |         self.state[self.round][:self.n_letters] = action | ||||||
|  |  | ||||||
|         # populate the flag characters into the row (flag channel) |  | ||||||
|         counter = Counter() |         counter = Counter() | ||||||
|         for i, char in enumerate(action): |         for i, char in enumerate(action): | ||||||
|             flag_i = i + self.n_letters  # starts at 5 |             flag_i = i + self.n_letters | ||||||
|             counter[char] += 1 |             counter[char] += 1 | ||||||
|  |  | ||||||
|             if char == solution[i]:  # character is in correct position |             if char == solution[i]: | ||||||
|                 self.state[self.round, flag_i] = self.right_pos |                 self.state[self.round, flag_i] = self.right_pos | ||||||
|             elif counter[char] <= (char == solution).sum():   |             elif counter[char] <= (char == solution).sum(): | ||||||
|                 # current character has been seen within correct number of |  | ||||||
|                 # occurrences  |  | ||||||
|                 self.state[self.round, flag_i] = self.wrong_pos |                 self.state[self.round, flag_i] = self.wrong_pos | ||||||
|             else: |             else: | ||||||
|                 # wrong character, or "correct" character too many times |  | ||||||
|                 self.state[self.round, flag_i] = self.wrong_char |                 self.state[self.round, flag_i] = self.wrong_char | ||||||
|  |  | ||||||
|         self.round += 1 |         self.round += 1 | ||||||
| @@ -277,20 +257,29 @@ class WordleEnv(gym.Env): | |||||||
|  |  | ||||||
|         done = correct or game_over |         done = correct or game_over | ||||||
|  |  | ||||||
|         # Total reward equals -(number of incorrect guesses) |  | ||||||
|         # reward = 0. if correct else -1. |  | ||||||
|          |  | ||||||
|         # correct +10 |  | ||||||
|         # guesses new letter +1 |  | ||||||
|         # guesses correct letter +1 |  | ||||||
|         # spent another guess -1 |  | ||||||
|  |  | ||||||
|         reward = 0 |         reward = 0 | ||||||
|         reward += np.sum(self.state[:, 5:] == 1) * 1 |         # correct spot | ||||||
|         reward += np.sum(self.state[:, 5:] == 2) * 0.5 |         reward += np.sum(self.state[:, 5:] == 1) * 2 | ||||||
|  |  | ||||||
|  |         # correct letter not correct spot | ||||||
|  |         reward += np.sum(self.state[:, 5:] == 2) * 1 | ||||||
|  |  | ||||||
|  |         # incorrect letter | ||||||
|         reward += np.sum(self.state[:, 5:] == 3) * -1 |         reward += np.sum(self.state[:, 5:] == 3) * -1 | ||||||
|  |  | ||||||
|  |         # guess same word as before | ||||||
|  |         hashable_action = to_english(action) | ||||||
|  |         if hashable_action in self.info['guesses']: | ||||||
|  |             reward += -10 * self.info['guesses'][hashable_action] | ||||||
|  |         else:  # guess different word | ||||||
|  |             reward += 10 | ||||||
|  |  | ||||||
|  |         self.info['guesses'][hashable_action] += 1 | ||||||
|  |  | ||||||
|  |         # for game ending in win or loss | ||||||
|         reward += 10 if correct else -10 if done else 0 |         reward += 10 if correct else -10 if done else 0 | ||||||
|  |  | ||||||
|         info = {'correct': correct} |         self.info['correct'] = correct | ||||||
|  |  | ||||||
|         return self.state, reward, done, info |         # observation, reward, terminated, truncated, info | ||||||
|  |         return self.state, reward, done, False, self.info | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user