2024-03-14 04:36:26 +00:00
|
|
|
import gym
|
|
|
|
import sys
|
2024-03-13 21:27:34 +00:00
|
|
|
from stable_baselines3 import DQN
|
2024-03-14 04:36:26 +00:00
|
|
|
from stable_baselines3.common.env_util import make_vec_env
|
|
|
|
import wordle_gym
|
2024-03-13 21:27:34 +00:00
|
|
|
import numpy as np
|
2024-03-14 04:36:26 +00:00
|
|
|
from tqdm import tqdm
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
def train (model, env, total_timesteps = 100000):
|
|
|
|
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
|
|
|
model.save("dqn_wordle")
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
def test(model, env, test_num=1000):
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
total_correct = 0
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
for i in tqdm(range(test_num)):
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
model = DQN.load("dqn_wordle")
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
env = gym.make("wordle-v0")
|
|
|
|
obs = env.reset()
|
2024-03-13 21:27:34 +00:00
|
|
|
done = False
|
|
|
|
while not done:
|
2024-03-14 04:36:26 +00:00
|
|
|
action, _states = model.predict(obs)
|
|
|
|
obs, rewards, done, info = env.step(action)
|
2024-03-14 17:57:18 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
return total_correct / test_num
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 04:36:26 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
env = gym.make("wordle-v0")
|
|
|
|
model = DQN("MlpPolicy", env, verbose=0)
|
|
|
|
print(env)
|
|
|
|
print(model)
|
2024-03-13 21:27:34 +00:00
|
|
|
|
2024-03-14 17:57:18 +00:00
|
|
|
train(model, env, total_timesteps=500000)
|
|
|
|
print(test(model, env))
|