added state saving

This commit is contained in:
Ethan Shapiro
2024-03-20 19:52:13 -07:00
parent 4fb81317f0
commit 3747af9d22
3 changed files with 230 additions and 14 deletions

View File

@@ -6,6 +6,7 @@ import numpy as np
from stable_baselines3 import PPO, DQN
from letter_guess import LetterGuessingEnv
def load_valid_words(file_path='wordle_words.txt'):
"""
Load valid five-letter words from a specified text file.
@@ -20,8 +21,9 @@ def load_valid_words(file_path='wordle_words.txt'):
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
return valid_words
class AI:
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False):
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
self.vocab_file = vocab_file
self.num_letters = num_letters
self.num_guesses = 6
@@ -36,18 +38,21 @@ class AI:
if use_q_model:
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
self.q_env = LetterGuessingEnv(vocab_file)
self.q_env_state, _ = self.q_env.reset()
# load model
self.q_model = PPO.load(model_file)
self.q_env_state = None
self.reset()
def solve_eval(self, results_callback):
num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
num_guesses += 1
# sample a word, this would use the q_env_state if the q_model is used
word = self.sample()
# get emulated results
results = results_callback(word)
if self.use_q_model:
@@ -59,7 +64,7 @@ class AI:
self.arc_consistency(word, results)
return num_guesses, word
def solve(self):
num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
@@ -132,9 +137,31 @@ class AI:
pattern = re.compile(regex_string)
# From the words with the highest scores, only return the best word that match the regex pattern
max_qval = float('-inf')
best_word = None
for word, _ in self.best_words:
# reset the state back to before we guessed a word
self.q_env.set_state(self.q_env_state)
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
return word
if self.use_q_model:
# Use policy to grade word
# get the state and action pairs
curr_qval = 0
for l in word:
action = ord(l) - ord('a')
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action)
_, _, _, _, _ = self.q_env.step(action)
curr_qval += q_val
if curr_qval > max_qval:
max_qval = curr_qval
best_word = word
else:
# otherwise return the word from eric heuristic
return word
self.q_env.set_state(self.q_env_state)
return best_word
def get_vocab(self, vocab_file):
vocab = []