mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			f40301cac9
			...
			arthur-tes
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | dd5889da33 | ||
|  | 848ea719b7 | ||
|  | f641d77c47 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,2 +1,3 @@ | ||||
| **/data/* | ||||
| **/*.zip | ||||
| **/__pycache__ | ||||
							
								
								
									
										114
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
							
						
						
									
										114
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							| @@ -1,114 +0,0 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gym\n", | ||||
|     "import gym_wordle\n", | ||||
|     "from stable_baselines3 import DQN\n", | ||||
|     "import numpy as np\n", | ||||
|     "import tqdm" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "env = gym.make(\"Wordle-v0\")\n", | ||||
|     "\n", | ||||
|     "print(env)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 35, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "total_timesteps = 100000\n", | ||||
|     "model = DQN(\"MlpPolicy\", env, verbose=0)\n", | ||||
|     "model.learn(total_timesteps=total_timesteps, progress_bar=True)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def test(model):\n", | ||||
|     "\n", | ||||
|     "    end_rewards = []\n", | ||||
|     "\n", | ||||
|     "    for i in range(1000):\n", | ||||
|     "        \n", | ||||
|     "        state = env.reset()\n", | ||||
|     "\n", | ||||
|     "        done = False\n", | ||||
|     "\n", | ||||
|     "        while not done:\n", | ||||
|     "\n", | ||||
|     "            action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "            state, reward, done, info = env.step(action)\n", | ||||
|     "            \n", | ||||
|     "        end_rewards.append(reward == 0)\n", | ||||
|     "        \n", | ||||
|     "    return np.sum(end_rewards) / len(end_rewards)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model.save(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model = DQN.load(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "print(test(model))" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.10" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
							
								
								
									
										38
									
								
								dqn_wordle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								dqn_wordle.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| import gym | ||||
| import sys | ||||
| from stable_baselines3 import DQN | ||||
| from stable_baselines3.common.env_util import make_vec_env | ||||
| import wordle_gym | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
|  | ||||
| def train (model, env, total_timesteps = 100000):  | ||||
|     model.learn(total_timesteps=total_timesteps, progress_bar=True) | ||||
|     model.save("dqn_wordle") | ||||
|  | ||||
| def test(model, env, test_num=1000): | ||||
|  | ||||
|     total_correct = 0 | ||||
|  | ||||
|     for i in tqdm(range(test_num)): | ||||
|  | ||||
|         model = DQN.load("dqn_wordle") | ||||
|  | ||||
|         env = gym.make("wordle-v0") | ||||
|         obs = env.reset() | ||||
|         done = False | ||||
|         while not done: | ||||
|             action, _states = model.predict(obs) | ||||
|             obs, rewards, done, info = env.step(action) | ||||
|              | ||||
|     return total_correct / test_num | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|      | ||||
|     env = gym.make("wordle-v0") | ||||
|     model = DQN("MlpPolicy", env, verbose=0) | ||||
|     print(env) | ||||
|     print(model) | ||||
|  | ||||
|     train(model, env, total_timesteps=500000) | ||||
|     print(test(model, env)) | ||||
							
								
								
									
										9
									
								
								wordle_gym/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								wordle_gym/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| from gym.envs.registration import register | ||||
|  | ||||
| register( | ||||
|     id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv", | ||||
| ) | ||||
|  | ||||
| register( | ||||
|     id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv", | ||||
| ) | ||||
							
								
								
									
										0
									
								
								wordle_gym/envs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								wordle_gym/envs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										15
									
								
								wordle_gym/envs/strategies/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								wordle_gym/envs/strategies/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| from enum import Enum | ||||
|  | ||||
| from typing import List | ||||
|  | ||||
| class StrategyType(Enum): | ||||
|     RANDOM = 1 | ||||
|     ELIMINATION = 2 | ||||
|     PROBABILITY = 3 | ||||
|  | ||||
| class Strategy: | ||||
|     def __init__(self, type: StrategyType): | ||||
|         self.type = type | ||||
|      | ||||
|     def get_best_word(self, guesses: List[List[str]], state: List[List[int]]): | ||||
|         raise NotImplementedError("Strategy.get_best_word() not implemented") | ||||
							
								
								
									
										2
									
								
								wordle_gym/envs/strategies/elimination.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								wordle_gym/envs/strategies/elimination.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| def get_best_word(state): | ||||
