mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			ethan-test
			...
			9172326013
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 9172326013 | ||
|  | 4836be8121 | ||
|  | 5672169073 | ||
|  | 5ec123e0f1 | ||
|  | e9622b6f68 | ||
|  | 83e81722d2 | ||
|  | 320f2f81b7 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1,3 @@ | ||||
| **/data/* | ||||
| **/data/* | ||||
| **/*.zip | ||||
| **/__pycache__ | ||||
							
								
								
									
										1160
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1160
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										7
									
								
								gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| from gym.envs.registration import register | ||||
| from .wordle import WordleEnv | ||||
|  | ||||
| register( | ||||
|     id='Wordle-v0', | ||||
|     entry_point='gym_wordle.wordle:WordleEnv' | ||||
| ) | ||||
							
								
								
									
										12972
									
								
								gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12972
									
								
								gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2315
									
								
								gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2315
									
								
								gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										93
									
								
								gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
|  | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| _chars = ' abcdefghijklmnopqrstuvwxyz' | ||||
| _char_d = {c: i for i, c in enumerate(_chars)} | ||||
|  | ||||
|  | ||||
| def to_english(array: npt.NDArray[np.int64]) -> str: | ||||
|     """Converts a numpy integer array into a corresponding English string. | ||||
|  | ||||
|     Args: | ||||
|         array: Word in array (int) form. It is assumed that each integer in the | ||||
|           array is between 0,...,26 (inclusive). | ||||
|  | ||||
|     Returns: | ||||
|         A (lowercase) string representation of the word.     | ||||
|     """ | ||||
|     return ''.join(_chars[i] for i in array) | ||||
|  | ||||
|  | ||||
| def to_array(word: str) -> npt.NDArray[np.int64]: | ||||
|     """Converts a string of characters into a corresponding numpy array. | ||||
|  | ||||
|     Args: | ||||
|         word: Word in string form. It is assumed that each character in the | ||||
|           string is either an empty space ' ' or lowercase alphabetical | ||||
|           character. | ||||
|  | ||||
|     Returns: | ||||
|         An array representation of the word. | ||||
|     """ | ||||
|     return np.array([_char_d[c] for c in word]) | ||||
|  | ||||
|  | ||||
| def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | ||||
|     """Loads a list of words in array form.  | ||||
|  | ||||
|     If specified, this will recompute the list from the human-readable list of | ||||
|     words, and save the results in array form. | ||||
|  | ||||
|     Args: | ||||
|         category: Either 'guess' or 'solution', which corresponds to the list | ||||
|           of acceptable guess words and the list of acceptable solution words. | ||||
|         build: If True, recomputes and saves the array-version of the computed | ||||
|           list for future access. | ||||
|  | ||||
|     Returns: | ||||
|         An array representation of the list of words specified by the category. | ||||
|         This array has two dimensions, and the number of columns is fixed at | ||||
|         five. | ||||
|     """ | ||||
|     assert category in {'guess', 'solution'} | ||||
|      | ||||
|     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' | ||||
|     if build: | ||||
|        list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||
|  | ||||
|        with open(list_path, 'r') as f: | ||||
|            words = np.array([to_array(line.strip()) for line in f]) | ||||
|            np.save(arr_path, words) | ||||
|  | ||||
|     return np.load(arr_path) | ||||
|  | ||||
|  | ||||
| def play(): | ||||
|     """Play Wordle yourself!""" | ||||
|     import gym | ||||
|     import gym_wordle | ||||
|      | ||||
|     env = gym.make('Wordle-v0')  # load the environment | ||||
|      | ||||
|     env.reset() | ||||
|     solution = to_english(env.unwrapped.solution_space[env.solution]).upper()  # no peeking! | ||||
|  | ||||
|     done = False | ||||
|      | ||||
|     while not done: | ||||
|         action = -1  | ||||
|  | ||||
|         # in general, the environment won't be forgiving if you input an | ||||
|         # invalid word, but for this function I want to let you screw up user | ||||
|         # input without consequence, so just loops until valid input is taken | ||||
|         while not env.action_space.contains(action): | ||||
|             guess = input('Guess: ') | ||||
|             action = env.unwrapped.action_space.index_of(to_array(guess)) | ||||
|      | ||||
|         state, reward, done, info = env.step(action) | ||||
|         env.render() | ||||
|      | ||||
|     print(f"The word was {solution}") | ||||
							
								
								
									
										296
									
								
								gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										296
									
								
								gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,296 @@ | ||||
