mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			4 Commits
		
	
	
		
			ethan-test
			...
			5ec123e0f1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 5ec123e0f1 | ||
|  | e9622b6f68 | ||
|  | 83e81722d2 | ||
|  | 320f2f81b7 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1,2 @@ | ||||
| **/data/* | ||||
| **/data/* | ||||
| **/*.zip | ||||
							
								
								
									
										114
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gym\n", | ||||
|     "import gym_wordle\n", | ||||
|     "from stable_baselines3 import DQN\n", | ||||
|     "import numpy as np\n", | ||||
|     "import tqdm" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "env = gym.make(\"Wordle-v0\")\n", | ||||
|     "\n", | ||||
|     "print(env)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 35, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "total_timesteps = 100000\n", | ||||
|     "model = DQN(\"MlpPolicy\", env, verbose=0)\n", | ||||
|     "model.learn(total_timesteps=total_timesteps, progress_bar=True)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def test(model):\n", | ||||
|     "\n", | ||||
|     "    end_rewards = []\n", | ||||
|     "\n", | ||||
|     "    for i in range(1000):\n", | ||||
|     "        \n", | ||||
|     "        state = env.reset()\n", | ||||
|     "\n", | ||||
|     "        done = False\n", | ||||
|     "\n", | ||||
|     "        while not done:\n", | ||||
|     "\n", | ||||
|     "            action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "            state, reward, done, info = env.step(action)\n", | ||||
|     "            \n", | ||||
|     "        end_rewards.append(reward == 0)\n", | ||||
|     "        \n", | ||||
|     "    return np.sum(end_rewards) / len(end_rewards)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model.save(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model = DQN.load(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "print(test(model))" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.10" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
							
								
								
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										61
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								test.py
									
									
									
									
									
								
							| @@ -1,61 +0,0 @@ | ||||
|  | ||||
| from torch.utils.data import Dataset | ||||
| from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer | ||||
| from tqdm import tqdm as progress_bar | ||||
| import torch | ||||
| import matplotlib | ||||
|  | ||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| print(device) | ||||
|  | ||||
| encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102) | ||||
| # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token | ||||
| decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) | ||||
| model = EncoderDecoderModel(encoder=encoder, decoder=decoder) | ||||
|  | ||||
| # create tokenizer... | ||||
| tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") | ||||
|  | ||||
| import json | ||||
|  | ||||
| class CodeDataset(Dataset): | ||||
|     def __init__(self): | ||||
|         with open("data/conala-train.json") as f: | ||||
|             self.data = json.load(f) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         intent = self.data[idx]["rewritten_intent"] if self.data[idx]["rewritten_intent"] else self.data[idx]["intent"] | ||||
|         return intent, self.data[idx]["snippet"] | ||||
|  | ||||
|  | ||||
| optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3) | ||||
| dataloader = CodeDataset() | ||||
| model = model.to(device) | ||||
|  | ||||
| losses = [] | ||||
| epochs = 10 | ||||
| for i in range(epochs): | ||||
|  | ||||
|     epoch_loss = 0 | ||||
|  | ||||
|     for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)): | ||||
|  | ||||
|         input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device) | ||||
|         label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) | ||||
|  | ||||
|         loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss | ||||
|  | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|  | ||||
|         epoch_loss += loss.item() | ||||
|  | ||||
|     losses.append(epoch_loss) | ||||
|  | ||||
| plt.plot(losses, color="green", label="Training Loss") | ||||
| plt.legend(loc = 'upper left') | ||||
| plt.savefig("plot.png") | ||||
		Reference in New Issue
	
	Block a user