|      | ||||
							
								
								
									
										20
									
								
								wordle_gym/envs/strategies/probabilistic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								wordle_gym/envs/strategies/probabilistic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| from random import sample | ||||
| from typing import List | ||||
|  | ||||
| from base import Strategy | ||||
| from base import StrategyType | ||||
|  | ||||
| from utils import freq | ||||
|  | ||||
| class Random(Strategy): | ||||
|     def __init__(self): | ||||
|         self.words = freq.get_5_letter_word_freqs() | ||||
|         super().__init__(StrategyType.RANDOM) | ||||
|  | ||||
|     def get_best_word(self, state: List[List[int]]): | ||||
|          | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     r = Random() | ||||
|     print(r.get_best_word([])) | ||||
							
								
								
									
										29
									
								
								wordle_gym/envs/strategies/rand.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								wordle_gym/envs/strategies/rand.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from random import sample | ||||
| from typing import List | ||||
|  | ||||
| from base import Strategy | ||||
| from base import StrategyType | ||||
|  | ||||
| from utils import freq | ||||
|  | ||||
| class Random(Strategy): | ||||
|     def __init__(self): | ||||
|         self.words = freq.get_5_letter_word_freqs() | ||||
|         super().__init__(StrategyType.RANDOM) | ||||
|  | ||||
|     def get_best_word(self, guesses: List[List[str]], state: List[List[int]]): | ||||
|         correct_letters = [] | ||||
|         regex = "" | ||||
|         for g, s in zip(guesses, state): | ||||
|             for c, s in zip(g, s): | ||||
|                 if s == 2: | ||||
|                     correct_letters.append(c) | ||||
|                     regex += c | ||||
|                  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     r = Random() | ||||
|     print(r.get_best_word([])) | ||||
							
								
								
									
										27
									
								
								wordle_gym/envs/strategies/utils/freq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								wordle_gym/envs/strategies/utils/freq.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| from os import path | ||||
|  | ||||
| def get_5_letter_word_freqs(): | ||||
|     """ | ||||
|     Returns a list of words with 5 letters. | ||||
|     """ | ||||
|     FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt") | ||||
|     lines = read_file(FILEPATH) | ||||
|     return {k:v for k, v in get_freq(lines).items() if len(k) == 5} | ||||
|  | ||||
|  | ||||
| def read_file(filename): | ||||
|     """ | ||||
|     Reads a file and returns a list of words and frequencies | ||||
|     """ | ||||
|     with open(filename, 'r') as f: | ||||
|         return f.readlines() | ||||
|  | ||||
|  | ||||
| def get_freq(lines): | ||||
|     """ | ||||
|     Returns a dictionary of words and their frequencies | ||||
|     """ | ||||
|     freqs = {} | ||||
|     for word, freq in map(lambda x: x.split("\t"), lines): | ||||
|         freqs[word] = int(freq) | ||||
|     return freqs | ||||
							
								
								
									
										131
									
								
								wordle_gym/envs/wordle_env.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								wordle_gym/envs/wordle_env.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | ||||
