diff --git a/.gitignore b/.gitignore index 5bd8281..d9ef7a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ **/data/* -**/*.zip **/__pycache__ /env **/runs/* -**/wandb/* \ No newline at end of file +**/wandb/* +**/models/* \ No newline at end of file diff --git a/eric_wordle/ai.py b/eric_wordle/ai.py index 73e0cc1..2c16f1c 100644 --- a/eric_wordle/ai.py +++ b/eric_wordle/ai.py @@ -23,7 +23,8 @@ def load_valid_words(file_path='wordle_words.txt'): class AI: - def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False): + def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False, device="cuda"): + self.device = device self.vocab_file = vocab_file self.num_letters = num_letters self.num_guesses = 6 @@ -41,7 +42,7 @@ class AI: self.q_env_state, _ = self.q_env.reset() # load model - self.q_model = PPO.load(model_file) + self.q_model = PPO.load(model_file, device=self.device) self.reset("") @@ -158,7 +159,7 @@ class AI: for l in word: action = ord(l) - ord('a') - q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to("cuda")) + q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to(self.device)) _, _, _, _, _ = self.q_env.step(action) curr_qval += q_val diff --git a/eric_wordle/eval.py b/eric_wordle/eval.py index aeca830..6646c6b 100644 --- a/eric_wordle/eval.py +++ b/eric_wordle/eval.py @@ -28,7 +28,7 @@ def main(args): if args.n is None: raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') - ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) + ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device) total_guesses = 0 wins = 0 @@ -58,5 +58,6 @@ if __name__ == '__main__': parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000) parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') parser.add_argument('--q_model', dest="q_model", type=bool, default=False) + parser.add_argument('--device', dest="device", type=str, default="cuda") args = parser.parse_args() main(args) \ No newline at end of file diff --git a/eric_wordle/main.py b/eric_wordle/main.py index 71898e4..a6f830e 100644 --- a/eric_wordle/main.py +++ b/eric_wordle/main.py @@ -6,7 +6,7 @@ def main(args): if args.n is None: raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') print(f"using q model? {args.q_model}") - ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) + ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device) ai.reset("lingo") ai.solve() @@ -17,5 +17,6 @@ if __name__ == '__main__': parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') parser.add_argument('--q_model', dest="q_model", type=bool, default=False) + parser.add_argument('--device', dest="device", type=str, default="cuda") args = parser.parse_args() main(args) \ No newline at end of file diff --git a/inference.sh b/inference.sh new file mode 100755 index 0000000..804c037 --- /dev/null +++ b/inference.sh @@ -0,0 +1 @@ +python eric_wordle/main.py --n 1 --vocab_file wordle_words.txt --q_model True --model_file wordle_ppo_model --device cpu \ No newline at end of file diff --git a/wordle_ppo_model.zip b/wordle_ppo_model.zip new file mode 100644 index 0000000..005007e Binary files /dev/null and b/wordle_ppo_model.zip differ