From b46d335044e337b1f98235306127ff0c690a0c3f Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Wed, 20 Mar 2024 21:01:08 -0700 Subject: [PATCH] working!! --- eric_wordle/ai.py | 45 ++++++++++++++++++++++++++------------------- eric_wordle/eval.py | 6 +++++- eric_wordle/main.py | 7 +++++-- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/eric_wordle/ai.py b/eric_wordle/ai.py index 1ab280c..73e0cc1 100644 --- a/eric_wordle/ai.py +++ b/eric_wordle/ai.py @@ -5,7 +5,7 @@ import numpy as np from stable_baselines3 import PPO, DQN from letter_guess import LetterGuessingEnv - +import torch def load_valid_words(file_path='wordle_words.txt'): """ @@ -37,26 +37,28 @@ class AI: self.use_q_model = use_q_model if use_q_model: # we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all - self.q_env = LetterGuessingEnv(vocab_file) + self.q_env = LetterGuessingEnv(load_valid_words(vocab_file)) self.q_env_state, _ = self.q_env.reset() # load model self.q_model = PPO.load(model_file) - self.reset() + self.reset("") def solve_eval(self, results_callback): num_guesses = 0 while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: num_guesses += 1 + if self.use_q_model: + self.freeze_state = self.q_env.clone_state() # sample a word, this would use the q_env_state if the q_model is used - word = self.sample() + word = self.sample(num_guesses) # get emulated results results = results_callback(word) if self.use_q_model: - self.q_env.set_state(self.q_env_state) + self.q_env.set_state(self.freeze_state) # step the q_env to match the guess we just made for i in range(len(word)): char = word[i] @@ -70,13 +72,11 @@ class AI: num_guesses = 0 while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: num_guesses += 1 - word = self.sample() + if self.use_q_model: + self.freeze_state = self.q_env.clone_state() - # # Always start with these two words - # if num_guesses == 1: - # word = 'soare' - # elif num_guesses == 2: - # word = 'culti' + # sample a word, this would use the q_env_state if the q_model is used + word = self.sample(num_guesses) print('-----------------------------------------------') print(f'Guess #{num_guesses}/{self.num_guesses}: {word}') @@ -96,10 +96,16 @@ class AI: results.append(result) break - self.arc_consistency(word, results) + if self.use_q_model: + self.q_env.set_state(self.freeze_state) + # step the q_env to match the guess we just made + for i in range(len(word)): + char = word[i] + action = ord(char) - ord('a') + self.q_env_state, _, _, _, _ = self.q_env.step(action) - print(f'You did it! The word is {"".join([e[0] for e in self.domains])}') - return num_guesses + self.arc_consistency(word, results) + return num_guesses, word def arc_consistency(self, word, results): self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1'] @@ -119,14 +125,15 @@ class AI: if results[i] == '2': self.domains[i] = [word[i]] - def reset(self): + def reset(self, target_word): self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)] self.possible_letters = [] if self.use_q_model: self.q_env_state, _ = self.q_env.reset() + self.q_env.target_word = target_word - def sample(self): + def sample(self, num_guesses): """ Samples a best word given the current domains :return: @@ -143,15 +150,15 @@ class AI: for word, _ in self.best_words: # reset the state back to before we guessed a word if pattern.match(word) and False not in [e in word for e in self.possible_letters]: - if self.use_q_model: - self.q_env.set_state(self.q_env_state) + if self.use_q_model and num_guesses == 3: + self.q_env.set_state(self.freeze_state) # Use policy to grade word # get the state and action pairs curr_qval = 0 for l in word: action = ord(l) - ord('a') - q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action) + q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to("cuda")) _, _, _, _, _ = self.q_env.step(action) curr_qval += q_val diff --git a/eric_wordle/eval.py b/eric_wordle/eval.py index 71f7a5d..aeca830 100644 --- a/eric_wordle/eval.py +++ b/eric_wordle/eval.py @@ -34,16 +34,20 @@ def main(args): wins = 0 num_eval = args.num_eval + np.random.seed(0) + for i in tqdm(range(num_eval)): idx = np.random.choice(range(len(ai.vocab))) solution = ai.vocab[idx] + + ai.reset(solution) + guesses, word = ai.solve_eval(results_callback=result_callback) if word != solution: total_guesses += 5 else: total_guesses += guesses wins += 1 - ai.reset() print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}") diff --git a/eric_wordle/main.py b/eric_wordle/main.py index 9632cab..71898e4 100644 --- a/eric_wordle/main.py +++ b/eric_wordle/main.py @@ -5,8 +5,9 @@ from ai import AI def main(args): if args.n is None: raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') - - ai = AI(args.vocab_file) + print(f"using q model? {args.q_model}") + ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) + ai.reset("lingo") ai.solve() @@ -14,5 +15,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--n', dest='n', type=int, default=None) parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') + parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') + parser.add_argument('--q_model', dest="q_model", type=bool, default=False) args = parser.parse_args() main(args) \ No newline at end of file