mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-26 01:59:10 +00:00
add ppo model, add cpu and cuda device support
This commit is contained in:
parent
b46d335044
commit
4f8ca4aa06
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,6 @@
|
|||||||
**/data/*
|
**/data/*
|
||||||
**/*.zip
|
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
/env
|
/env
|
||||||
**/runs/*
|
**/runs/*
|
||||||
**/wandb/*
|
**/wandb/*
|
||||||
|
**/models/*
|
@ -23,7 +23,8 @@ def load_valid_words(file_path='wordle_words.txt'):
|
|||||||
|
|
||||||
|
|
||||||
class AI:
|
class AI:
|
||||||
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
|
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False, device="cuda"):
|
||||||
|
self.device = device
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.num_letters = num_letters
|
self.num_letters = num_letters
|
||||||
self.num_guesses = 6
|
self.num_guesses = 6
|
||||||
@ -41,7 +42,7 @@ class AI:
|
|||||||
self.q_env_state, _ = self.q_env.reset()
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
self.q_model = PPO.load(model_file)
|
self.q_model = PPO.load(model_file, device=self.device)
|
||||||
|
|
||||||
self.reset("")
|
self.reset("")
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ class AI:
|
|||||||
|
|
||||||
for l in word:
|
for l in word:
|
||||||
action = ord(l) - ord('a')
|
action = ord(l) - ord('a')
|
||||||
q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to("cuda"))
|
q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to(self.device))
|
||||||
_, _, _, _, _ = self.q_env.step(action)
|
_, _, _, _, _ = self.q_env.step(action)
|
||||||
curr_qval += q_val
|
curr_qval += q_val
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ def main(args):
|
|||||||
if args.n is None:
|
if args.n is None:
|
||||||
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
|
|
||||||
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model)
|
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
|
||||||
|
|
||||||
total_guesses = 0
|
total_guesses = 0
|
||||||
wins = 0
|
wins = 0
|
||||||
@ -58,5 +58,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
|
parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000)
|
||||||
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
||||||
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
||||||
|
parser.add_argument('--device', dest="device", type=str, default="cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
@ -6,7 +6,7 @@ def main(args):
|
|||||||
if args.n is None:
|
if args.n is None:
|
||||||
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).')
|
||||||
print(f"using q model? {args.q_model}")
|
print(f"using q model? {args.q_model}")
|
||||||
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model)
|
ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device)
|
||||||
ai.reset("lingo")
|
ai.reset("lingo")
|
||||||
ai.solve()
|
ai.solve()
|
||||||
|
|
||||||
@ -17,5 +17,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt')
|
||||||
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model')
|
||||||
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
parser.add_argument('--q_model', dest="q_model", type=bool, default=False)
|
||||||
|
parser.add_argument('--device', dest="device", type=str, default="cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
1
inference.sh
Executable file
1
inference.sh
Executable file
@ -0,0 +1 @@
|
|||||||
|
python eric_wordle/main.py --n 1 --vocab_file wordle_words.txt --q_model True --model_file wordle_ppo_model --device cpu
|
BIN
wordle_ppo_model.zip
Normal file
BIN
wordle_ppo_model.zip
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user