mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-09-08 08:17:21 +00:00
added state saving
This commit is contained in:
@@ -6,6 +6,7 @@ import numpy as np
|
||||
from stable_baselines3 import PPO, DQN
|
||||
from letter_guess import LetterGuessingEnv
|
||||
|
||||
|
||||
def load_valid_words(file_path='wordle_words.txt'):
|
||||
"""
|
||||
Load valid five-letter words from a specified text file.
|
||||
@@ -20,8 +21,9 @@ def load_valid_words(file_path='wordle_words.txt'):
|
||||
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
||||
return valid_words
|
||||
|
||||
|
||||
class AI:
|
||||
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False):
|
||||
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
|
||||
self.vocab_file = vocab_file
|
||||
self.num_letters = num_letters
|
||||
self.num_guesses = 6
|
||||
@@ -36,18 +38,21 @@ class AI:
|
||||
if use_q_model:
|
||||
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
||||
self.q_env = LetterGuessingEnv(vocab_file)
|
||||
self.q_env_state, _ = self.q_env.reset()
|
||||
|
||||
# load model
|
||||
self.q_model = PPO.load(model_file)
|
||||
self.q_env_state = None
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
def solve_eval(self, results_callback):
|
||||
num_guesses = 0
|
||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||
num_guesses += 1
|
||||
|
||||
# sample a word, this would use the q_env_state if the q_model is used
|
||||
word = self.sample()
|
||||
|
||||
# get emulated results
|
||||
results = results_callback(word)
|
||||
if self.use_q_model:
|
||||
@@ -59,7 +64,7 @@ class AI:
|
||||
|
||||
self.arc_consistency(word, results)
|
||||
return num_guesses, word
|
||||
|
||||
|
||||
def solve(self):
|
||||
num_guesses = 0
|
||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||
@@ -132,9 +137,31 @@ class AI:
|
||||
pattern = re.compile(regex_string)
|
||||
|
||||
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||
max_qval = float('-inf')
|
||||
best_word = None
|
||||
for word, _ in self.best_words:
|
||||
# reset the state back to before we guessed a word
|
||||
self.q_env.set_state(self.q_env_state)
|
||||
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||
return word
|
||||
if self.use_q_model:
|
||||
# Use policy to grade word
|
||||
# get the state and action pairs
|
||||
curr_qval = 0
|
||||
|
||||
for l in word:
|
||||
action = ord(l) - ord('a')
|
||||
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action)
|
||||
_, _, _, _, _ = self.q_env.step(action)
|
||||
curr_qval += q_val
|
||||
|
||||
if curr_qval > max_qval:
|
||||
max_qval = curr_qval
|
||||
best_word = word
|
||||
else:
|
||||
# otherwise return the word from eric heuristic
|
||||
return word
|
||||
self.q_env.set_state(self.q_env_state)
|
||||
return best_word
|
||||
|
||||
def get_vocab(self, vocab_file):
|
||||
vocab = []
|
||||
|
Reference in New Issue
Block a user