add letter_guess symlink, add model loading into ai.py

This commit is contained in:
Arthur Lu 2024-03-20 17:31:27 -07:00
parent 12601964bd
commit 4fb81317f0
4 changed files with 48 additions and 8 deletions

View File

@ -3,9 +3,25 @@ import string
import numpy as np import numpy as np
from stable_baselines3 import PPO, DQN
from letter_guess import LetterGuessingEnv
def load_valid_words(file_path='wordle_words.txt'):
"""
Load valid five-letter words from a specified text file.
Parameters:
- file_path (str): The path to the text file containing valid words.
Returns:
- list[str]: A list of valid words loaded from the file.
"""
with open(file_path, 'r') as file:
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
return valid_words
class AI: class AI:
def __init__(self, vocab_file, num_letters=5, num_guesses=6): def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False):
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.num_letters = num_letters self.num_letters = num_letters
self.num_guesses = 6 self.num_guesses = 6
@ -16,16 +32,32 @@ class AI:
self.domains = None self.domains = None
self.possible_letters = None self.possible_letters = None
self.use_q_model = use_q_model
if use_q_model:
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
self.q_env = LetterGuessingEnv(vocab_file)
# load model
self.q_model = PPO.load(model_file)
self.q_env_state = None
self.reset() self.reset()
def solve_eval(self, results_callback): def solve_eval(self, results_callback):
num_guesses = 0 num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
num_guesses += 1 num_guesses += 1
# sample a word, this would use the q_env_state if the q_model is used
word = self.sample() word = self.sample()
# get emulated results
results = results_callback(word) results = results_callback(word)
if self.use_q_model:
# step the q_env to match the guess we just made
for i in range(len(word)):
char = word[i]
action = ord(char) - ord('a')
self.q_env_state, _, _, _, _ = self.q_env.step(action)
self.arc_consistency(word, results) self.arc_consistency(word, results)
# print(num_guesses, word, results)
return num_guesses, word return num_guesses, word
def solve(self): def solve(self):
@ -58,8 +90,6 @@ class AI:
results.append(result) results.append(result)
break break
print(results)
self.arc_consistency(word, results) self.arc_consistency(word, results)
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}') print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
@ -87,6 +117,9 @@ class AI:
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)] self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
self.possible_letters = [] self.possible_letters = []
if self.use_q_model:
self.q_env_state, _ = self.q_env.reset()
def sample(self): def sample(self):
""" """
Samples a best word given the current domains Samples a best word given the current domains

View File

@ -1,6 +1,7 @@
import argparse import argparse
from ai import AI from ai import AI
import numpy as np import numpy as np
from tqdm import tqdm
global solution global solution
@ -27,12 +28,13 @@ def main(args):
if args.n is None: if args.n is None:
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
ai = AI(args.vocab_file) ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model)
total_guesses = 0 total_guesses = 0
wins = 0
num_eval = args.num_eval num_eval = args.num_eval
for i in range(num_eval): for i in tqdm(range(num_eval)):
idx = np.random.choice(range(len(ai.vocab))) idx = np.random.choice(range(len(ai.vocab)))
solution = ai.vocab[idx] solution = ai.vocab[idx]
guesses, word = ai.solve_eval(results_callback=result_callback) guesses, word = ai.solve_eval(results_callback=result_callback)
@ -40,14 +42,17 @@ def main(args):
total_guesses += 5 total_guesses += 5
else: else:
total_guesses += guesses total_guesses += guesses
wins += 1
ai.reset() ai.reset()
print(total_guesses / num_eval) print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--n', dest='n', type=int, default=None) parser.add_argument('--n', dest='n', type=int, default=None)
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000) parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

1
eric_wordle/letter_guess.py Symbolic link
View File

@ -0,0 +1 @@
../letter_guess.py

3
eval.sh Normal file → Executable file
View File

@ -1 +1,2 @@
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 1000 python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000 --q_model True --model_file wordle_ppo_model