mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-09 22:54:45 +00:00
378 lines
12 KiB
Plaintext
378 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_valid_words(file_path='wordle_words.txt'):\n",
|
|
" \"\"\"\n",
|
|
" Load valid five-letter words from a specified text file.\n",
|
|
"\n",
|
|
" Parameters:\n",
|
|
" - file_path (str): The path to the text file containing valid words.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" - list[str]: A list of valid words loaded from the file.\n",
|
|
" \"\"\"\n",
|
|
" with open(file_path, 'r') as file:\n",
|
|
" valid_words = [line.strip() for line in file if len(line.strip()) == 5]\n",
|
|
" return valid_words"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from stable_baselines3 import PPO, DQN # Or any other suitable RL algorithm\n",
|
|
"from stable_baselines3.common.env_checker import check_env\n",
|
|
"from letter_guess import LetterGuessingEnv\n",
|
|
"from tqdm import tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = LetterGuessingEnv(valid_words=load_valid_words()) # Make sure to load your valid words\n",
|
|
"check_env(env) # Optional: Verify the environment is compatible with SB3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import wandb\n",
|
|
"from wandb.integration.sb3 import WandbCallback"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
|
|
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mltcptgeneral\u001b[0m (\u001b[33mfulltime\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"Tracking run with wandb version 0.16.4"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"Run data is saved locally in <code>/home/art/cse151b-final-project/wandb/run-20240319_211220-cyh5nscz</code>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"Syncing run <strong><a href='https://wandb.ai/fulltime/wordle/runs/cyh5nscz' target=\"_blank\">distinctive-flower-20</a></strong> to <a href='https://wandb.ai/fulltime/wordle' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
" View project at <a href='https://wandb.ai/fulltime/wordle' target=\"_blank\">https://wandb.ai/fulltime/wordle</a>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
" View run at <a href='https://wandb.ai/fulltime/wordle/runs/cyh5nscz' target=\"_blank\">https://wandb.ai/fulltime/wordle/runs/cyh5nscz</a>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model_save_path = \"wordle_ppo_model\"\n",
|
|
"config = {\n",
|
|
" \"policy_type\": \"MlpPolicy\",\n",
|
|
" \"total_timesteps\": 200_000\n",
|
|
"}\n",
|
|
"run = wandb.init(\n",
|
|
" project=\"wordle\",\n",
|
|
" config=config,\n",
|
|
" sync_tensorboard=True\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using cuda device\n",
|
|
"Wrapping the env with a `Monitor` wrapper\n",
|
|
"Wrapping the env in a DummyVecEnv.\n",
|
|
"Logging to runs/cyh5nscz/PPO_1\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "ca60c274a90b4dddaf275fe164012f16",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Output()"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"---------------------------------\n",
|
|
"| rollout/ | |\n",
|
|
"| ep_len_mean | 2.54 |\n",
|
|
"| ep_rew_mean | -3.66 |\n",
|
|
"| time/ | |\n",
|
|
"| fps | 721 |\n",
|
|
"| iterations | 1 |\n",
|
|
"| time_elapsed | 2 |\n",
|
|
"| total_timesteps | 2048 |\n",
|
|
"---------------------------------\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-----------------------------------------\n",
|
|
"| rollout/ | |\n",
|
|
"| ep_len_mean | 2.53 |\n",
|
|
"| ep_rew_mean | -3.61 |\n",
|
|
"| time/ | |\n",
|
|
"| fps | 718 |\n",
|
|
"| iterations | 2 |\n",
|
|
"| time_elapsed | 5 |\n",
|
|
"| total_timesteps | 4096 |\n",
|
|
"| train/ | |\n",
|
|
"| approx_kl | 0.011673957 |\n",
|
|
"| clip_fraction | 0.0292 |\n",
|
|
"| clip_range | 0.2 |\n",
|
|
"| entropy_loss | -3.25 |\n",
|
|
"| explained_variance | -0.126 |\n",
|
|
"| learning_rate | 0.0003 |\n",
|
|
"| loss | 0.576 |\n",
|
|
"| n_updates | 10 |\n",
|
|
"| policy_gradient_loss | -0.0197 |\n",
|
|
"| value_loss | 3.58 |\n",
|
|
"-----------------------------------------\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-----------------------------------------\n",
|
|
"| rollout/ | |\n",
|
|
"| ep_len_mean | 2.7 |\n",
|
|
"| ep_rew_mean | -3.56 |\n",
|
|
"| time/ | |\n",
|
|
"| fps | 698 |\n",
|
|
"| iterations | 3 |\n",
|
|
"| time_elapsed | 8 |\n",
|
|
"| total_timesteps | 6144 |\n",
|
|
"| train/ | |\n",
|
|
"| approx_kl | 0.019258872 |\n",
|
|
"| clip_fraction | 0.198 |\n",
|
|
"| clip_range | 0.2 |\n",
|
|
"| entropy_loss | -3.22 |\n",
|
|
"| explained_variance | -0.211 |\n",
|
|
"| learning_rate | 0.0003 |\n",
|
|
"| loss | 0.187 |\n",
|
|
"| n_updates | 20 |\n",
|
|
"| policy_gradient_loss | -0.0215 |\n",
|
|
"| value_loss | 0.637 |\n",
|
|
"-----------------------------------------\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-----------------------------------------\n",
|
|
"| rollout/ | |\n",
|
|
"| ep_len_mean | 2.73 |\n",
|
|
"| ep_rew_mean | -3.43 |\n",
|
|
"| time/ | |\n",
|
|
"| fps | 681 |\n",
|
|
"| iterations | 4 |\n",
|
|
"| time_elapsed | 12 |\n",
|
|
"| total_timesteps | 8192 |\n",
|
|
"| train/ | |\n",
|
|
"| approx_kl | 0.021500897 |\n",
|
|
"| clip_fraction | 0.171 |\n",
|
|
"| clip_range | 0.2 |\n",
|
|
"| entropy_loss | -3.17 |\n",
|
|
"| explained_variance | 0.378 |\n",
|
|
"| learning_rate | 0.0003 |\n",
|
|
"| loss | 0.185 |\n",
|
|
"| n_updates | 30 |\n",
|
|
"| policy_gradient_loss | -0.0214 |\n",
|
|
"| value_loss | 0.479 |\n",
|
|
"-----------------------------------------\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-----------------------------------------\n",
|
|
"| rollout/ | |\n",
|
|
"| ep_len_mean | 2.92 |\n",
|
|
"| ep_rew_mean | -3.36 |\n",
|
|
"| time/ | |\n",
|
|
"| fps | 682 |\n",
|
|
"| iterations | 5 |\n",
|
|
"| time_elapsed | 14 |\n",
|
|
"| total_timesteps | 10240 |\n",
|
|
"| train/ | |\n",
|
|
"| approx_kl | 0.018113121 |\n",
|
|
"| clip_fraction | 0.101 |\n",
|
|
"| clip_range | 0.2 |\n",
|
|
"| entropy_loss | -3.11 |\n",
|
|
"| explained_variance | 0.448 |\n",
|
|
"| learning_rate | 0.0003 |\n",
|
|
"| loss | 0.203 |\n",
|
|
"| n_updates | 40 |\n",
|
|
"| policy_gradient_loss | -0.0183 |\n",
|
|
"| value_loss | 0.455 |\n",
|
|
"-----------------------------------------\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = PPO(config[\"policy_type\"], env=env, verbose=2, tensorboard_log=f\"runs/{run.id}\", batch_size=64)\n",
|
|
"\n",
|
|
"# Train for a certain number of timesteps\n",
|
|
"model.learn(\n",
|
|
" total_timesteps=config[\"total_timesteps\"],\n",
|
|
" callback=WandbCallback(\n",
|
|
" model_save_path=f\"models/{run.id}\",\n",
|
|
" verbose=2,\n",
|
|
" ),\n",
|
|
"\tprogress_bar=True\n",
|
|
")\n",
|
|
"\n",
|
|
"run.finish()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.save(model_save_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = PPO.load(model_save_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"rewards = 0\n",
|
|
"for i in tqdm(range(1000)):\n",
|
|
" obs, _ = env.reset()\n",
|
|
" done = False\n",
|
|
" while not done:\n",
|
|
" action, _ = model.predict(obs)\n",
|
|
" obs, reward, done, _, info = env.step(action)\n",
|
|
" rewards += reward\n",
|
|
"print(rewards / 1000)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "env",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|