mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-26 01:59:10 +00:00
add letter_guess symlink, add model loading into ai.py
This commit is contained in:
parent
12601964bd
commit
4fb81317f0
@ -3,9 +3,25 @@ import string
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO, DQN
|
||||||
|
from letter_guess import LetterGuessingEnv
|
||||||
|
|
||||||
|
def load_valid_words(file_path='wordle_words.txt'):
|
||||||
|
"""
|
||||||
|
Load valid five-letter words from a specified text file.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file_path (str): The path to the text file containing valid words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- list[str]: A list of valid words loaded from the file.
|
||||||
|
"""
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
||||||
|
return valid_words
|
||||||
|
|
||||||
class AI:
|
class AI:
|
||||||
def __init__(self, vocab_file, num_letters=5, num_guesses=6):
|
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False):
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.num_letters = num_letters
|
self.num_letters = num_letters
|
||||||
self.num_guesses = 6
|
self.num_guesses = 6
|
||||||
@ -16,16 +32,32 @@ class AI:
|
|||||||
self.domains = None
|
self.domains = None
|
||||||
self.possible_letters = None
|
self.possible_letters = None
|
||||||
|
|
||||||
|
self.use_q_model = use_q_model
|
||||||
|
if use_q_model:
|
||||||
|
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
||||||
|
self.q_env = LetterGuessingEnv(vocab_file)
|
||||||
|
# load model
|
||||||
|
self.q_model = PPO.load(model_file)
|
||||||
|
self.q_env_state = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def solve_eval(self, results_callback):
|
def solve_eval(self, results_callback):
|
||||||
num_guesses = 0
|
num_guesses = 0
|
||||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
num_guesses += 1
|
num_guesses += 1
|
||||||
|
# sample a word, this would use the q_env_state if the q_model is used
|
||||||
word = self.sample()
|
word = self.sample()
|
||||||
|
# get emulated results
|
||||||
results = results_callback(word)
|
results = results_callback(word)
|
||||||
|
if self.use_q_model:
|
||||||
|
# step the q_env to match the guess we just made
|
||||||
|
for i in range(len(word)):
|
||||||
|
char = word[i]
|
||||||
|
action = ord(char) - ord('a')
|
||||||
|
self.q_env_state, _, _, _, _ = self.q_env.step(action)
|
||||||
|
|
||||||
self.arc_consistency(word, results)
|
self.arc_consistency(word, results)
|
||||||
# print(num_guesses, word, results)
|
|
||||||
return num_guesses, word
|
return num_guesses, word
|
||||||
|
|
||||||
def solve(self):
|
def solve(self):
|
||||||
@ -58,8 +90,6 @@ class AI:
|
|||||||
results.append(result)
|
results.append(result)
|
||||||
break
|
break
|
||||||
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
self.arc_consistency(word, results)
|
self.arc_consistency(word, results)
|
||||||
|
|
||||||
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}')
|
||||||
@ -87,6 +117,9 @@ class AI:
|
|||||||
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
|
||||||
self.possible_letters = []
|
self.possible_letters = []
|
||||||
|
|
||||||
|
if self.use_q_model:
|
||||||
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""
|
"""
|
||||||
Samples a best word given the current domains
|
Samples a best word given the current domains
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from ai import AI
|
from ai import AI
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
global solution
|
global solution
|
||||||
|
|
||||||
@ -27,12 +28,13 @@ def main(args):
|
|||||||
if args.n is None:
|
if args.n is None:
|
||||||
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
|
||||||
ai = AI(args.vocab_file)
|
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model)
|
||||||
|
|
||||||
total_guesses = 0
|
total_guesses = 0
|
||||||
|
wins = 0
|
||||||
num_eval = args.num_eval
|
num_eval = args.num_eval
|
||||||
|
|
||||||
for i in range(num_eval):
|
for i in tqdm(range(num_eval)):
|
||||||
idx = np.random.choice(range(len(ai.vocab)))
|
idx = np.random.choice(range(len(ai.vocab)))
|
||||||
solution = ai.vocab[idx]
|
solution = ai.vocab[idx]
|
||||||
guesses, word = ai.solve_eval(results_callback=result_callback)
|
guesses, word = ai.solve_eval(results_callback=result_callback)
|
||||||
@ -40,14 +42,17 @@ def main(args):
|
|||||||
total_guesses += 5
|
total_guesses += 5
|
||||||
else:
|
else:
|
||||||
total_guesses += guesses
|
total_guesses += guesses
|
||||||
|
wins += 1
|
||||||
ai.reset()
|
ai.reset()
|
||||||
|
|
||||||
print(total_guesses / num_eval)
|
print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--n', dest='n', type=int, default=None)
|
parser.add_argument('--n', dest='n', type=int, default=None)
|
||||||
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
|
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
|
||||||
|
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
||||||
|
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
1
eric_wordle/letter_guess.py
Symbolic link
1
eric_wordle/letter_guess.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../letter_guess.py
|
3
eval.sh
Normal file → Executable file
3
eval.sh
Normal file → Executable file
@ -1 +1,2 @@
|
|||||||
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 1000
|
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000
|
||||||
|
python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt --num_eval 5000 --q_model True --model_file wordle_ppo_model
|
Loading…
Reference in New Issue
Block a user