mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			2 Commits
		
	
	
		
			9172326013
			...
			gymnasium-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | cf977e4797 | ||
|  | bbe9a1891c | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,3 +1,4 @@ | ||||
| **/data/* | ||||
| **/*.zip | ||||
| **/__pycache__ | ||||
| /env | ||||
							
								
								
									
										1179
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
							
						
						
									
										1179
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -35,7 +35,7 @@ def to_array(word: str) -> npt.NDArray[np.int64]: | ||||
|     return np.array([_char_d[c] for c in word]) | ||||
|  | ||||
|  | ||||
| def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | ||||
| def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]: | ||||
|     """Loads a list of words in array form.  | ||||
|  | ||||
|     If specified, this will recompute the list from the human-readable list of | ||||
| @@ -56,11 +56,11 @@ def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | ||||
|  | ||||
|     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' | ||||
|     if build: | ||||
|        list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||
|         list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||
|  | ||||
|        with open(list_path, 'r') as f: | ||||
|            words = np.array([to_array(line.strip()) for line in f]) | ||||
|            np.save(arr_path, words) | ||||
|         with open(list_path, 'r') as f: | ||||
|             words = np.array([to_array(line.strip()) for line in f]) | ||||
|             np.save(arr_path, words) | ||||
|  | ||||
|     return np.load(arr_path) | ||||
|  | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| import gym | ||||
| import gymnasium as gym | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
| from sty import fg, bg, ef, rs | ||||
|  | ||||
| from collections import Counter  | ||||
| from collections import Counter, defaultdict | ||||
| from gym_wordle.utils import to_english, to_array, get_words | ||||
| from typing import Optional | ||||
|  | ||||
| @@ -11,7 +11,6 @@ class WordList(gym.spaces.Discrete): | ||||
|     """Super class for defining a space of valid words according to a specified | ||||
|     list. | ||||
|  | ||||
|     TODO: Fix these paragraphs | ||||
|     The space is a subclass of gym.spaces.Discrete, where each element | ||||
|     corresponds to an index of a valid word in the word list. The obfuscation | ||||
|     is necessary for more direct implementation of RL algorithms, which expect | ||||
| @@ -66,16 +65,15 @@ class SolutionList(WordList): | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|         * "guesses", which the game accepts as valid words to use to guess the | ||||
|           answer. | ||||
|         * "solutions", which the game uses to choose solutions from. | ||||
|     * "guesses", which the game accepts as valid words to use to guess the | ||||
|       answer. | ||||
|     * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||
|  | ||||
|     This class represents the set of solution words. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
| @@ -87,7 +85,7 @@ class SolutionList(WordList): | ||||
|  | ||||
| class WordleObsSpace(gym.spaces.Box): | ||||
|     """Implementation of the state (observation) space in terms of gym | ||||
|     primatives, in this case, gym.spaces.Box. | ||||
|     primitives, in this case, gym.spaces.Box. | ||||
|  | ||||
|     The Wordle observation space can be thought of as a 6x5 array with two | ||||
|     channels: | ||||
| @@ -100,20 +98,11 @@ class WordleObsSpace(gym.spaces.Box): | ||||
|     where there are 6 rows, one for each turn in the game, and 5 columns, since | ||||
|     the solution will always be a word of length 5. | ||||
|  | ||||
|     For simplicity, and compatibility with the stable_baselines algorithms, | ||||
|     For simplicity, and compatibility with stable_baselines algorithms, | ||||
|     this multichannel is modeled as a 6x10 array, where the two channels are | ||||
|     horizontally appended (along columns). Thus each row in the observation | ||||
|     should be interpreted as  | ||||
|  | ||||
|         c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 | ||||
|  | ||||
|     when the word is c0...c4 and its associated flags are f0...f4. | ||||
|  | ||||
|     While the superclass method `sample` is available to the WordleObsSpace, it | ||||
|     should be emphasized that the output of `sample` will (almost surely) not | ||||
|     correspond to a real game configuration, because the sampling is not out of | ||||
|     possible game configurations. Instead, the Box superclass just samples the | ||||
|     integer array space uniformly. | ||||
|     should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is | ||||
|     c0...c4 and its associated flags are f0...f4. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
| @@ -130,20 +119,11 @@ class WordleObsSpace(gym.spaces.Box): | ||||
|  | ||||
|  | ||||
| class GuessList(WordList): | ||||
|     """Space for *solution* words to the Wordle environment. | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|         * "guesses", which the game accepts as valid words to use to guess the | ||||
|           answer. | ||||
|         * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||
|     """Space for *guess* words to the Wordle environment. | ||||
|  | ||||
|     This class represents the set of guess words. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
| @@ -154,10 +134,9 @@ class GuessList(WordList): | ||||
|  | ||||
|  | ||||
| class WordleEnv(gym.Env): | ||||
|  | ||||
|     metadata = {'render.modes': ['human']} | ||||
|  | ||||
|     # character flag codes | ||||
|     # Character flag codes | ||||
|     no_char = 0 | ||||
|     right_pos = 1 | ||||
|     wrong_pos = 2 | ||||
| @@ -166,7 +145,6 @@ class WordleEnv(gym.Env): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.seed() | ||||
|         self.action_space = GuessList() | ||||
|         self.solution_space = SolutionList() | ||||
|  | ||||
| @@ -181,6 +159,7 @@ class WordleEnv(gym.Env): | ||||
|  | ||||
|         self.n_rounds = 6 | ||||
|         self.n_letters = 5 | ||||
|         self.info = {'correct': False, 'guesses': defaultdict(int)} | ||||
|  | ||||
|     def _highlighter(self, char: str, flag: int) -> str: | ||||
|         """Terminal renderer functionality. Properly highlights a character | ||||
| @@ -201,35 +180,43 @@ class WordleEnv(gym.Env): | ||||
|         front, back = self._highlights[flag] | ||||
|         return front + char + back | ||||
|  | ||||
|     def reset(self): | ||||
|     def reset(self, seed=None, options=None): | ||||
|         """Reset the environment to an initial state and returns an initial | ||||
|         observation. | ||||
|  | ||||
|         Note: The observation space instance should be a Box space. | ||||
|  | ||||
|         Returns: | ||||
|             state (object): The initial observation of the space. | ||||
|         """ | ||||
|         self.round = 0 | ||||
|         self.solution = self.solution_space.sample() | ||||
|  | ||||
|         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) | ||||
|  | ||||
|         return self.state | ||||
|         self.info = {'correct': False, 'guesses': defaultdict(int)} | ||||
|  | ||||
|     def render(self, mode: str ='human'): | ||||
|         return self.state, self.info | ||||
|  | ||||
|     def render(self, mode: str = 'human'): | ||||
|         """Renders the Wordle environment. | ||||
|  | ||||
|         Currently supported render modes: | ||||
|  | ||||
|         - human: renders the Wordle game to the terminal. | ||||
|  | ||||
|         Args: | ||||
|             mode: the mode to render with | ||||
|             mode: the mode to render with. | ||||
|         """ | ||||
|         if mode == 'human': | ||||
|             for row in self.states: | ||||
|             for row in self.state: | ||||
|                 text = ''.join(map( | ||||
|                     self._highlighter, | ||||
|                     to_english(row[:self.n_letters]).upper(), | ||||
|                     row[self.n_letters:] | ||||
|                 )) | ||||
|  | ||||
|                 print(text) | ||||
|         else: | ||||
|             super(WordleEnv, self).render(mode=mode) | ||||
|             super().render(mode=mode) | ||||
|  | ||||
|     def step(self, action): | ||||
|         """Run one step of the Wordle game. Every game must be previously | ||||
| @@ -237,37 +224,30 @@ class WordleEnv(gym.Env): | ||||
|  | ||||
|         Args: | ||||
|             action: Word guessed by the agent. | ||||
|  | ||||
|         Returns: | ||||
|             state (object): Wordle game state after the guess. | ||||
|             reward (float): Reward associated with the guess (-1 for incorrect, | ||||
|               0 for correct) | ||||
|             done (bool): Whether the game has ended (by a correct guess or | ||||
|               after six guesses). | ||||
|             info (dict): Auxiliary diagnostic information (empty). | ||||
|             reward (float): Reward associated with the guess. | ||||
|             done (bool): Whether the game has ended. | ||||
|             info (dict): Auxiliary diagnostic information. | ||||
|         """ | ||||
|         assert self.action_space.contains(action), 'Invalid word!' | ||||
|  | ||||
|         # transform the action, solution indices to their words | ||||
|         action = self.action_space[action] | ||||
|         solution = self.solution_space[self.solution] | ||||
|  | ||||
|         # populate the word chars into the row (character channel) | ||||
|         self.state[self.round][:self.n_letters] = action | ||||
|  | ||||
|         # populate the flag characters into the row (flag channel) | ||||
|         counter = Counter() | ||||
|         for i, char in enumerate(action): | ||||
|             flag_i = i + self.n_letters  # starts at 5 | ||||
|             flag_i = i + self.n_letters | ||||
|             counter[char] += 1 | ||||
|  | ||||
|             if char == solution[i]:  # character is in correct position | ||||
|             if char == solution[i]: | ||||
|                 self.state[self.round, flag_i] = self.right_pos | ||||
|             elif counter[char] <= (char == solution).sum(): | ||||
|                 # current character has been seen within correct number of | ||||
|                 # occurrences  | ||||
|                 self.state[self.round, flag_i] = self.wrong_pos | ||||
|             else: | ||||
|                 # wrong character, or "correct" character too many times | ||||
|                 self.state[self.round, flag_i] = self.wrong_char | ||||
|  | ||||
|         self.round += 1 | ||||
| @@ -277,20 +257,29 @@ class WordleEnv(gym.Env): | ||||
|  | ||||
|         done = correct or game_over | ||||
|  | ||||
|         # Total reward equals -(number of incorrect guesses) | ||||
|         # reward = 0. if correct else -1. | ||||
|          | ||||
|         # correct +10 | ||||
|         # guesses new letter +1 | ||||
|         # guesses correct letter +1 | ||||
|         # spent another guess -1 | ||||
|  | ||||
|         reward = 0 | ||||
|         reward += np.sum(self.state[:, 5:] == 1) * 1 | ||||
|         reward += np.sum(self.state[:, 5:] == 2) * 0.5 | ||||
|         # correct spot | ||||
|         reward += np.sum(self.state[:, 5:] == 1) * 2 | ||||
|  | ||||
|         # correct letter not correct spot | ||||
|         reward += np.sum(self.state[:, 5:] == 2) * 1 | ||||
|  | ||||
|         # incorrect letter | ||||
|         reward += np.sum(self.state[:, 5:] == 3) * -1 | ||||
|  | ||||
|         # guess same word as before | ||||
|         hashable_action = to_english(action) | ||||
|         if hashable_action in self.info['guesses']: | ||||
|             reward += -10 * self.info['guesses'][hashable_action] | ||||
|         else:  # guess different word | ||||
|             reward += 10 | ||||
|  | ||||
|         self.info['guesses'][hashable_action] += 1 | ||||
|  | ||||
|         # for game ending in win or loss | ||||
|         reward += 10 if correct else -10 if done else 0 | ||||
|  | ||||
|         info = {'correct': correct} | ||||
|         self.info['correct'] = correct | ||||
|  | ||||
|         return self.state, reward, done, info | ||||
|         # observation, reward, terminated, truncated, info | ||||
|         return self.state, reward, done, False, self.info | ||||
|   | ||||
		Reference in New Issue
	
	Block a user