diff --git a/dqn_wordle.py b/dqn_wordle.py index b1e6f08..b7f3fa1 100644 --- a/dqn_wordle.py +++ b/dqn_wordle.py @@ -24,9 +24,7 @@ def test(model, env, test_num=1000): while not done: action, _states = model.predict(obs) obs, rewards, done, info = env.step(action) - - print(action, obs, rewards) - + return total_correct / test_num if __name__ == "__main__": @@ -36,5 +34,5 @@ if __name__ == "__main__": print(env) print(model) - train(model, env, total_timesteps=10000) - print(test(model, env, test_num=1)) \ No newline at end of file + train(model, env, total_timesteps=500000) + print(test(model, env)) \ No newline at end of file