mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			4 Commits
		
	
	
		
			ethan-test
			...
			5ec123e0f1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 5ec123e0f1 | ||
|  | e9622b6f68 | ||
|  | 83e81722d2 | ||
|  | 320f2f81b7 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1 +1,2 @@ | |||||||
| **/data/* | **/data/* | ||||||
|  | **/*.zip | ||||||
							
								
								
									
										114
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								dqn_wordle.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | |||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "import gym\n", | ||||||
|  |     "import gym_wordle\n", | ||||||
|  |     "from stable_baselines3 import DQN\n", | ||||||
|  |     "import numpy as np\n", | ||||||
|  |     "import tqdm" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "env = gym.make(\"Wordle-v0\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "print(env)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 35, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "total_timesteps = 100000\n", | ||||||
|  |     "model = DQN(\"MlpPolicy\", env, verbose=0)\n", | ||||||
|  |     "model.learn(total_timesteps=total_timesteps, progress_bar=True)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "def test(model):\n", | ||||||
|  |     "\n", | ||||||
|  |     "    end_rewards = []\n", | ||||||
|  |     "\n", | ||||||
|  |     "    for i in range(1000):\n", | ||||||
|  |     "        \n", | ||||||
|  |     "        state = env.reset()\n", | ||||||
|  |     "\n", | ||||||
|  |     "        done = False\n", | ||||||
|  |     "\n", | ||||||
|  |     "        while not done:\n", | ||||||
|  |     "\n", | ||||||
|  |     "            action, _states = model.predict(state, deterministic=True)\n", | ||||||
|  |     "\n", | ||||||
|  |     "            state, reward, done, info = env.step(action)\n", | ||||||
|  |     "            \n", | ||||||
|  |     "        end_rewards.append(reward == 0)\n", | ||||||
|  |     "        \n", | ||||||
|  |     "    return np.sum(end_rewards) / len(end_rewards)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "model.save(\"dqn_wordle\")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "model = DQN.load(\"dqn_wordle\")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "print(test(model))" | ||||||
|  |    ] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": "Python 3 (ipykernel)", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.8.10" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 2 | ||||||
|  | } | ||||||
							
								
								
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								test.ipynb
									
									
									
									
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										61
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								test.py
									
									
									
									
									
								
							| @@ -1,61 +0,0 @@ | |||||||
|  |  | ||||||
| from torch.utils.data import Dataset |  | ||||||
| from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer |  | ||||||
| from tqdm import tqdm as progress_bar |  | ||||||
| import torch |  | ||||||
| import matplotlib |  | ||||||
|  |  | ||||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |  | ||||||
| print(device) |  | ||||||
|  |  | ||||||
| encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102) |  | ||||||
| # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token |  | ||||||
| decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) |  | ||||||
| model = EncoderDecoderModel(encoder=encoder, decoder=decoder) |  | ||||||
|  |  | ||||||
| # create tokenizer... |  | ||||||
| tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") |  | ||||||
|  |  | ||||||
| import json |  | ||||||
|  |  | ||||||
| class CodeDataset(Dataset): |  | ||||||
|     def __init__(self): |  | ||||||
|         with open("data/conala-train.json") as f: |  | ||||||
|             self.data = json.load(f) |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self.data) |  | ||||||
|  |  | ||||||
|     def __getitem__(self, idx): |  | ||||||
|         intent = self.data[idx]["rewritten_intent"] if self.data[idx]["rewritten_intent"] else self.data[idx]["intent"] |  | ||||||
|         return intent, self.data[idx]["snippet"] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3) |  | ||||||
| dataloader = CodeDataset() |  | ||||||
| model = model.to(device) |  | ||||||
|  |  | ||||||
| losses = [] |  | ||||||
| epochs = 10 |  | ||||||
| for i in range(epochs): |  | ||||||
|  |  | ||||||
|     epoch_loss = 0 |  | ||||||
|  |  | ||||||
|     for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)): |  | ||||||
|  |  | ||||||
|         input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device) |  | ||||||
|         label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) |  | ||||||
|  |  | ||||||
|         loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss |  | ||||||
|  |  | ||||||
|         optimizer.zero_grad() |  | ||||||
|         loss.backward() |  | ||||||
|         optimizer.step() |  | ||||||
|  |  | ||||||
|         epoch_loss += loss.item() |  | ||||||
|  |  | ||||||
|     losses.append(epoch_loss) |  | ||||||
|  |  | ||||||
| plt.plot(losses, color="green", label="Training Loss") |  | ||||||
| plt.legend(loc = 'upper left') |  | ||||||
| plt.savefig("plot.png") |  | ||||||
		Reference in New Issue
	
	Block a user