From 4fb81317f0d60dea7bbdffe99f30d3169832c0a5 Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Wed, 20 Mar 2024 17:31:27 -0700 Subject: [PATCH] add letter_guess symlink, add model loading into ai.py --- eric_wordle/ai.py | 41 +++++++++++++++++++++++++++++++++---- eric_wordle/eval.py | 11 +++++++--- eric_wordle/letter_guess.py | 1 + eval.sh | 3 ++- 4 files changed, 48 insertions(+), 8 deletions(-) create mode 120000 eric_wordle/letter_guess.py mode change 100644 => 100755 eval.sh diff --git a/eric_wordle/ai.py b/eric_wordle/ai.py index 9f34246..0b0cce2 100644 --- a/eric_wordle/ai.py +++ b/eric_wordle/ai.py @@ -3,9 +3,25 @@ import string import numpy as np +from stable_baselines3 import PPO, DQN +from letter_guess import LetterGuessingEnv + +def load_valid_words(file_path='wordle_words.txt'): + """ + Load valid five-letter words from a specified text file. + + Parameters: + - file_path (str): The path to the text file containing valid words. + + Returns: + - list[str]: A list of valid words loaded from the file. + """ + with open(file_path, 'r') as file: + valid_words = [line.strip() for line in file if len(line.strip()) == 5] + return valid_words class AI: - def __init__(self, vocab_file, num_letters=5, num_guesses=6): + def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False): self.vocab_file = vocab_file self.num_letters = num_letters self.num_guesses = 6 @@ -16,16 +32,32 @@ class AI: self.domains = None self.possible_letters = None + self.use_q_model = use_q_model + if use_q_model: + # we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all + self.q_env = LetterGuessingEnv(vocab_file) + # load model + self.q_model = PPO.load(model_file) + self.q_env_state = None + self.reset() def solve_eval(self, results_callback): num_guesses = 0 while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: num_guesses += 1 + # sample a word, this would use the q_env_state if the q_model is used word = self.sample() + # get emulated results results = results_callback(word) + if self.use_q_model: + # step the q_env to match the guess we just made + for i in range(len(word)): + char = word[i] + action = ord(char) - ord('a') + self.q_env_state, _, _, _, _ = self.q_env.step(action) + self.arc_consistency(word, results) - # print(num_guesses, word, results) return num_guesses, word def solve(self): @@ -58,8 +90,6 @@ class AI: results.append(result) break - print(results) - self.arc_consistency(word, results) print(f'You did it! The word is {"".join([e[0] for e in self.domains])}') @@ -87,6 +117,9 @@ class AI: self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)] self.possible_letters = [] + if self.use_q_model: + self.q_env_state, _ = self.q_env.reset() + def sample(self): """ Samples a best word given the current domains diff --git a/eric_wordle/eval.py b/eric_wordle/eval.py index 49943c2..71f7a5d 100644 --- a/eric_wordle/eval.py +++ b/eric_wordle/eval.py @@ -1,6 +1,7 @@ import argparse from ai import AI import numpy as np +from tqdm import tqdm global solution @@ -27,12 +28,13 @@ def main(args): if args.n is None: raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') - ai = AI(args.vocab_file) + ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) total_guesses = 0 + wins = 0 num_eval = args.num_eval - for i in range(num_eval): + for i in tqdm(range(num_eval)): idx = np.random.choice(range(len(ai.vocab))) solution = ai.vocab[idx] guesses, word = ai.solve_eval(results_callback=result_callback) @@ -40,14 +42,17 @@ def main(args): total_guesses += 5 else: total_guesses += guesses + wins += 1 ai.reset() - print(total_guesses / num_eval) + print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--n', dest='n', type=int, default=None) parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000) + parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') + parser.add_argument('--q_model', dest="q_model", type=bool, default=False) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/eric_wordle/letter_guess.py b/eric_wordle/letter_guess.py new file mode 120000 index 0000000..b99d852 --- /dev/null +++ b/eric_wordle/letter_guess.py @@ -0,0 +1 @@ +../letter_guess.py \ No newline at end of file diff --git a/eval.sh b/eval.sh old mode 100644 new mode 100755 index 202cf7a..831f9e3 --- a/eval.sh +++ b/eval.sh @@ -1 +1,2 @@ -python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 1000 \ No newline at end of file +python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000 +python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000 --q_model True --model_file wordle_ppo_model \ No newline at end of file