| import gym | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
| from sty import fg, bg, ef, rs | ||||
|  | ||||
| from collections import Counter  | ||||
| from gym_wordle.utils import to_english, to_array, get_words | ||||
| from typing import Optional  | ||||
|  | ||||
| class WordList(gym.spaces.Discrete): | ||||
|     """Super class for defining a space of valid words according to a specified | ||||
|     list. | ||||
|  | ||||
|     TODO: Fix these paragraphs | ||||
|     The space is a subclass of gym.spaces.Discrete, where each element | ||||
|     corresponds to an index of a valid word in the word list. The obfuscation | ||||
|     is necessary for more direct implementation of RL algorithms, which expect | ||||
|     spaces of less sophisticated form. | ||||
|  | ||||
|     In addition to the default methods of the Discrete space, it implements | ||||
|     a __getitem__ method for easy index lookup, and an index_of method to | ||||
|     convert potential words into their corresponding index (if they exist). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, words: npt.NDArray[np.int64], **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             words: Collection of words in array form with shape (_, 5), where | ||||
|               each word is a row of the array. Each array element is an integer | ||||
|               between 0,...,26 (inclusive). | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         super().__init__(words.shape[0], **kwargs) | ||||
|         self.words = words | ||||
|  | ||||
|     def __getitem__(self, index: int) -> npt.NDArray[np.int64]: | ||||
|         """Obtains the (int-encoded) word associated with the given index. | ||||
|  | ||||
|         Args: | ||||
|             index: Index for the list of words. | ||||
|  | ||||
|         Returns: | ||||
|             Associated word at the position specified by index. | ||||
|         """ | ||||
|         return self.words[index] | ||||
|  | ||||
|     def index_of(self, word: npt.NDArray[np.int64]) -> int: | ||||
|         """Given a word, determine its index in the list (if it exists), | ||||
|         otherwise returning -1 if no index exists. | ||||
|  | ||||
|         Args: | ||||
|             word: Word to find in the word list. | ||||
|  | ||||
|         Returns: | ||||
|             The index of the given word if it exists, otherwise -1. | ||||
|         """ | ||||
|         try: | ||||
|             index, = np.nonzero((word == self.words).all(axis=1)) | ||||
|             return index[0] | ||||
|         except: | ||||
|             return -1 | ||||
|  | ||||
|  | ||||
| class SolutionList(WordList): | ||||
|     """Space for *solution* words to the Wordle environment. | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|         * "guesses", which the game accepts as valid words to use to guess the | ||||
|           answer. | ||||
|         * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||
|  | ||||
|     This class represents the set of solution words. | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         words = get_words('solution') | ||||
|         super().__init__(words, **kwargs) | ||||
|  | ||||
|  | ||||
| class WordleObsSpace(gym.spaces.Box): | ||||
|     """Implementation of the state (observation) space in terms of gym | ||||
|     primatives, in this case, gym.spaces.Box. | ||||
|  | ||||
|     The Wordle observation space can be thought of as a 6x5 array with two | ||||
|     channels: | ||||
|  | ||||
|       - the character channel, indicating which characters are placed on the | ||||
|         board (unfilled rows are marked with the empty character, 0) | ||||
|       - the flag channel, indicating the in-game information associated with | ||||
|         each character's placement (green highlight, yellow highlight, etc.) | ||||
|  | ||||
|     where there are 6 rows, one for each turn in the game, and 5 columns, since | ||||
|     the solution will always be a word of length 5. | ||||
|  | ||||
|     For simplicity, and compatibility with the stable_baselines algorithms, | ||||
|     this multichannel is modeled as a 6x10 array, where the two channels are | ||||
|     horizontally appended (along columns). Thus each row in the observation | ||||
|     should be interpreted as  | ||||
|  | ||||
|         c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 | ||||
|  | ||||
|     when the word is c0...c4 and its associated flags are f0...f4. | ||||
|  | ||||
|     While the superclass method `sample` is available to the WordleObsSpace, it | ||||
|     should be emphasized that the output of `sample` will (almost surely) not | ||||
|     correspond to a real game configuration, because the sampling is not out of | ||||
|     possible game configurations. Instead, the Box superclass just samples the | ||||
|     integer array space uniformly. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         self.n_rows = 6 | ||||
|         self.n_cols = 5 | ||||
|         self.max_char = 26 | ||||
|         self.max_flag = 4 | ||||
|  | ||||
|         low = np.zeros((self.n_rows, 2*self.n_cols)) | ||||
|         high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char), | ||||
|                      np.full((self.n_rows, self.n_cols), self.max_flag)] | ||||
|  | ||||
|         super().__init__(low, high, dtype=np.int64, **kwargs) | ||||
|  | ||||
|  | ||||
| class GuessList(WordList): | ||||
|     """Space for *solution* words to the Wordle environment. | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|         * "guesses", which the game accepts as valid words to use to guess the | ||||
|           answer. | ||||
|         * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||
|  | ||||
|     This class represents the set of guess words. | ||||
|     """ | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         words = get_words('guess') | ||||
|         super().__init__(words, **kwargs) | ||||
|  | ||||
|  | ||||
| class WordleEnv(gym.Env): | ||||
|  | ||||
|     metadata = {'render.modes': ['human']} | ||||
|  | ||||
|     # character flag codes | ||||
|     no_char = 0 | ||||
|     right_pos = 1 | ||||
|     wrong_pos = 2 | ||||
|     wrong_char = 3 | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.seed() | ||||
|         self.action_space = GuessList() | ||||
|         self.solution_space = SolutionList() | ||||
|  | ||||
|         self.observation_space = WordleObsSpace() | ||||
|  | ||||
|         self._highlights = { | ||||
|             self.right_pos: (bg.green, bg.rs), | ||||
|             self.wrong_pos: (bg.yellow, bg.rs), | ||||
|             self.wrong_char: ('', ''), | ||||
|             self.no_char: ('', ''), | ||||
|         } | ||||
|  | ||||
|         self.n_rounds = 6 | ||||
|         self.n_letters = 5 | ||||
|  | ||||
|     def _highlighter(self, char: str, flag: int) -> str: | ||||
|         """Terminal renderer functionality. Properly highlights a character | ||||
|         based on the flag associated with it. | ||||
|  | ||||
|         Args: | ||||
|             char: Character in question. | ||||
|             flag: Associated flag, one of: | ||||
|                 - 0: no character (render no background) | ||||
|                 - 1: right position (render green background) | ||||
|                 - 2: wrong position (render yellow background) | ||||
|                 - 3: wrong character (render no background) | ||||
|  | ||||
|         Returns: | ||||
|             Correct ASCII sequence producing the desired character in the | ||||
|             correct background. | ||||
|         """ | ||||
|         front, back = self._highlights[flag] | ||||
|         return front + char + back | ||||
|  | ||||
|     def reset(self): | ||||
|         self.round = 0 | ||||
|         self.solution = self.solution_space.sample() | ||||
|  | ||||
|         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) | ||||
|  | ||||
|         return self.state | ||||
|  | ||||
|     def render(self, mode: str ='human'): | ||||
|         """Renders the Wordle environment. | ||||
|  | ||||
|         Currently supported render modes: | ||||
|  | ||||
|         - human: renders the Wordle game to the terminal. | ||||
|  | ||||
|         Args: | ||||
|             mode: the mode to render with | ||||
|         """ | ||||
|         if mode == 'human': | ||||
|             for row in self.states: | ||||
|                 text = ''.join(map( | ||||
|                     self._highlighter,  | ||||
|                     to_english(row[:self.n_letters]).upper(),  | ||||
|                     row[self.n_letters:] | ||||
|                 )) | ||||
|  | ||||
|                 print(text) | ||||
|         else: | ||||
|             super(WordleEnv, self).render(mode=mode) | ||||
|                  | ||||
|     def step(self, action): | ||||
|         """Run one step of the Wordle game. Every game must be previously | ||||
|         initialized by a call to the `reset` method. | ||||
|  | ||||
|         Args: | ||||
|             action: Word guessed by the agent. | ||||
|         Returns: | ||||
|             state (object): Wordle game state after the guess. | ||||
|             reward (float): Reward associated with the guess (-1 for incorrect, | ||||
|               0 for correct) | ||||
|             done (bool): Whether the game has ended (by a correct guess or | ||||
|               after six guesses). | ||||
|             info (dict): Auxiliary diagnostic information (empty). | ||||
|         """ | ||||
|         assert self.action_space.contains(action), 'Invalid word!' | ||||
|  | ||||
|         # transform the action, solution indices to their words | ||||
|         action = self.action_space[action]   | ||||
|         solution = self.solution_space[self.solution] | ||||
|  | ||||
|         # populate the word chars into the row (character channel) | ||||
|         self.state[self.round][:self.n_letters] = action | ||||
|  | ||||
|         # populate the flag characters into the row (flag channel) | ||||
|         counter = Counter() | ||||
|         for i, char in enumerate(action): | ||||
|             flag_i = i + self.n_letters  # starts at 5 | ||||
|             counter[char] += 1 | ||||
|  | ||||
|             if char == solution[i]:  # character is in correct position | ||||
|                 self.state[self.round, flag_i] = self.right_pos | ||||
|             elif counter[char] <= (char == solution).sum():   | ||||
|                 # current character has been seen within correct number of | ||||
|                 # occurrences  | ||||
|                 self.state[self.round, flag_i] = self.wrong_pos | ||||
|             else: | ||||
|                 # wrong character, or "correct" character too many times | ||||
|                 self.state[self.round, flag_i] = self.wrong_char | ||||
|  | ||||
|         self.round += 1 | ||||
|  | ||||
|         correct = (action == solution).all() | ||||
|         game_over = (self.round == self.n_rounds) | ||||
|  | ||||
|         done = correct or game_over | ||||
|  | ||||
|         # Total reward equals -(number of incorrect guesses) | ||||
|         # reward = 0. if correct else -1. | ||||
|          | ||||
|         # correct +10 | ||||
|         # guesses new letter +1 | ||||
|         # guesses correct letter +1 | ||||
|         # spent another guess -1 | ||||
|  | ||||
|         reward = 0 | ||||
|         reward += np.sum(self.state[:, 5:] == 1) * 1 | ||||
|         reward += np.sum(self.state[:, 5:] == 2) * 0.5 | ||||
|         reward += np.sum(self.state[:, 5:] == 3) * -1 | ||||
|         reward += 10 if correct else -10 if done else 0 | ||||
|  | ||||
|         info = {'correct': correct} | ||||
|  | ||||
|         return self.state, reward, done, info | ||||
							
								
								
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										61
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								test.py
									
									
									
									
									
								
							| @@ -1,61 +0,0 @@ | ||||
|  | ||||
| from torch.utils.data import Dataset | ||||
| from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer | ||||
| from tqdm import tqdm as progress_bar | ||||
| import torch | ||||
| import matplotlib | ||||
|  | ||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| print(device) | ||||
|  | ||||
| encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102) | ||||
| # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token | ||||
| decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) | ||||
| model = EncoderDecoderModel(encoder=encoder, decoder=decoder) | ||||
|  | ||||
| # create tokenizer... | ||||
| tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") | ||||
|  | ||||
| import json | ||||
|  | ||||
| class CodeDataset(Dataset): | ||||
|     def __init__(self): | ||||
|         with open("data/conala-train.json") as f: | ||||
|             self.data = json.load(f) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         intent = self.data[idx]["rewritten_intent"] if self.data[idx]["rewritten_intent"] else self.data[idx]["intent"] | ||||
|         return intent, self.data[idx]["snippet"] | ||||
|  | ||||
|  | ||||
| optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3) | ||||
| dataloader = CodeDataset() | ||||
| model = model.to(device) | ||||
|  | ||||
| losses = [] | ||||
| epochs = 10 | ||||
| for i in range(epochs): | ||||
|  | ||||
|     epoch_loss = 0 | ||||
|  | ||||
|     for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)): | ||||
|  | ||||
|         input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device) | ||||
|         label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) | ||||
|  | ||||
|         loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss | ||||
|  | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|  | ||||
|         epoch_loss += loss.item() | ||||
|  | ||||
|     losses.append(epoch_loss) | ||||
|  | ||||
| plt.plot(losses, color="green", label="Training Loss") | ||||
| plt.legend(loc = 'upper left') | ||||
| plt.savefig("plot.png") | ||||
		Reference in New Issue
	
	Block a user