From cf977e479736b71d2df83292931269550e9bff97 Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Fri, 15 Mar 2024 18:48:21 -0700 Subject: [PATCH] try penalizing duplicate guesses --- dqn_wordle.ipynb | 425 ++++++++++++++++++------------------------- gym_wordle/wordle.py | 13 +- 2 files changed, 188 insertions(+), 250 deletions(-) diff --git a/dqn_wordle.ipynb b/dqn_wordle.ipynb index 1e31b94..55fb002 100644 --- a/dqn_wordle.ipynb +++ b/dqn_wordle.ipynb @@ -6,11 +6,10 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", "import gym_wordle\n", "from stable_baselines3 import DQN, PPO, common\n", "import numpy as np\n", - "import tqdm" + "from tqdm import tqdm" ] }, { @@ -42,213 +41,213 @@ "name": "stdout", "output_type": "stream", "text": [ - "Using cpu device\n", + "Using cuda device\n", "Wrapping the env in a DummyVecEnv.\n", "---------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 5.59 |\n", + "| ep_rew_mean | 2.14 |\n", "| time/ | |\n", - "| fps | 544 |\n", + "| fps | 750 |\n", "| iterations | 1 |\n", - "| time_elapsed | 3 |\n", + "| time_elapsed | 2 |\n", "| total_timesteps | 2048 |\n", "---------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.77 |\n", + "| ep_rew_mean | 4.59 |\n", "| time/ | |\n", - "| fps | 245 |\n", + "| fps | 625 |\n", "| iterations | 2 |\n", - "| time_elapsed | 16 |\n", + "| time_elapsed | 6 |\n", "| total_timesteps | 4096 |\n", "| train/ | |\n", - "| approx_kl | 0.021515464 |\n", - "| clip_fraction | 0.335 |\n", + "| approx_kl | 0.022059526 |\n", + "| clip_fraction | 0.331 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.00118 |\n", + "| explained_variance | -0.0118 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 89.5 |\n", + "| loss | 130 |\n", "| n_updates | 10 |\n", - "| policy_gradient_loss | -0.0854 |\n", - "| value_loss | 262 |\n", + "| policy_gradient_loss | -0.0851 |\n", + "| value_loss | 253 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | 5.86 |\n", + "| time/ | |\n", + "| fps | 585 |\n", + "| iterations | 3 |\n", + "| time_elapsed | 10 |\n", + "| total_timesteps | 6144 |\n", + "| train/ | |\n", + "| approx_kl | 0.024416003 |\n", + "| clip_fraction | 0.462 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | 0.152 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 85.2 |\n", + "| n_updates | 20 |\n", + "| policy_gradient_loss | -0.0987 |\n", + "| value_loss | 218 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | 4.75 |\n", + "| time/ | |\n", + "| fps | 566 |\n", + "| iterations | 4 |\n", + "| time_elapsed | 14 |\n", + "| total_timesteps | 8192 |\n", + "| train/ | |\n", + "| approx_kl | 0.026305672 |\n", + "| clip_fraction | 0.45 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.47 |\n", + "| explained_variance | 0.161 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 144 |\n", + "| n_updates | 30 |\n", + "| policy_gradient_loss | -0.105 |\n", + "| value_loss | 220 |\n", "-----------------------------------------\n", "----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.31 |\n", + "| ep_rew_mean | 1.47 |\n", "| time/ | |\n", - "| fps | 211 |\n", - "| iterations | 3 |\n", - "| time_elapsed | 29 |\n", - "| total_timesteps | 6144 |\n", + "| fps | 554 |\n", + "| iterations | 5 |\n", + "| time_elapsed | 18 |\n", + "| total_timesteps | 10240 |\n", "| train/ | |\n", - "| approx_kl | 0.02457875 |\n", - "| clip_fraction | 0.465 |\n", + "| approx_kl | 0.02928267 |\n", + "| clip_fraction | 0.498 |\n", "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.161 |\n", + "| entropy_loss | -9.46 |\n", + "| explained_variance | 0.167 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 118 |\n", - "| n_updates | 20 |\n", - "| policy_gradient_loss | -0.0987 |\n", - "| value_loss | 217 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 5.96 |\n", - "| ep_rew_mean | 5.79 |\n", - "| time/ | |\n", - "| fps | 196 |\n", - "| iterations | 4 |\n", - "| time_elapsed | 41 |\n", - "| total_timesteps | 8192 |\n", - "| train/ | |\n", - "| approx_kl | 0.02515613 |\n", - "| clip_fraction | 0.447 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.47 |\n", - "| explained_variance | 0.151 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 138 |\n", - "| n_updates | 30 |\n", - "| policy_gradient_loss | -0.103 |\n", - "| value_loss | 242 |\n", + "| loss | 127 |\n", + "| n_updates | 40 |\n", + "| policy_gradient_loss | -0.116 |\n", + "| value_loss | 207 |\n", "----------------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 4.9 |\n", + "| ep_rew_mean | 1.62 |\n", "| time/ | |\n", - "| fps | 177 |\n", - "| iterations | 5 |\n", - "| time_elapsed | 57 |\n", - "| total_timesteps | 10240 |\n", + "| fps | 546 |\n", + "| iterations | 6 |\n", + "| time_elapsed | 22 |\n", + "| total_timesteps | 12288 |\n", "| train/ | |\n", - "| approx_kl | 0.026685718 |\n", - "| clip_fraction | 0.444 |\n", + "| approx_kl | 0.028425258 |\n", + "| clip_fraction | 0.483 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.176 |\n", + "| explained_variance | 0.143 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 96.7 |\n", - "| n_updates | 40 |\n", - "| policy_gradient_loss | -0.111 |\n", - "| value_loss | 211 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 1.19 |\n", - "| time/ | |\n", - "| fps | 164 |\n", - "| iterations | 6 |\n", - "| time_elapsed | 74 |\n", - "| total_timesteps | 12288 |\n", - "| train/ | |\n", - "| approx_kl | 0.02762504 |\n", - "| clip_fraction | 0.463 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.186 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 103 |\n", - "| n_updates | 50 |\n", - "| policy_gradient_loss | -0.115 |\n", - "| value_loss | 200 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 5.5 |\n", - "| time/ | |\n", - "| fps | 155 |\n", - "| iterations | 7 |\n", - "| time_elapsed | 92 |\n", - "| total_timesteps | 14336 |\n", - "| train/ | |\n", - "| approx_kl | 0.02694263 |\n", - "| clip_fraction | 0.458 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.46 |\n", - "| explained_variance | 0.15 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 84.1 |\n", - "| n_updates | 60 |\n", - "| policy_gradient_loss | -0.116 |\n", - "| value_loss | 225 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 7.27 |\n", - "| time/ | |\n", - "| fps | 154 |\n", - "| iterations | 8 |\n", - "| time_elapsed | 106 |\n", - "| total_timesteps | 16384 |\n", - "| train/ | |\n", - "| approx_kl | 0.024316464 |\n", - "| clip_fraction | 0.412 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.173 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 126 |\n", - "| n_updates | 70 |\n", - "| policy_gradient_loss | -0.112 |\n", - "| value_loss | 227 |\n", + "| loss | 109 |\n", + "| n_updates | 50 |\n", + "| policy_gradient_loss | -0.117 |\n", + "| value_loss | 240 |\n", "-----------------------------------------\n", "-----------------------------------------\n", "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", - "| ep_rew_mean | 7.8 |\n", - "| time/ | |\n", - "| fps | 151 |\n", - "| iterations | 9 |\n", - "| time_elapsed | 121 |\n", - "| total_timesteps | 18432 |\n", - "| train/ | |\n", - "| approx_kl | 0.022988513 |\n", - "| clip_fraction | 0.391 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.206 |\n", - "| learning_rate | 0.0003 |\n", - "| loss | 139 |\n", - "| n_updates | 80 |\n", - "| policy_gradient_loss | -0.111 |\n", - "| value_loss | 228 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| rollout/ | |\n", - "| ep_len_mean | 6 |\n", + "| ep_len_mean | 5.98 |\n", "| ep_rew_mean | 6.14 |\n", "| time/ | |\n", - "| fps | 153 |\n", - "| iterations | 10 |\n", - "| time_elapsed | 133 |\n", - "| total_timesteps | 20480 |\n", + "| fps | 541 |\n", + "| iterations | 7 |\n", + "| time_elapsed | 26 |\n", + "| total_timesteps | 14336 |\n", "| train/ | |\n", - "| approx_kl | 0.022813996 |\n", - "| clip_fraction | 0.372 |\n", + "| approx_kl | 0.026178032 |\n", + "| clip_fraction | 0.453 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.46 |\n", + "| explained_variance | 0.174 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 141 |\n", + "| n_updates | 60 |\n", + "| policy_gradient_loss | -0.116 |\n", + "| value_loss | 235 |\n", + "-----------------------------------------\n", + "----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | 3.03 |\n", + "| time/ | |\n", + "| fps | 537 |\n", + "| iterations | 8 |\n", + "| time_elapsed | 30 |\n", + "| total_timesteps | 16384 |\n", + "| train/ | |\n", + "| approx_kl | 0.02457074 |\n", + "| clip_fraction | 0.423 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | 0.171 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 111 |\n", + "| n_updates | 70 |\n", + "| policy_gradient_loss | -0.112 |\n", + "| value_loss | 212 |\n", + "----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | 9.54 |\n", + "| time/ | |\n", + "| fps | 532 |\n", + "| iterations | 9 |\n", + "| time_elapsed | 34 |\n", + "| total_timesteps | 18432 |\n", + "| train/ | |\n", + "| approx_kl | 0.024578478 |\n", + "| clip_fraction | 0.417 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -9.45 |\n", - "| explained_variance | 0.199 |\n", + "| explained_variance | 0.178 |\n", "| learning_rate | 0.0003 |\n", - "| loss | 117 |\n", + "| loss | 121 |\n", + "| n_updates | 80 |\n", + "| policy_gradient_loss | -0.114 |\n", + "| value_loss | 232 |\n", + "-----------------------------------------\n", + "-----------------------------------------\n", + "| rollout/ | |\n", + "| ep_len_mean | 6 |\n", + "| ep_rew_mean | 3.81 |\n", + "| time/ | |\n", + "| fps | 527 |\n", + "| iterations | 10 |\n", + "| time_elapsed | 38 |\n", + "| total_timesteps | 20480 |\n", + "| train/ | |\n", + "| approx_kl | 0.022704324 |\n", + "| clip_fraction | 0.379 |\n", + "| clip_range | 0.2 |\n", + "| entropy_loss | -9.45 |\n", + "| explained_variance | 0.194 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | 108 |\n", "| n_updates | 90 |\n", - "| policy_gradient_loss | -0.108 |\n", - "| value_loss | 212 |\n", + "| policy_gradient_loss | -0.112 |\n", + "| value_loss | 216 |\n", "-----------------------------------------\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -275,101 +274,48 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.\n", - "Exception: code() argument 13 must be str, not int\n", - " warnings.warn(\n", - "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", - "Exception: code() argument 13 must be str, not int\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "model = PPO.load(\"dqn_wordle\")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "[[16 1 9 19 5 3 2 3 3 1]\n", - " [18 8 5 13 5 2 3 2 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]\n", - " [16 1 9 19 5 3 2 3 3 1]]\n", - "[[16 1 9 19 5 3 3 3 3 3]\n", - " [18 8 5 13 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 3 3 3]]\n", - "[[16 1 9 19 5 3 3 1 3 3]\n", - " [18 8 5 13 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]\n", - " [16 1 9 19 5 3 3 1 3 3]]\n", - "[[16 1 9 19 5 1 2 2 3 3]\n", - " [18 8 5 13 5 3 3 3 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]\n", - " [16 1 9 19 5 1 2 2 3 3]]\n", - "[[16 1 9 19 5 1 1 3 1 1]\n", - " [18 1 11 5 4 2 1 3 2 3]\n", - " [16 1 9 19 5 1 1 3 1 1]\n", - " [16 1 9 19 5 1 1 3 1 1]\n", - " [16 1 9 19 5 1 1 3 1 1]\n", - " [16 1 9 19 5 1 1 3 1 1]]\n", - "[[16 1 9 19 5 3 3 1 1 3]\n", - " [18 1 11 5 4 3 3 3 3 3]\n", - " [16 1 9 19 5 3 3 1 1 3]\n", - " [16 1 9 19 5 3 3 1 1 3]\n", - " [16 1 9 19 5 3 3 1 1 3]\n", - " [16 1 9 19 5 3 3 1 1 3]]\n", - "[[16 1 9 19 5 3 3 3 2 1]\n", - " [18 1 11 5 4 3 3 3 2 3]\n", - " [16 1 9 19 5 3 3 3 2 1]\n", - " [16 1 9 19 5 3 3 3 2 1]\n", - " [16 1 9 19 5 3 3 3 2 1]\n", - " [16 1 9 19 5 3 3 3 2 1]]\n", - "[[16 1 9 19 5 3 2 3 3 3]\n", - " [18 8 5 13 5 1 3 3 2 3]\n", - " [16 1 9 19 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 2 3 3 3]\n", - " [16 1 9 19 5 3 2 3 3 3]]\n", - "[[16 1 9 19 5 3 1 3 3 3]\n", - " [18 8 5 13 5 3 3 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]\n", - " [16 1 9 19 5 3 1 3 3 3]]\n", - "[[16 1 9 19 5 3 3 3 3 2]\n", - " [18 8 5 13 5 3 3 1 1 2]\n", - " [16 1 9 19 5 3 3 3 3 2]\n", - " [16 1 9 19 5 3 3 3 3 2]\n", - " [16 1 9 19 5 3 3 3 3 2]\n", - " [16 1 9 19 5 3 3 3 3 2]]\n", + "[[ 7 18 1 19 16 3 3 3 2 3]\n", + " [16 9 5 14 4 3 3 3 3 3]\n", + " [16 9 5 14 4 3 3 3 3 3]\n", + " [16 9 5 14 4 3 3 3 3 3]\n", + " [ 7 18 1 19 16 3 3 3 2 3]\n", + " [ 7 18 1 19 16 3 3 3 2 3]] -54 {'correct': False, 'guesses': defaultdict(, {'grasp': 3, 'piend': 3})}\n", "0\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ "env = gym_wordle.wordle.WordleEnv()\n", "\n", - "for i in range(10):\n", + "for i in tqdm(range(1000)):\n", " \n", " state, info = env.reset()\n", "\n", @@ -386,18 +332,11 @@ " if info[\"correct\"]:\n", " wins += 1\n", "\n", + "print(state, reward, info)\n", + "\n", "print(wins)\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -422,7 +361,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/gym_wordle/wordle.py b/gym_wordle/wordle.py index ff6586a..3c445e7 100644 --- a/gym_wordle/wordle.py +++ b/gym_wordle/wordle.py @@ -3,11 +3,10 @@ import numpy as np import numpy.typing as npt from sty import fg, bg, ef, rs -from collections import Counter +from collections import Counter, defaultdict from gym_wordle.utils import to_english, to_array, get_words from typing import Optional - class WordList(gym.spaces.Discrete): """Super class for defining a space of valid words according to a specified list. @@ -160,7 +159,7 @@ class WordleEnv(gym.Env): self.n_rounds = 6 self.n_letters = 5 - self.info = {'correct': False, 'guesses': set()} + self.info = {'correct': False, 'guesses': defaultdict(int)} def _highlighter(self, char: str, flag: int) -> str: """Terminal renderer functionality. Properly highlights a character @@ -195,7 +194,7 @@ class WordleEnv(gym.Env): self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) - self.info = {'correct': False, 'guesses': set()} + self.info = {'correct': False, 'guesses': defaultdict(int)} return self.state, self.info @@ -269,13 +268,13 @@ class WordleEnv(gym.Env): reward += np.sum(self.state[:, 5:] == 3) * -1 # guess same word as before - hashable_action = tuple(action) + hashable_action = to_english(action) if hashable_action in self.info['guesses']: - reward += -10 + reward += -10 * self.info['guesses'][hashable_action] else: # guess different word reward += 10 - self.info['guesses'].add(hashable_action) + self.info['guesses'][hashable_action] += 1 # for game ending in win or loss reward += 10 if correct else -10 if done else 0