{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import gym_wordle\n", "from stable_baselines3 import DQN, PPO, common\n", "import numpy as np\n", "import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">\n" ] } ], "source": [ "env = gym_wordle.wordle.WordleEnv()\n", "env = common.monitor.Monitor(env)\n", "\n", "print(env)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c52630b65904d5e8e200be505d2121a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n", "Wrapping the env with a `Monitor` wrapper\n", "Wrapping the env in a DummyVecEnv.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -175 |\n", "| exploration_rate | 0.525 |\n", "| time/ | |\n", "| episodes | 10000 |\n", "| fps | 4606 |\n", "| time_elapsed | 10 |\n", "| total_timesteps | 49989 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -208 |\n", "| exploration_rate | 0.0502 |\n", "| time/ | |\n", "| episodes | 20000 |\n", "| fps | 1118 |\n", "| time_elapsed | 89 |\n", "| total_timesteps | 99980 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 24.6 |\n", "| n_updates | 12494 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -230 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 30000 |\n", "| fps | 856 |\n", "| time_elapsed | 175 |\n", "| total_timesteps | 149974 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 18.7 |\n", "| n_updates | 24993 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -242 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 40000 |\n", "| fps | 766 |\n", "| time_elapsed | 260 |\n", "| total_timesteps | 199967 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 24 |\n", "| n_updates | 37491 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -186 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 50000 |\n", "| fps | 722 |\n", "| time_elapsed | 346 |\n", "| total_timesteps | 249962 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 21.5 |\n", "| n_updates | 49990 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -183 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 60000 |\n", "| fps | 694 |\n", "| time_elapsed | 431 |\n", "| total_timesteps | 299957 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 17.6 |\n", "| n_updates | 62489 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -181 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 70000 |\n", "| fps | 675 |\n", "| time_elapsed | 517 |\n", "| total_timesteps | 349953 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 26.8 |\n", "| n_updates | 74988 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -196 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 80000 |\n", "| fps | 663 |\n", "| time_elapsed | 603 |\n", "| total_timesteps | 399936 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 22.5 |\n", "| n_updates | 87483 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -174 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 90000 |\n", "| fps | 653 |\n", "| time_elapsed | 688 |\n", "| total_timesteps | 449928 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 21.1 |\n", "| n_updates | 99981 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -155 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 100000 |\n", "| fps | 645 |\n", "| time_elapsed | 774 |\n", "| total_timesteps | 499920 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 22.8 |\n", "| n_updates | 112479 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -153 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 110000 |\n", "| fps | 638 |\n", "| time_elapsed | 860 |\n", "| total_timesteps | 549916 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 16 |\n", "| n_updates | 124978 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -164 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 120000 |\n", "| fps | 633 |\n", "| time_elapsed | 947 |\n", "| total_timesteps | 599915 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 17.8 |\n", "| n_updates | 137478 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -145 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 130000 |\n", "| fps | 628 |\n", "| time_elapsed | 1033 |\n", "| total_timesteps | 649910 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 17.8 |\n", "| n_updates | 149977 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -154 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 140000 |\n", "| fps | 624 |\n", "| time_elapsed | 1120 |\n", "| total_timesteps | 699902 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 20.9 |\n", "| n_updates | 162475 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -192 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 150000 |\n", "| fps | 621 |\n", "| time_elapsed | 1206 |\n", "| total_timesteps | 749884 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 18.3 |\n", "| n_updates | 174970 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -170 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 160000 |\n", "| fps | 618 |\n", "| time_elapsed | 1293 |\n", "| total_timesteps | 799869 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 17.7 |\n", "| n_updates | 187467 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -233 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 170000 |\n", "| fps | 615 |\n", "| time_elapsed | 1380 |\n", "| total_timesteps | 849855 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 21.6 |\n", "| n_updates | 199963 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -146 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 180000 |\n", "| fps | 613 |\n", "| time_elapsed | 1466 |\n", "| total_timesteps | 899847 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 19.4 |\n", "| n_updates | 212461 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -142 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 190000 |\n", "| fps | 611 |\n", "| time_elapsed | 1553 |\n", "| total_timesteps | 949846 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 22.9 |\n", "| n_updates | 224961 |\n", "----------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 5 |\n", "| ep_rew_mean | -171 |\n", "| exploration_rate | 0.05 |\n", "| time/ | |\n", "| episodes | 200000 |\n", "| fps | 609 |\n", "| time_elapsed | 1640 |\n", "| total_timesteps | 999839 |\n", "| train/ | |\n", "| learning_rate | 0.0001 |\n", "| loss | 20.3 |\n", "| n_updates | 237459 |\n", "----------------------------------\n" ] }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total_timesteps = 1_000_000\n", "model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n", "model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model.save(\"dqn_new_rewards\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", "Exception: code() argument 13 must be str, not int\n", " warnings.warn(\n", "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n", "Exception: code() argument 13 must be str, not int\n", " warnings.warn(\n" ] } ], "source": [ "# model = DQN.load(\"dqn_wordle\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n" ] } ], "source": [ "env = gym_wordle.wordle.WordleEnv()\n", "\n", "for i in range(1000):\n", " \n", " state, info = env.reset()\n", "\n", " done = False\n", "\n", " wins = 0\n", "\n", " while not done:\n", "\n", " action, _states = model.predict(state, deterministic=True)\n", "\n", " state, reward, done, truncated, info = env.step(action)\n", "\n", " if info[\"correct\"]:\n", " wins += 1\n", "\n", "print(wins)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[18, 1, 20, 5, 19, 3, 3, 3, 3, 3],\n", " [14, 15, 9, 12, 25, 2, 3, 2, 2, 2],\n", " [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n", " [25, 21, 3, 11, 15, 2, 3, 3, 3, 3],\n", " [ 1, 20, 13, 15, 19, 3, 3, 3, 3, 3],\n", " [25, 21, 3, 11, 15, 2, 3, 3, 3, 3]], dtype=int64),\n", " -130)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state, reward" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "blah = (14, 1, 9, 22, 5)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blah in info['guesses']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }