{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "def load_valid_words(file_path='wordle_words.txt'):\n", " \"\"\"\n", " Load valid five-letter words from a specified text file.\n", "\n", " Parameters:\n", " - file_path (str): The path to the text file containing valid words.\n", "\n", " Returns:\n", " - list[str]: A list of valid words loaded from the file.\n", " \"\"\"\n", " with open(file_path, 'r') as file:\n", " valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n", " return valid_words" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from stable_baselines3 import PPO, DQN # Or any other suitable RL algorithm\n", "from stable_baselines3.common.env_checker import check_env\n", "from letter_guess import LetterGuessingEnv\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n", "check_env(env) # Optional: Verify the environment is compatible with SB3" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "from wandb.integration.sb3 import WandbCallback" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mltcptgeneral\u001b[0m (\u001b[33mfulltime\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /home/art/cse151b-final-project/wandb/run-20240319_211220-cyh5nscz" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run distinctive-flower-20 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/fulltime/wordle" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/fulltime/wordle/runs/cyh5nscz" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_save_path = \"wordle_ppo_model\"\n", "config = {\n", " \"policy_type\": \"MlpPolicy\",\n", " \"total_timesteps\": 200_000\n", "}\n", "run = wandb.init(\n", " project=\"wordle\",\n", " config=config,\n", " sync_tensorboard=True\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n", "Wrapping the env with a `Monitor` wrapper\n", "Wrapping the env in a DummyVecEnv.\n", "Logging to runs/cyh5nscz/PPO_1\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca60c274a90b4dddaf275fe164012f16", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "---------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 2.54 |\n", "| ep_rew_mean | -3.66 |\n", "| time/ | |\n", "| fps | 721 |\n", "| iterations | 1 |\n", "| time_elapsed | 2 |\n", "| total_timesteps | 2048 |\n", "---------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 2.53 |\n", "| ep_rew_mean | -3.61 |\n", "| time/ | |\n", "| fps | 718 |\n", "| iterations | 2 |\n", "| time_elapsed | 5 |\n", "| total_timesteps | 4096 |\n", "| train/ | |\n", "| approx_kl | 0.011673957 |\n", "| clip_fraction | 0.0292 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -3.25 |\n", "| explained_variance | -0.126 |\n", "| learning_rate | 0.0003 |\n", "| loss | 0.576 |\n", "| n_updates | 10 |\n", "| policy_gradient_loss | -0.0197 |\n", "| value_loss | 3.58 |\n", "-----------------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 2.7 |\n", "| ep_rew_mean | -3.56 |\n", "| time/ | |\n", "| fps | 698 |\n", "| iterations | 3 |\n", "| time_elapsed | 8 |\n", "| total_timesteps | 6144 |\n", "| train/ | |\n", "| approx_kl | 0.019258872 |\n", "| clip_fraction | 0.198 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -3.22 |\n", "| explained_variance | -0.211 |\n", "| learning_rate | 0.0003 |\n", "| loss | 0.187 |\n", "| n_updates | 20 |\n", "| policy_gradient_loss | -0.0215 |\n", "| value_loss | 0.637 |\n", "-----------------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 2.73 |\n", "| ep_rew_mean | -3.43 |\n", "| time/ | |\n", "| fps | 681 |\n", "| iterations | 4 |\n", "| time_elapsed | 12 |\n", "| total_timesteps | 8192 |\n", "| train/ | |\n", "| approx_kl | 0.021500897 |\n", "| clip_fraction | 0.171 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -3.17 |\n", "| explained_variance | 0.378 |\n", "| learning_rate | 0.0003 |\n", "| loss | 0.185 |\n", "| n_updates | 30 |\n", "| policy_gradient_loss | -0.0214 |\n", "| value_loss | 0.479 |\n", "-----------------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-----------------------------------------\n", "| rollout/ | |\n", "| ep_len_mean | 2.92 |\n", "| ep_rew_mean | -3.36 |\n", "| time/ | |\n", "| fps | 682 |\n", "| iterations | 5 |\n", "| time_elapsed | 14 |\n", "| total_timesteps | 10240 |\n", "| train/ | |\n", "| approx_kl | 0.018113121 |\n", "| clip_fraction | 0.101 |\n", "| clip_range | 0.2 |\n", "| entropy_loss | -3.11 |\n", "| explained_variance | 0.448 |\n", "| learning_rate | 0.0003 |\n", "| loss | 0.203 |\n", "| n_updates | 40 |\n", "| policy_gradient_loss | -0.0183 |\n", "| value_loss | 0.455 |\n", "-----------------------------------------\n" ] } ], "source": [ "model = PPO(config[\"policy_type\"], env=env, verbose=2, tensorboard_log=f\"runs/{run.id}\", batch_size=64)\n", "\n", "# Train for a certain number of timesteps\n", "model.learn(\n", " total_timesteps=config[\"total_timesteps\"],\n", " callback=WandbCallback(\n", " model_save_path=f\"models/{run.id}\",\n", " verbose=2,\n", " ),\n", "\tprogress_bar=True\n", ")\n", "\n", "run.finish()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.save(model_save_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = PPO.load(model_save_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rewards = 0\n", "for i in tqdm(range(1000)):\n", " obs, _ = env.reset()\n", " done = False\n", " while not done:\n", " action, _ = model.predict(obs)\n", " obs, reward, done, _, info = env.step(action)\n", " rewards += reward\n", "print(rewards / 1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 2 }