cse151b-final-project/dqn_wordle.py

40 lines
979 B
Python
Raw Normal View History

2024-03-14 04:36:26 +00:00
import gym
import sys
from stable_baselines3 import DQN
2024-03-14 04:36:26 +00:00
from stable_baselines3.common.env_util import make_vec_env
import wordle_gym
import numpy as np
2024-03-14 04:36:26 +00:00
from tqdm import tqdm
2024-03-14 04:36:26 +00:00
def train (model, env, total_timesteps = 100000):
model.learn(total_timesteps=total_timesteps, progress_bar=True)
model.save("dqn_wordle")
2024-03-14 04:36:26 +00:00
def test(model, env, test_num=1000):
2024-03-14 04:36:26 +00:00
total_correct = 0
2024-03-14 04:36:26 +00:00
for i in tqdm(range(test_num)):
2024-03-14 04:36:26 +00:00
model = DQN.load("dqn_wordle")
2024-03-14 04:36:26 +00:00
env = gym.make("wordle-v0")
obs = env.reset()
done = False
while not done:
2024-03-14 04:36:26 +00:00
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
2024-03-14 04:36:26 +00:00
print(action, obs, rewards)
2024-03-14 04:36:26 +00:00
return total_correct / test_num
2024-03-14 04:36:26 +00:00
if __name__ == "__main__":
env = gym.make("wordle-v0")
model = DQN("MlpPolicy", env, verbose=0)
print(env)
print(model)
2024-03-14 04:36:26 +00:00
train(model, env, total_timesteps=10000)
print(test(model, env, test_num=1))