From 5672169073d6ab228c38da49db21844077698e88 Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Thu, 14 Mar 2024 14:49:17 -0700 Subject: [PATCH] copy the wordle env locally and fix the obs return --- .gitignore | 3 +- dqn_wordle.ipynb | 157 ++++++++++++++++++++++++++++++++++------------- 2 files changed, 118 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index 1244f49..b3ea66f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ **/data/* -**/*.zip \ No newline at end of file +**/*.zip +**/ \ No newline at end of file diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb index a3b8f14..36185ba 100644 --- a/dqn_wordle.ipynb +++ b/dqn_wordle.ipynb @@ -2,64 +2,55 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import gym_wordle\n", - "from stable_baselines3 import DQN\n", + "from stable_baselines3 import DQN, PPO, common\n", "import numpy as np\n", "import tqdm" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 84, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">\n" + ] + } + ], "source": [ - "env = gym.make(\"Wordle-v0\")\n", + "env = gym_wordle.wordle.WordleEnv()\n", + "env = common.monitor.Monitor(env)\n", "\n", "print(env)" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 85, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device\n", + "Wrapping the env in a DummyVecEnv.\n" + ] + } + ], "source": [ - "total_timesteps = 100000\n", - "model = DQN(\"MlpPolicy\", env, verbose=0)\n", - "model.learn(total_timesteps=total_timesteps, progress_bar=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def test(model):\n", - "\n", - " end_rewards = []\n", - "\n", - " for i in range(1000):\n", - " \n", - " state = env.reset()\n", - "\n", - " done = False\n", - "\n", - " while not done:\n", - "\n", - " action, _states = model.predict(state, deterministic=True)\n", - "\n", - " state, reward, done, info = env.step(action)\n", - " \n", - " end_rewards.append(reward == 0)\n", - " \n", - " return np.sum(end_rewards) / len(end_rewards)" + "total_timesteps = 1000\n", + "model = PPO(\"MlpPolicy\", env, verbose=1)\n", + "model.learn(total_timesteps=total_timesteps)" ] }, { @@ -77,7 +68,93 @@ "metadata": {}, "outputs": [], "source": [ - "model = DQN.load(\"dqn_wordle\")" + "model = PPO.load(\"dqn_wordle\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[16 18 5 15 14 3 3 1 3 3]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", + "[[16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", + "[[16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", + "[[16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [ 0 0 0 0 0 0 0 0 0 0]\n", + " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", + "[[16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [ 0 0 0 0 0 0 0 0 0 0]] -1.0 False {}\n", + "[[16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]\n", + " [16 18 5 15 14 3 3 1 3 3]] -1.0 True {}\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'correct'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[82], line 19\u001b[0m\n\u001b[1;32m 15\u001b[0m state, reward, done, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(action)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mprint\u001b[39m(state, reward, done, info)\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43minfo\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcorrect\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m:\n\u001b[1;32m 20\u001b[0m wins \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m#end_rewards.append(reward == 0)\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m#return np.sum(end_rewards) / len(end_rewards)\u001b[39;00m\n", + "\u001b[0;31mKeyError\u001b[0m: 'correct'" + ] + } + ], + "source": [ + "env = gym_wordle.wordle.WordleEnv()\n", + "\n", + "for i in range(1):\n", + " \n", + " state = env.reset()\n", + "\n", + " done = False\n", + "\n", + " wins = 0\n", + "\n", + " while not done:\n", + "\n", + " action, _states = model.predict(state, deterministic=True)\n", + "\n", + " state, reward, done, info = env.step(action)\n", + "\n", + " print(state, reward, done, info)\n", + "\n", + " if info[\"correct\"]:\n", + " wins += 1\n", + " \n", + " #end_rewards.append(reward == 0)\n", + " \n", + "#return np.sum(end_rewards) / len(end_rewards)\n" ] }, { @@ -85,9 +162,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "print(test(model))" - ] + "source": [] } ], "metadata": {