Compare commits

...

2 Commits

Author SHA1 Message Date
Arthur Lu
4f8ca4aa06 add ppo model, add cpu and cuda device support 2024-03-20 21:17:44 -07:00
Arthur Lu
b46d335044 working!! 2024-03-20 21:01:08 -07:00
6 changed files with 45 additions and 27 deletions

4
.gitignore vendored
View File

@ -1,6 +1,6 @@
**/data/* **/data/*
**/*.zip
**/__pycache__ **/__pycache__
/env /env
**/runs/* **/runs/*
**/wandb/* **/wandb/*
**/models/*

View File

@ -5,7 +5,7 @@ import numpy as np
from stable_baselines3 import PPO, DQN from stable_baselines3 import PPO, DQN
from letter_guess import LetterGuessingEnv from letter_guess import LetterGuessingEnv
import torch
def load_valid_words(file_path='wordle_words.txt'): def load_valid_words(file_path='wordle_words.txt'):
""" """
@ -23,7 +23,8 @@ def load_valid_words(file_path='wordle_words.txt'):
class AI: class AI:
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False): def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False, device="cuda"):
self.device = device
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.num_letters = num_letters self.num_letters = num_letters
self.num_guesses = 6 self.num_guesses = 6
@ -37,26 +38,28 @@ class AI:
self.use_q_model = use_q_model self.use_q_model = use_q_model
if use_q_model: if use_q_model:
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all # we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
self.q_env = LetterGuessingEnv(vocab_file) self.q_env = LetterGuessingEnv(load_valid_words(vocab_file))
self.q_env_state, _ = self.q_env.reset() self.q_env_state, _ = self.q_env.reset()
# load model # load model
self.q_model = PPO.load(model_file) self.q_model = PPO.load(model_file, device=self.device)
self.reset() self.reset("")
def solve_eval(self, results_callback): def solve_eval(self, results_callback):
num_guesses = 0 num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
num_guesses += 1 num_guesses += 1
if self.use_q_model:
self.freeze_state = self.q_env.clone_state()
# sample a word, this would use the q_env_state if the q_model is used # sample a word, this would use the q_env_state if the q_model is used
word = self.sample() word = self.sample(num_guesses)
# get emulated results # get emulated results
results = results_callback(word) results = results_callback(word)
if self.use_q_model: if self.use_q_model:
self.q_env.set_state(self.q_env_state) self.q_env.set_state(self.freeze_state)
# step the q_env to match the guess we just made # step the q_env to match the guess we just made
for i in range(len(word)): for i in range(len(word)):
char = word[i] char = word[i]
@ -70,13 +73,11 @@ class AI:
num_guesses = 0 num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
num_guesses += 1 num_guesses += 1
word = self.sample() if self.use_q_model:
self.freeze_state = self.q_env.clone_state()
# # Always start with these two words # sample a word, this would use the q_env_state if the q_model is used
# if num_guesses == 1: word = self.sample(num_guesses)
# word = 'soare'
# elif num_guesses == 2:
# word = 'culti'
print('-----------------------------------------------') print('-----------------------------------------------')
print(f'Guess #{num_guesses}/{self.num_guesses}: {word}') print(f'Guess #{num_guesses}/{self.num_guesses}: {word}')
@ -96,10 +97,16 @@ class AI:
results.append(result) results.append(result)
break break
self.arc_consistency(word, results) if self.use_q_model:
self.q_env.set_state(self.freeze_state)
# step the q_env to match the guess we just made
for i in range(len(word)):
char = word[i]
action = ord(char) - ord('a')
self.q_env_state, _, _, _, _ = self.q_env.step(action)
print(f'You did it! The word is {"".join([e[0] for e in self.domains])}') self.arc_consistency(word, results)
return num_guesses return num_guesses, word
def arc_consistency(self, word, results): def arc_consistency(self, word, results):
self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1'] self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1']
@ -119,14 +126,15 @@ class AI:
if results[i] == '2': if results[i] == '2':
self.domains[i] = [word[i]] self.domains[i] = [word[i]]
def reset(self): def reset(self, target_word):
self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)] self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)]
self.possible_letters = [] self.possible_letters = []
if self.use_q_model: if self.use_q_model:
self.q_env_state, _ = self.q_env.reset() self.q_env_state, _ = self.q_env.reset()
self.q_env.target_word = target_word
def sample(self): def sample(self, num_guesses):
""" """
Samples a best word given the current domains Samples a best word given the current domains
:return: :return:
@ -143,15 +151,15 @@ class AI:
for word, _ in self.best_words: for word, _ in self.best_words:
# reset the state back to before we guessed a word # reset the state back to before we guessed a word
if pattern.match(word) and False not in [e in word for e in self.possible_letters]: if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
if self.use_q_model: if self.use_q_model and num_guesses == 3:
self.q_env.set_state(self.q_env_state) self.q_env.set_state(self.freeze_state)
# Use policy to grade word # Use policy to grade word
# get the state and action pairs # get the state and action pairs
curr_qval = 0 curr_qval = 0
for l in word: for l in word:
action = ord(l) - ord('a') action = ord(l) - ord('a')
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action) q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to(self.device))
_, _, _, _, _ = self.q_env.step(action) _, _, _, _, _ = self.q_env.step(action)
curr_qval += q_val curr_qval += q_val

View File

@ -28,22 +28,26 @@ def main(args):
if args.n is None: if args.n is None:
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
total_guesses = 0 total_guesses = 0
wins = 0 wins = 0
num_eval = args.num_eval num_eval = args.num_eval
np.random.seed(0)
for i in tqdm(range(num_eval)): for i in tqdm(range(num_eval)):
idx = np.random.choice(range(len(ai.vocab))) idx = np.random.choice(range(len(ai.vocab)))
solution = ai.vocab[idx] solution = ai.vocab[idx]
ai.reset(solution)
guesses, word = ai.solve_eval(results_callback=result_callback) guesses, word = ai.solve_eval(results_callback=result_callback)
if word != solution: if word != solution:
total_guesses += 5 total_guesses += 5
else: else:
total_guesses += guesses total_guesses += guesses
wins += 1 wins += 1
ai.reset()
print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}") print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}")
@ -54,5 +58,6 @@ if __name__ == '__main__':
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000) parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
parser.add_argument('--q_model', dest="q_model", type=bool, default=False) parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
parser.add_argument('--device', dest="device", type=str, default="cuda")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -5,8 +5,9 @@ from ai import AI
def main(args): def main(args):
if args.n is None: if args.n is None:
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
print(f"using q model? {args.q_model}")
ai = AI(args.vocab_file) ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
ai.reset("lingo")
ai.solve() ai.solve()
@ -14,5 +15,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--n', dest='n', type=int, default=None) parser.add_argument('--n', dest='n', type=int, default=None)
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
parser.add_argument('--device', dest="device", type=str, default="cuda")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

1
inference.sh Executable file
View File

@ -0,0 +1 @@
python eric_wordle/main.py --n 1 --vocab_file wordle_words.txt --q_model True --model_file wordle_ppo_model --device cpu

BIN
wordle_ppo_model.zip Normal file

Binary file not shown.