| import os | ||||
|  | ||||
|  | ||||
| import gym | ||||
| from gym import error, spaces, utils | ||||
| from gym.utils import seeding | ||||
|  | ||||
| from enum import Enum | ||||
| from collections import Counter | ||||
| import numpy as np | ||||
|  | ||||
| WORD_LENGTH = 5 | ||||
| TOTAL_GUESSES = 6 | ||||
| SOLUTION_PATH = "../words/solution.csv" | ||||
| VALID_WORDS_PATH = "../words/guess.csv" | ||||
|  | ||||
| class LetterState(Enum): | ||||
|     ABSENT = 0 | ||||
|     PRESENT = 1 | ||||
|     CORRECT_POSITION = 2 | ||||
|  | ||||
|  | ||||
| class WordleEnv(gym.Env): | ||||
|     metadata = {"render.modes": ["human"]} | ||||
|  | ||||
|     def _current_path(self): | ||||
|         return os.path.dirname(os.path.abspath(__file__)) | ||||
|  | ||||
|     def _read_solutions(self): | ||||
|         return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines() | ||||
|      | ||||
|     def _get_valid_words(self): | ||||
|         words = [] | ||||
|         for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines(): | ||||
|             words.append((word, Counter(word))) | ||||
|         return words | ||||
|  | ||||
|     def get_valid(self): | ||||
|         return self._valid_words | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._solutions = self._read_solutions() | ||||
|         self._valid_words = self._get_valid_words() | ||||
|         self.action_space = spaces.Discrete(len(self._valid_words)) | ||||
|         self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH) | ||||
|         np.random.seed(0) | ||||
|         self.reset() | ||||
|      | ||||
|     def _check_guess(self, guess, guess_counter): | ||||
|         c = guess_counter & self.solution_ct | ||||
|         result = [] | ||||
|         correct = True | ||||
|         reward = 0 | ||||
|         for i, char in enumerate(guess): | ||||
|             if c.get(char, 0) > 0: | ||||
|                 if self.solution[i] == char: | ||||
|                     result.append(2) | ||||
|                     reward += 2 | ||||
|                 else: | ||||
|                     result.append(1) | ||||
|                     correct = False | ||||
|                     reward += 1 | ||||
|                 c[char] -= 1 | ||||
|             else: | ||||
|                 result.append(0) | ||||
|                 correct = False | ||||
|         return result, correct, reward | ||||
|  | ||||
|     def step(self, action): | ||||
|         """ | ||||
|         action: index of word in valid_words | ||||
|  | ||||
|         returns: | ||||
|             observation: (TOTAL_GUESSES, WORD_LENGTH) | ||||
|             reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained | ||||
|             done: True if game over, w/ or w/o correct answer | ||||
|             additional_info: empty | ||||
|         """ | ||||
|         guess, guess_counter = self._valid_words[action] | ||||
|         if guess in self.guesses: | ||||
|             return self.obs, -1, False, {} | ||||
|         self.guesses.append(guess) | ||||
|         result, correct, reward = self._check_guess(guess, guess_counter) | ||||
|         done = False | ||||
|  | ||||
|         for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH): | ||||
|             self.obs[i] = result[i - self.guess_no*WORD_LENGTH] | ||||
|          | ||||
|         self.guess_no += 1 | ||||
|         if correct: | ||||
|             done = True | ||||
|             reward = 1200 | ||||
|         if self.guess_no == TOTAL_GUESSES: | ||||
|             done = True | ||||
|             if not correct: | ||||
|                 reward = -15 | ||||
|         return self.obs, reward, done, {} | ||||
|  | ||||
|     def reset(self): | ||||
|         self.solution = self._solutions[np.random.randint(len(self._solutions))] | ||||
|         self.solution_ct = Counter(self.solution) | ||||
|         self.guess_no = 0 | ||||
|         self.guesses = [] | ||||
|         self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, )) | ||||
|         return self.obs | ||||
|  | ||||
|     def render(self, mode="human"): | ||||
|         m = { | ||||
|             0: "⬜", | ||||
|             1: "🟨", | ||||
|             2: "🟩" | ||||
|         } | ||||
|         print("Solution:", self.solution) | ||||
|         for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))): | ||||
|             o_n = "".join(map(lambda x: m[x], o)) | ||||
|             print(g, o_n) | ||||
|  | ||||
|     def close(self): | ||||
|         pass | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     env = WordleEnv() | ||||
|     print(env.action_space) | ||||
|     print(env.observation_space) | ||||
|     print(env.solution) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
							
								
								
									
										12972
									
								
								wordle_gym/words/guess.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12972
									
								
								wordle_gym/words/guess.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										2315
									
								
								wordle_gym/words/solution.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2315
									
								
								wordle_gym/words/solution.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user