Merge pull request #2 from ltcptgeneral/gym-2

add ppo model, add cpu and cuda device support
This commit is contained in:
Arthur Lu 2024-03-20 21:18:19 -07:00 committed by GitHub
commit 3595fc1b07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 11 additions and 7 deletions

2
.gitignore vendored
View File

@ -1,6 +1,6 @@
**/data/* **/data/*
**/*.zip
**/__pycache__ **/__pycache__
/env /env
**/runs/* **/runs/*
**/wandb/* **/wandb/*
**/models/*

View File

@ -23,7 +23,8 @@ def load_valid_words(file_path='wordle_words.txt'):
class AI: class AI:
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False): def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False, device="cuda"):
self.device = device
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.num_letters = num_letters self.num_letters = num_letters
self.num_guesses = 6 self.num_guesses = 6
@ -41,7 +42,7 @@ class AI:
self.q_env_state, _ = self.q_env.reset() self.q_env_state, _ = self.q_env.reset()
# load model # load model
self.q_model = PPO.load(model_file) self.q_model = PPO.load(model_file, device=self.device)
self.reset("") self.reset("")
@ -158,7 +159,7 @@ class AI:
for l in word: for l in word:
action = ord(l) - ord('a') action = ord(l) - ord('a')
q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to("cuda")) q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to(self.device))
_, _, _, _, _ = self.q_env.step(action) _, _, _, _, _ = self.q_env.step(action)
curr_qval += q_val curr_qval += q_val

View File

@ -28,7 +28,7 @@ def main(args):
if args.n is None: if args.n is None:
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
total_guesses = 0 total_guesses = 0
wins = 0 wins = 0
@ -58,5 +58,6 @@ if __name__ == '__main__':
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000) parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
parser.add_argument('--q_model', dest="q_model", type=bool, default=False) parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
parser.add_argument('--device', dest="device", type=str, default="cuda")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -6,7 +6,7 @@ def main(args):
if args.n is None: if args.n is None:
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
print(f"using q model? {args.q_model}") print(f"using q model? {args.q_model}")
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model) ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
ai.reset("lingo") ai.reset("lingo")
ai.solve() ai.solve()
@ -17,5 +17,6 @@ if __name__ == '__main__':
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
parser.add_argument('--q_model', dest="q_model", type=bool, default=False) parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
parser.add_argument('--device', dest="device", type=str, default="cuda")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

1
inference.sh Executable file
View File

@ -0,0 +1 @@
python eric_wordle/main.py --n 1 --vocab_file wordle_words.txt --q_model True --model_file wordle_ppo_model --device cpu

BIN
wordle_ppo_model.zip Normal file

Binary file not shown.