mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			ethan-test
			...
			9172326013
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 9172326013 | ||
|  | 4836be8121 | ||
|  | 5672169073 | ||
|  | 5ec123e0f1 | ||
|  | e9622b6f68 | ||
|  | 83e81722d2 | ||
|  | 320f2f81b7 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1,3 @@ | |||||||
| **/data/* | **/data/* | ||||||
|  | **/*.zip | ||||||
|  | **/__pycache__ | ||||||
							
								
								
									
										1160
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1160
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										7
									
								
								gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | from gym.envs.registration import register | ||||||
|  | from .wordle import WordleEnv | ||||||
|  |  | ||||||
|  | register( | ||||||
|  |     id='Wordle-v0', | ||||||
|  |     entry_point='gym_wordle.wordle:WordleEnv' | ||||||
|  | ) | ||||||
							
								
								
									
										12972
									
								
								gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12972
									
								
								gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2315
									
								
								gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2315
									
								
								gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										93
									
								
								gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | |||||||
|  | import numpy as np | ||||||
|  | import numpy.typing as npt | ||||||
|  |  | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _chars = ' abcdefghijklmnopqrstuvwxyz' | ||||||
|  | _char_d = {c: i for i, c in enumerate(_chars)} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def to_english(array: npt.NDArray[np.int64]) -> str: | ||||||
|  |     """Converts a numpy integer array into a corresponding English string. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         array: Word in array (int) form. It is assumed that each integer in the | ||||||
|  |           array is between 0,...,26 (inclusive). | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         A (lowercase) string representation of the word.     | ||||||
|  |     """ | ||||||
|  |     return ''.join(_chars[i] for i in array) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def to_array(word: str) -> npt.NDArray[np.int64]: | ||||||
|  |     """Converts a string of characters into a corresponding numpy array. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         word: Word in string form. It is assumed that each character in the | ||||||
|  |           string is either an empty space ' ' or lowercase alphabetical | ||||||
|  |           character. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         An array representation of the word. | ||||||
|  |     """ | ||||||
|  |     return np.array([_char_d[c] for c in word]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_words(category: str, build: bool=False) -> npt.NDArray[np.int64]: | ||||||
|  |     """Loads a list of words in array form.  | ||||||
|  |  | ||||||
|  |     If specified, this will recompute the list from the human-readable list of | ||||||
|  |     words, and save the results in array form. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         category: Either 'guess' or 'solution', which corresponds to the list | ||||||
|  |           of acceptable guess words and the list of acceptable solution words. | ||||||
|  |         build: If True, recomputes and saves the array-version of the computed | ||||||
|  |           list for future access. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         An array representation of the list of words specified by the category. | ||||||
|  |         This array has two dimensions, and the number of columns is fixed at | ||||||
|  |         five. | ||||||
|  |     """ | ||||||
|  |     assert category in {'guess', 'solution'} | ||||||
|  |      | ||||||
|  |     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' | ||||||
|  |     if build: | ||||||
|  |        list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||||
|  |  | ||||||
|  |        with open(list_path, 'r') as f: | ||||||
|  |            words = np.array([to_array(line.strip()) for line in f]) | ||||||
|  |            np.save(arr_path, words) | ||||||
|  |  | ||||||
|  |     return np.load(arr_path) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def play(): | ||||||
|  |     """Play Wordle yourself!""" | ||||||
|  |     import gym | ||||||
|  |     import gym_wordle | ||||||
|  |      | ||||||
|  |     env = gym.make('Wordle-v0')  # load the environment | ||||||
|  |      | ||||||
|  |     env.reset() | ||||||
|  |     solution = to_english(env.unwrapped.solution_space[env.solution]).upper()  # no peeking! | ||||||
|  |  | ||||||
|  |     done = False | ||||||
|  |      | ||||||
|  |     while not done: | ||||||
|  |         action = -1  | ||||||
|  |  | ||||||
|  |         # in general, the environment won't be forgiving if you input an | ||||||
|  |         # invalid word, but for this function I want to let you screw up user | ||||||
|  |         # input without consequence, so just loops until valid input is taken | ||||||
|  |         while not env.action_space.contains(action): | ||||||
|  |             guess = input('Guess: ') | ||||||
|  |             action = env.unwrapped.action_space.index_of(to_array(guess)) | ||||||
|  |      | ||||||
|  |         state, reward, done, info = env.step(action) | ||||||
|  |         env.render() | ||||||
|  |      | ||||||
|  |     print(f"The word was {solution}") | ||||||
							
								
								
									
										296
									
								
								gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										296
									
								
								gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,296 @@ | |||||||
|  | import gym | ||||||
|  | import numpy as np | ||||||
|  | import numpy.typing as npt | ||||||
|  | from sty import fg, bg, ef, rs | ||||||
|  |  | ||||||
|  | from collections import Counter  | ||||||
|  | from gym_wordle.utils import to_english, to_array, get_words | ||||||
|  | from typing import Optional  | ||||||
|  |  | ||||||
|  | class WordList(gym.spaces.Discrete): | ||||||
|  |     """Super class for defining a space of valid words according to a specified | ||||||
|  |     list. | ||||||
|  |  | ||||||
|  |     TODO: Fix these paragraphs | ||||||
|  |     The space is a subclass of gym.spaces.Discrete, where each element | ||||||
|  |     corresponds to an index of a valid word in the word list. The obfuscation | ||||||
|  |     is necessary for more direct implementation of RL algorithms, which expect | ||||||
|  |     spaces of less sophisticated form. | ||||||
|  |  | ||||||
|  |     In addition to the default methods of the Discrete space, it implements | ||||||
|  |     a __getitem__ method for easy index lookup, and an index_of method to | ||||||
|  |     convert potential words into their corresponding index (if they exist). | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, words: npt.NDArray[np.int64], **kwargs): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             words: Collection of words in array form with shape (_, 5), where | ||||||
|  |               each word is a row of the array. Each array element is an integer | ||||||
|  |               between 0,...,26 (inclusive). | ||||||
|  |             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||||
|  |         """ | ||||||
|  |         super().__init__(words.shape[0], **kwargs) | ||||||
|  |         self.words = words | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index: int) -> npt.NDArray[np.int64]: | ||||||
|  |         """Obtains the (int-encoded) word associated with the given index. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             index: Index for the list of words. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             Associated word at the position specified by index. | ||||||
|  |         """ | ||||||
|  |         return self.words[index] | ||||||
|  |  | ||||||
|  |     def index_of(self, word: npt.NDArray[np.int64]) -> int: | ||||||
|  |         """Given a word, determine its index in the list (if it exists), | ||||||
|  |         otherwise returning -1 if no index exists. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             word: Word to find in the word list. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             The index of the given word if it exists, otherwise -1. | ||||||
|  |         """ | ||||||
|  |         try: | ||||||
|  |             index, = np.nonzero((word == self.words).all(axis=1)) | ||||||
|  |             return index[0] | ||||||
|  |         except: | ||||||
|  |             return -1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SolutionList(WordList): | ||||||
|  |     """Space for *solution* words to the Wordle environment. | ||||||
|  |  | ||||||
|  |     In the game Wordle, there are two different collections of words: | ||||||
|  |  | ||||||
|  |         * "guesses", which the game accepts as valid words to use to guess the | ||||||
|  |           answer. | ||||||
|  |         * "solutions", which the game uses to choose solutions from. | ||||||
|  |  | ||||||
|  |     Of course, the set of solutions is a strict subset of the set of guesses. | ||||||
|  |  | ||||||
|  |     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||||
|  |  | ||||||
|  |     This class represents the set of solution words. | ||||||
|  |     """ | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||||
|  |         """ | ||||||
|  |         words = get_words('solution') | ||||||
|  |         super().__init__(words, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WordleObsSpace(gym.spaces.Box): | ||||||
|  |     """Implementation of the state (observation) space in terms of gym | ||||||
|  |     primatives, in this case, gym.spaces.Box. | ||||||
|  |  | ||||||
|  |     The Wordle observation space can be thought of as a 6x5 array with two | ||||||
|  |     channels: | ||||||
|  |  | ||||||
|  |       - the character channel, indicating which characters are placed on the | ||||||
|  |         board (unfilled rows are marked with the empty character, 0) | ||||||
|  |       - the flag channel, indicating the in-game information associated with | ||||||
|  |         each character's placement (green highlight, yellow highlight, etc.) | ||||||
|  |  | ||||||
|  |     where there are 6 rows, one for each turn in the game, and 5 columns, since | ||||||
|  |     the solution will always be a word of length 5. | ||||||
|  |  | ||||||
|  |     For simplicity, and compatibility with the stable_baselines algorithms, | ||||||
|  |     this multichannel is modeled as a 6x10 array, where the two channels are | ||||||
|  |     horizontally appended (along columns). Thus each row in the observation | ||||||
|  |     should be interpreted as  | ||||||
|  |  | ||||||
|  |         c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 | ||||||
|  |  | ||||||
|  |     when the word is c0...c4 and its associated flags are f0...f4. | ||||||
|  |  | ||||||
|  |     While the superclass method `sample` is available to the WordleObsSpace, it | ||||||
|  |     should be emphasized that the output of `sample` will (almost surely) not | ||||||
|  |     correspond to a real game configuration, because the sampling is not out of | ||||||
|  |     possible game configurations. Instead, the Box superclass just samples the | ||||||
|  |     integer array space uniformly. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         self.n_rows = 6 | ||||||
|  |         self.n_cols = 5 | ||||||
|  |         self.max_char = 26 | ||||||
|  |         self.max_flag = 4 | ||||||
|  |  | ||||||
|  |         low = np.zeros((self.n_rows, 2*self.n_cols)) | ||||||
|  |         high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char), | ||||||
|  |                      np.full((self.n_rows, self.n_cols), self.max_flag)] | ||||||
|  |  | ||||||
|  |         super().__init__(low, high, dtype=np.int64, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class GuessList(WordList): | ||||||
|  |     """Space for *solution* words to the Wordle environment. | ||||||
|  |  | ||||||
|  |     In the game Wordle, there are two different collections of words: | ||||||
|  |  | ||||||
|  |         * "guesses", which the game accepts as valid words to use to guess the | ||||||
|  |           answer. | ||||||
|  |         * "solutions", which the game uses to choose solutions from. | ||||||
|  |  | ||||||
|  |     Of course, the set of solutions is a strict subset of the set of guesses. | ||||||
|  |  | ||||||
|  |     Reference: https://fivethirtyeight.com/features/when-the-riddler-met-wordle/ | ||||||
|  |  | ||||||
|  |     This class represents the set of guess words. | ||||||
|  |     """ | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         """ | ||||||
|  |         Args: | ||||||
|  |             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||||
|  |         """ | ||||||
|  |         words = get_words('guess') | ||||||
|  |         super().__init__(words, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WordleEnv(gym.Env): | ||||||
|  |  | ||||||
|  |     metadata = {'render.modes': ['human']} | ||||||
|  |  | ||||||
|  |     # character flag codes | ||||||
|  |     no_char = 0 | ||||||
|  |     right_pos = 1 | ||||||
|  |     wrong_pos = 2 | ||||||
|  |     wrong_char = 3 | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |  | ||||||
|  |         self.seed() | ||||||
|  |         self.action_space = GuessList() | ||||||
|  |         self.solution_space = SolutionList() | ||||||
|  |  | ||||||
|  |         self.observation_space = WordleObsSpace() | ||||||
|  |  | ||||||
|  |         self._highlights = { | ||||||
|  |             self.right_pos: (bg.green, bg.rs), | ||||||
|  |             self.wrong_pos: (bg.yellow, bg.rs), | ||||||
|  |             self.wrong_char: ('', ''), | ||||||
|  |             self.no_char: ('', ''), | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         self.n_rounds = 6 | ||||||
|  |         self.n_letters = 5 | ||||||
|  |  | ||||||
|  |     def _highlighter(self, char: str, flag: int) -> str: | ||||||
|  |         """Terminal renderer functionality. Properly highlights a character | ||||||
|  |         based on the flag associated with it. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             char: Character in question. | ||||||
|  |             flag: Associated flag, one of: | ||||||
|  |                 - 0: no character (render no background) | ||||||
|  |                 - 1: right position (render green background) | ||||||
|  |                 - 2: wrong position (render yellow background) | ||||||
|  |                 - 3: wrong character (render no background) | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             Correct ASCII sequence producing the desired character in the | ||||||
|  |             correct background. | ||||||
|  |         """ | ||||||
|  |         front, back = self._highlights[flag] | ||||||
|  |         return front + char + back | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self.round = 0 | ||||||
|  |         self.solution = self.solution_space.sample() | ||||||
|  |  | ||||||
|  |         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) | ||||||
|  |  | ||||||
|  |         return self.state | ||||||
|  |  | ||||||
|  |     def render(self, mode: str ='human'): | ||||||
|  |         """Renders the Wordle environment. | ||||||
|  |  | ||||||
|  |         Currently supported render modes: | ||||||
|  |  | ||||||
|  |         - human: renders the Wordle game to the terminal. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             mode: the mode to render with | ||||||
|  |         """ | ||||||
|  |         if mode == 'human': | ||||||
|  |             for row in self.states: | ||||||
|  |                 text = ''.join(map( | ||||||
|  |                     self._highlighter,  | ||||||
|  |                     to_english(row[:self.n_letters]).upper(),  | ||||||
|  |                     row[self.n_letters:] | ||||||
|  |                 )) | ||||||
|  |  | ||||||
|  |                 print(text) | ||||||
|  |         else: | ||||||
|  |             super(WordleEnv, self).render(mode=mode) | ||||||
|  |                  | ||||||
|  |     def step(self, action): | ||||||
|  |         """Run one step of the Wordle game. Every game must be previously | ||||||
|  |         initialized by a call to the `reset` method. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             action: Word guessed by the agent. | ||||||
|  |         Returns: | ||||||
|  |             state (object): Wordle game state after the guess. | ||||||
|  |             reward (float): Reward associated with the guess (-1 for incorrect, | ||||||
|  |               0 for correct) | ||||||
|  |             done (bool): Whether the game has ended (by a correct guess or | ||||||
|  |               after six guesses). | ||||||
|  |             info (dict): Auxiliary diagnostic information (empty). | ||||||
|  |         """ | ||||||
|  |         assert self.action_space.contains(action), 'Invalid word!' | ||||||
|  |  | ||||||
|  |         # transform the action, solution indices to their words | ||||||
|  |         action = self.action_space[action]   | ||||||
|  |         solution = self.solution_space[self.solution] | ||||||
|  |  | ||||||
|  |         # populate the word chars into the row (character channel) | ||||||
|  |         self.state[self.round][:self.n_letters] = action | ||||||
|  |  | ||||||
|  |         # populate the flag characters into the row (flag channel) | ||||||
|  |         counter = Counter() | ||||||
|  |         for i, char in enumerate(action): | ||||||
|  |             flag_i = i + self.n_letters  # starts at 5 | ||||||
|  |             counter[char] += 1 | ||||||
|  |  | ||||||
|  |             if char == solution[i]:  # character is in correct position | ||||||
|  |                 self.state[self.round, flag_i] = self.right_pos | ||||||
|  |             elif counter[char] <= (char == solution).sum():   | ||||||
|  |                 # current character has been seen within correct number of | ||||||
|  |                 # occurrences  | ||||||
|  |                 self.state[self.round, flag_i] = self.wrong_pos | ||||||
|  |             else: | ||||||
|  |                 # wrong character, or "correct" character too many times | ||||||
|  |                 self.state[self.round, flag_i] = self.wrong_char | ||||||
|  |  | ||||||
|  |         self.round += 1 | ||||||
|  |  | ||||||
|  |         correct = (action == solution).all() | ||||||
|  |         game_over = (self.round == self.n_rounds) | ||||||
|  |  | ||||||
|  |         done = correct or game_over | ||||||
|  |  | ||||||
|  |         # Total reward equals -(number of incorrect guesses) | ||||||
|  |         # reward = 0. if correct else -1. | ||||||
|  |          | ||||||
|  |         # correct +10 | ||||||
|  |         # guesses new letter +1 | ||||||
|  |         # guesses correct letter +1 | ||||||
|  |         # spent another guess -1 | ||||||
|  |  | ||||||
|  |         reward = 0 | ||||||
|  |         reward += np.sum(self.state[:, 5:] == 1) * 1 | ||||||
|  |         reward += np.sum(self.state[:, 5:] == 2) * 0.5 | ||||||
|  |         reward += np.sum(self.state[:, 5:] == 3) * -1 | ||||||
|  |         reward += 10 if correct else -10 if done else 0 | ||||||
|  |  | ||||||
|  |         info = {'correct': correct} | ||||||
|  |  | ||||||
|  |         return self.state, reward, done, info | ||||||
							
								
								
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										61
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								test.py
									
									
									
									
									
								
							| @@ -1,61 +0,0 @@ | |||||||
|  |  | ||||||
| from torch.utils.data import Dataset |  | ||||||
| from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer |  | ||||||
| from tqdm import tqdm as progress_bar |  | ||||||
| import torch |  | ||||||
| import matplotlib |  | ||||||
|  |  | ||||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |  | ||||||
| print(device) |  | ||||||
|  |  | ||||||
| encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102) |  | ||||||
| # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token |  | ||||||
| decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) |  | ||||||
| model = EncoderDecoderModel(encoder=encoder, decoder=decoder) |  | ||||||
|  |  | ||||||
| # create tokenizer... |  | ||||||
| tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") |  | ||||||
|  |  | ||||||
| import json |  | ||||||
|  |  | ||||||
| class CodeDataset(Dataset): |  | ||||||
|     def __init__(self): |  | ||||||
|         with open("data/conala-train.json") as f: |  | ||||||
|             self.data = json.load(f) |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self.data) |  | ||||||
|  |  | ||||||
|     def __getitem__(self, idx): |  | ||||||
|         intent = self.data[idx]["rewritten_intent"] if self.data[idx]["rewritten_intent"] else self.data[idx]["intent"] |  | ||||||
|         return intent, self.data[idx]["snippet"] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3) |  | ||||||
| dataloader = CodeDataset() |  | ||||||
| model = model.to(device) |  | ||||||
|  |  | ||||||
| losses = [] |  | ||||||
| epochs = 10 |  | ||||||
| for i in range(epochs): |  | ||||||
|  |  | ||||||
|     epoch_loss = 0 |  | ||||||
|  |  | ||||||
|     for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)): |  | ||||||
|  |  | ||||||
|         input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device) |  | ||||||
|         label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) |  | ||||||
|  |  | ||||||
|         loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss |  | ||||||
|  |  | ||||||
|         optimizer.zero_grad() |  | ||||||
|         loss.backward() |  | ||||||
|         optimizer.step() |  | ||||||
|  |  | ||||||
|         epoch_loss += loss.item() |  | ||||||
|  |  | ||||||
|     losses.append(epoch_loss) |  | ||||||
|  |  | ||||||
| plt.plot(losses, color="green", label="Training Loss") |  | ||||||
| plt.legend(loc = 'upper left') |  | ||||||
| plt.savefig("plot.png") |  | ||||||
		Reference in New Issue
	
	Block a user