From 3747af9d2231073afde9ed2aa38fa172d3c39228 Mon Sep 17 00:00:00 2001 From: Ethan Shapiro <46407744+Ethan-Shapiro@users.noreply.github.com> Date: Wed, 20 Mar 2024 19:52:13 -0700 Subject: [PATCH] added state saving --- dqn_letter_gssr.ipynb | 172 +++++++++++++++++++++++++++++++++++++++++- eric_wordle/ai.py | 37 +++++++-- letter_guess.py | 35 +++++++-- 3 files changed, 230 insertions(+), 14 deletions(-) diff --git a/dqn_letter_gssr.ipynb b/dqn_letter_gssr.ipynb index 2b8a960..bfbc0a7 100644 --- a/dqn_letter_gssr.ipynb +++ b/dqn_letter_gssr.ipynb @@ -43,6 +43,174 @@ "check_env(env) # Optional: Verify the environment is compatible with SB3" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "initial_state = env.clone_state()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "obs, _ = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "model_save_path = \"wordle_ppo_model\"\n", + "model = PPO.load(model_save_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "action, _ = model.predict(obs)\n", + "obs, reward, done, _, info = env.step(action)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "action % 26" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ord('f') - ord('a')" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'f'" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chr(ord('a') + action % 26)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n", + " 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n", + " 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n", + " 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n", + " 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n", + " 1, 1])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "env.set_state(initial_state)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all(env.get_obs() == obs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Perform your action to see the outcome\n", + "action = # Define your action\n", + "observation, reward, done, info = env.step(action)\n", + "\n", + "# Revert to the initial state\n", + "env.env.set_state(initial_state)" + ] + }, { "cell_type": "code", "execution_count": 4, @@ -128,7 +296,7 @@ } ], "source": [ - "model_save_path = \"wordle_ppo_model\"\n", + "model_save_path = \"wordle_ppo_model_test\"\n", "config = {\n", " \"policy_type\": \"MlpPolicy\",\n", " \"total_timesteps\": 200_000\n", @@ -369,7 +537,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/eric_wordle/ai.py b/eric_wordle/ai.py index 0b0cce2..7af906f 100644 --- a/eric_wordle/ai.py +++ b/eric_wordle/ai.py @@ -6,6 +6,7 @@ import numpy as np from stable_baselines3 import PPO, DQN from letter_guess import LetterGuessingEnv + def load_valid_words(file_path='wordle_words.txt'): """ Load valid five-letter words from a specified text file. @@ -20,8 +21,9 @@ def load_valid_words(file_path='wordle_words.txt'): valid_words = [line.strip() for line in file if len(line.strip()) == 5] return valid_words + class AI: - def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False): + def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False): self.vocab_file = vocab_file self.num_letters = num_letters self.num_guesses = 6 @@ -36,18 +38,21 @@ class AI: if use_q_model: # we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all self.q_env = LetterGuessingEnv(vocab_file) + self.q_env_state, _ = self.q_env.reset() + # load model self.q_model = PPO.load(model_file) - self.q_env_state = None self.reset() - + def solve_eval(self, results_callback): num_guesses = 0 while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: num_guesses += 1 + # sample a word, this would use the q_env_state if the q_model is used word = self.sample() + # get emulated results results = results_callback(word) if self.use_q_model: @@ -59,7 +64,7 @@ class AI: self.arc_consistency(word, results) return num_guesses, word - + def solve(self): num_guesses = 0 while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: @@ -132,9 +137,31 @@ class AI: pattern = re.compile(regex_string) # From the words with the highest scores, only return the best word that match the regex pattern + max_qval = float('-inf') + best_word = None for word, _ in self.best_words: + # reset the state back to before we guessed a word + self.q_env.set_state(self.q_env_state) if pattern.match(word) and False not in [e in word for e in self.possible_letters]: - return word + if self.use_q_model: + # Use policy to grade word + # get the state and action pairs + curr_qval = 0 + + for l in word: + action = ord(l) - ord('a') + q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action) + _, _, _, _, _ = self.q_env.step(action) + curr_qval += q_val + + if curr_qval > max_qval: + max_qval = curr_qval + best_word = word + else: + # otherwise return the word from eric heuristic + return word + self.q_env.set_state(self.q_env_state) + return best_word def get_vocab(self, vocab_file): vocab = [] diff --git a/letter_guess.py b/letter_guess.py index 8822f8e..c073584 100644 --- a/letter_guess.py +++ b/letter_guess.py @@ -3,6 +3,7 @@ from gymnasium import spaces import numpy as np import random import re +import copy class LetterGuessingEnv(gym.Env): @@ -29,8 +30,28 @@ class LetterGuessingEnv(gym.Env): self.reset() + def clone_state(self): + # Clone the current state + return { + 'target_word': self.target_word, + 'letter_flags': copy.deepcopy(self.letter_flags), + 'letter_positions': copy.deepcopy(self.letter_positions), + 'guessed_letters': copy.deepcopy(self.guessed_letters), + 'guess_prefix': self.guess_prefix, + 'round': self.round + } + + def set_state(self, state): + # Restore the state + self.target_word = state['target_word'] + self.letter_flags = copy.deepcopy(state['letter_flags']) + self.letter_positions = copy.deepcopy(state['letter_positions']) + self.guessed_letters = copy.deepcopy(state['guessed_letters']) + self.guess_prefix = state['guess_prefix'] + self.round = state['round'] + def step(self, action): - letter_index = action % 26 # Assuming action is the letter index directly + letter_index = action # Assuming action is the letter index directly position = len(self.guess_prefix) # The next position in the prefix is determined by its current length letter = chr(ord('a') + letter_index) @@ -56,8 +77,8 @@ class LetterGuessingEnv(gym.Env): reward = 1 # Reward for adding new information by trying a new letter # Update the letter_positions matrix to reflect the new guess - if position == 4: - self.letter_positions[:,:] = 1 + if position == 4: + self.letter_positions[:, :] = 1 else: self.letter_positions[:, position] = 0 self.letter_positions[letter_index, position] = 1 @@ -77,8 +98,8 @@ class LetterGuessingEnv(gym.Env): # reward = 5 done = True - obs = self._get_obs() - + obs = self.get_obs() + if reward < -5: print(obs, reward, done) exit(0) @@ -93,7 +114,7 @@ class LetterGuessingEnv(gym.Env): self.guessed_letters = set() self.guess_prefix = "" # Reset the guess prefix for the new episode self.round = 0 - return self._get_obs(), {} + return self.get_obs(), {} def encode_word(self, word): encoded = np.zeros((26,)) @@ -102,7 +123,7 @@ class LetterGuessingEnv(gym.Env): encoded[index] = 1 return encoded - def _get_obs(self): + def get_obs(self): return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()]) def render(self, mode='human'):