mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 10:39:20 +00:00 
			
		
		
		
	Compare commits
	
		
			10 Commits
		
	
	
		
			f40301cac9
			...
			main
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 3595fc1b07 | ||
|  | 4f8ca4aa06 | ||
|  | 2244de94ac | ||
|  | b46d335044 | ||
|  | 284a29d7af | ||
|  | 3747af9d22 | ||
|  | 4fb81317f0 | ||
|  | 12601964bd | ||
|  | c448e02512 | ||
|  | 848d385482 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,6 +1,6 @@ | ||||
| **/data/* | ||||
| **/*.zip | ||||
| **/__pycache__ | ||||
| /env | ||||
| **/runs/* | ||||
| **/wandb/* | ||||
| **/wandb/* | ||||
| **/models/* | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										338
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
							
						
						
									
										338
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							| @@ -1,338 +0,0 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gym\n", | ||||
|     "import gym_wordle\n", | ||||
|     "from stable_baselines3 import DQN, PPO, common\n", | ||||
|     "import numpy as np\n", | ||||
|     "import tqdm" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "<Monitor<WordleEnv instance>>\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym_wordle.wordle.WordleEnv()\n", | ||||
|     "env = common.monitor.Monitor(env)\n", | ||||
|     "\n", | ||||
|     "print(env)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Using cuda device\n", | ||||
|       "Wrapping the env in a DummyVecEnv.\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "application/vnd.jupyter.widget-view+json": { | ||||
|        "model_id": "6921a0721569456abf5bceac7e7b6b34", | ||||
|        "version_major": 2, | ||||
|        "version_minor": 0 | ||||
|       }, | ||||
|       "text/plain": [ | ||||
|        "Output()" | ||||
|       ] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "----------------------------------\n", | ||||
|       "| rollout/            |          |\n", | ||||
|       "|    ep_len_mean      | 4.97     |\n", | ||||
|       "|    ep_rew_mean      | -63.8    |\n", | ||||
|       "|    exploration_rate | 0.05     |\n", | ||||
|       "| time/               |          |\n", | ||||
|       "|    episodes         | 10000    |\n", | ||||
|       "|    fps              | 1628     |\n", | ||||
|       "|    time_elapsed     | 30       |\n", | ||||
|       "|    total_timesteps  | 49995    |\n", | ||||
|       "----------------------------------\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "----------------------------------\n", | ||||
|       "| rollout/            |          |\n", | ||||
|       "|    ep_len_mean      | 5        |\n", | ||||
|       "|    ep_rew_mean      | -70.5    |\n", | ||||
|       "|    exploration_rate | 0.05     |\n", | ||||
|       "| time/               |          |\n", | ||||
|       "|    episodes         | 20000    |\n", | ||||
|       "|    fps              | 662      |\n", | ||||
|       "|    time_elapsed     | 150      |\n", | ||||
|       "|    total_timesteps  | 99992    |\n", | ||||
|       "| train/              |          |\n", | ||||
|       "|    learning_rate    | 0.0001   |\n", | ||||
|       "|    loss             | 11.7     |\n", | ||||
|       "|    n_updates        | 12497    |\n", | ||||
|       "----------------------------------\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/html": [ | ||||
|        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | ||||
|       ], | ||||
|       "text/plain": [] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/html": [ | ||||
|        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | ||||
|        "</pre>\n" | ||||
|       ], | ||||
|       "text/plain": [ | ||||
|        "\n" | ||||
|       ] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 3, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "total_timesteps = 100_000\n", | ||||
|     "model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n", | ||||
|     "model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model.save(\"dqn_new_state\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", | ||||
|       "Exception: code() argument 13 must be str, not int\n", | ||||
|       "  warnings.warn(\n", | ||||
|       "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n", | ||||
|       "Exception: code() argument 13 must be str, not int\n", | ||||
|       "  warnings.warn(\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# model = DQN.load(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "0\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym_wordle.wordle.WordleEnv()\n", | ||||
|     "\n", | ||||
|     "for i in range(1000):\n", | ||||
|     "        \n", | ||||
|     "    state, info = env.reset()\n", | ||||
|     "\n", | ||||
|     "    done = False\n", | ||||
|     "\n", | ||||
|     "    wins = 0\n", | ||||
|     "\n", | ||||
|     "    while not done:\n", | ||||
|     "\n", | ||||
|     "        action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "        state, reward, done, truncated, info = env.step(action)\n", | ||||
|     "\n", | ||||
|     "    print(state)\n", | ||||
|     "    if info[\"correct\"]:\n", | ||||
|     "        wins += 1\n", | ||||
|     "\n", | ||||
|     "print(wins)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | ||||
|        "        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | ||||
|        "        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | ||||
|        "        0., 0., 0., 0., 0., 0., 0., 1.]),\n", | ||||
|        " -50)" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 6, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "state, reward" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.5" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
| @@ -3,9 +3,28 @@ import string | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from stable_baselines3 import PPO, DQN | ||||
| from letter_guess import LetterGuessingEnv | ||||
| import torch | ||||
|  | ||||
| def load_valid_words(file_path='wordle_words.txt'): | ||||
|     """ | ||||
|     Load valid five-letter words from a specified text file. | ||||
|  | ||||
|     Parameters: | ||||
|     - file_path (str): The path to the text file containing valid words. | ||||
|  | ||||
|     Returns: | ||||
|     - list[str]: A list of valid words loaded from the file. | ||||
|     """ | ||||
|     with open(file_path, 'r') as file: | ||||
|         valid_words = [line.strip() for line in file if len(line.strip()) == 5] | ||||
|     return valid_words | ||||
|  | ||||
|  | ||||
| class AI: | ||||
|     def __init__(self, vocab_file, num_letters=5, num_guesses=6): | ||||
|     def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False, device="cuda"): | ||||
|         self.device = device | ||||
|         self.vocab_file = vocab_file | ||||
|         self.num_letters = num_letters | ||||
|         self.num_guesses = 6 | ||||
| @@ -16,43 +35,80 @@ class AI: | ||||
|         self.domains = None | ||||
|         self.possible_letters = None | ||||
|  | ||||
|         self.reset() | ||||
|         self.use_q_model = use_q_model | ||||
|         if use_q_model: | ||||
|             # we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all | ||||
|             self.q_env = LetterGuessingEnv(load_valid_words(vocab_file)) | ||||
|             self.q_env_state, _ = self.q_env.reset() | ||||
|  | ||||
|             # load model | ||||
|             self.q_model = PPO.load(model_file, device=self.device) | ||||
|  | ||||
|         self.reset("") | ||||
|  | ||||
|     def solve_eval(self, results_callback): | ||||
|         num_guesses = 0 | ||||
|         while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: | ||||
|             num_guesses += 1 | ||||
|             if self.use_q_model: | ||||
|                 self.freeze_state = self.q_env.clone_state() | ||||
|  | ||||
|             # sample a word, this would use the q_env_state if the q_model is used | ||||
|             word = self.sample(num_guesses) | ||||
|  | ||||
|             # get emulated results | ||||
|             results = results_callback(word) | ||||
|             if self.use_q_model: | ||||
|                 self.q_env.set_state(self.freeze_state) | ||||
|                 # step the q_env to match the guess we just made | ||||
|                 for i in range(len(word)): | ||||
|                     char = word[i] | ||||
|                     action = ord(char) - ord('a') | ||||
|                     self.q_env_state, _, _, _, _ = self.q_env.step(action) | ||||
|  | ||||
|             self.arc_consistency(word, results) | ||||
|         return num_guesses, word | ||||
|  | ||||
|     def solve(self): | ||||
|         num_guesses = 0 | ||||
|         while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: | ||||
|             num_guesses += 1 | ||||
|             word = self.sample() | ||||
|             if self.use_q_model: | ||||
|                 self.freeze_state = self.q_env.clone_state() | ||||
|  | ||||
|             # # Always start with these two words | ||||
|             # if num_guesses == 1: | ||||
|             #     word = 'soare' | ||||
|             # elif num_guesses == 2: | ||||
|             #     word = 'culti' | ||||
|             # sample a word, this would use the q_env_state if the q_model is used | ||||
|             word = self.sample(num_guesses) | ||||
|  | ||||
|             print('-----------------------------------------------') | ||||
|             print(f'Guess #{num_guesses}/{self.num_guesses}: {word}') | ||||
|             print('-----------------------------------------------') | ||||
|             self.arc_consistency(word) | ||||
|  | ||||
|         print(f'You did it! The word is {"".join([e[0] for e in self.domains])}') | ||||
|             print(f'Performing arc consistency check on {word}...') | ||||
|             print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.') | ||||
|             results = [] | ||||
|  | ||||
|             # Collect results | ||||
|             for l in word: | ||||
|                 while True: | ||||
|                     result = input(f'{l}: ') | ||||
|                     if result not in ['0', '1', '2']: | ||||
|                         print('Incorrect option. Try again.') | ||||
|                         continue | ||||
|                     results.append(result) | ||||
|                     break | ||||
|  | ||||
|     def arc_consistency(self, word): | ||||
|         print(f'Performing arc consistency check on {word}...') | ||||
|         print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.') | ||||
|         results = [] | ||||
|             if self.use_q_model: | ||||
|                 self.q_env.set_state(self.freeze_state) | ||||
|                 # step the q_env to match the guess we just made | ||||
|                 for i in range(len(word)): | ||||
|                     char = word[i] | ||||
|                     action = ord(char) - ord('a') | ||||
|                     self.q_env_state, _, _, _, _ = self.q_env.step(action) | ||||
|  | ||||
|         # Collect results | ||||
|         for l in word: | ||||
|             while True: | ||||
|                 result = input(f'{l}: ') | ||||
|                 if result not in ['0', '1', '2']: | ||||
|                     print('Incorrect option. Try again.') | ||||
|                     continue | ||||
|                 results.append(result) | ||||
|                 break | ||||
|             self.arc_consistency(word, results) | ||||
|         return num_guesses, word | ||||
|  | ||||
|     def arc_consistency(self, word, results): | ||||
|         self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1'] | ||||
|  | ||||
|         for i in range(len(word)): | ||||
| @@ -70,12 +126,15 @@ class AI: | ||||
|             if results[i] == '2': | ||||
|                 self.domains[i] = [word[i]] | ||||
|  | ||||
|  | ||||
|     def reset(self): | ||||
|     def reset(self, target_word): | ||||
|         self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)] | ||||
|         self.possible_letters = [] | ||||
|  | ||||
|     def sample(self): | ||||
|         if self.use_q_model: | ||||
|             self.q_env_state, _ = self.q_env.reset() | ||||
|             self.q_env.target_word = target_word | ||||
|  | ||||
|     def sample(self, num_guesses): | ||||
|         """ | ||||
|         Samples a best word given the current domains | ||||
|         :return: | ||||
| @@ -87,9 +146,30 @@ class AI: | ||||
|         pattern = re.compile(regex_string) | ||||
|  | ||||
|         # From the words with the highest scores, only return the best word that match the regex pattern | ||||
|         max_qval = float('-inf') | ||||
|         best_word = None | ||||
|         for word, _ in self.best_words: | ||||
|             # reset the state back to before we guessed a word | ||||
|             if pattern.match(word) and False not in [e in word for e in self.possible_letters]: | ||||
|                 return word | ||||
|                 if self.use_q_model and num_guesses == 3: | ||||
|                     self.q_env.set_state(self.freeze_state) | ||||
|                     # Use policy to grade word | ||||
|                     # get the state and action pairs | ||||
|                     curr_qval = 0 | ||||
|  | ||||
|                     for l in word: | ||||
|                         action = ord(l) - ord('a') | ||||
|                         q_val, _, _ = self.q_model.policy.evaluate_actions(self.q_model.policy.obs_to_tensor(self.q_env.get_obs())[0], torch.Tensor(np.array([action])).to(self.device)) | ||||
|                         _, _, _, _, _ = self.q_env.step(action) | ||||
|                         curr_qval += q_val | ||||
|  | ||||
|                     if curr_qval > max_qval: | ||||
|                         max_qval = curr_qval | ||||
|                         best_word = word | ||||
|                 else: | ||||
|                     # otherwise return the word from eric heuristic | ||||
|                     return word | ||||
|         return best_word | ||||
|  | ||||
|     def get_vocab(self, vocab_file): | ||||
|         vocab = [] | ||||
|   | ||||
							
								
								
									
										63
									
								
								eric_wordle/eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								eric_wordle/eval.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| import argparse | ||||
| from ai import AI | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
|  | ||||
| global solution | ||||
|  | ||||
| def result_callback(word): | ||||
|  | ||||
|     global solution | ||||
|  | ||||
|     result = ['0', '0', '0', '0', '0'] | ||||
|  | ||||
|     for i, letter in enumerate(word): | ||||
|  | ||||
|         if solution[i] == word[i]: | ||||
|             result[i] = '2' | ||||
|         elif letter in solution: | ||||
|             result[i] = '1' | ||||
|         else: | ||||
|             pass | ||||
|  | ||||
|     return result | ||||
|  | ||||
| def main(args): | ||||
|     global solution  | ||||
|  | ||||
|     if args.n is None: | ||||
|         raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') | ||||
|  | ||||
|     ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device) | ||||
|  | ||||
|     total_guesses = 0 | ||||
|     wins = 0 | ||||
|     num_eval = args.num_eval | ||||
|  | ||||
|     np.random.seed(0) | ||||
|  | ||||
|     for i in tqdm(range(num_eval)): | ||||
|         idx = np.random.choice(range(len(ai.vocab))) | ||||
|         solution = ai.vocab[idx] | ||||
|  | ||||
|         ai.reset(solution) | ||||
|  | ||||
|         guesses, word = ai.solve_eval(results_callback=result_callback) | ||||
|         if word != solution: | ||||
|             total_guesses += 5 | ||||
|         else: | ||||
|             total_guesses += guesses | ||||
|             wins += 1 | ||||
|  | ||||
|     print(f"q_model?: {args.q_model} \t average guesses per game: {total_guesses / num_eval} \t win rate: {wins / num_eval}") | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--n', dest='n', type=int, default=None) | ||||
|     parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') | ||||
|     parser.add_argument('--num_eval', dest="num_eval", type=int, default=1000) | ||||
|     parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') | ||||
|     parser.add_argument('--q_model', dest="q_model", type=bool, default=False) | ||||
|     parser.add_argument('--device', dest="device", type=str, default="cuda") | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
							
								
								
									
										1
									
								
								eric_wordle/letter_guess.py
									
									
									
									
									
										Symbolic link
									
								
							
							
						
						
									
										1
									
								
								eric_wordle/letter_guess.py
									
									
									
									
									
										Symbolic link
									
								
							| @@ -0,0 +1 @@ | ||||
| ../letter_guess.py | ||||
| @@ -5,8 +5,9 @@ from ai import AI | ||||
| def main(args): | ||||
|     if args.n is None: | ||||
|         raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') | ||||
|  | ||||
|     ai = AI(args.vocab_file) | ||||
|     print(f"using q model? {args.q_model}") | ||||
|     ai = AI(args.vocab_file, args.model_file, use_q_model=args.q_model, device=args.device) | ||||
|     ai.reset("lingo") | ||||
|     ai.solve() | ||||
|  | ||||
|  | ||||
| @@ -14,5 +15,8 @@ if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--n', dest='n', type=int, default=None) | ||||
|     parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') | ||||
|     parser.add_argument('--model_file', dest="model_file", type=str, default='wordle_ppo_model') | ||||
|     parser.add_argument('--q_model', dest="q_model", type=bool, default=False) | ||||
|     parser.add_argument('--device', dest="device", type=str, default="cuda") | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
							
								
								
									
										2
									
								
								eval.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										2
									
								
								eval.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt  --num_eval 5000 | ||||
| python eric_wordle/eval.py --n 1 --vocab_file wordle_words.txt  --num_eval 5000 --q_model True --model_file wordle_ppo_model | ||||
							
								
								
									
										1
									
								
								inference.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										1
									
								
								inference.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1 @@ | ||||
| python eric_wordle/main.py --n 1 --vocab_file wordle_words.txt  --q_model True --model_file wordle_ppo_model --device cpu | ||||
| @@ -3,6 +3,7 @@ from gymnasium import spaces | ||||
| import numpy as np | ||||
| import random | ||||
| import re | ||||
| import copy | ||||
|  | ||||
|  | ||||
| class LetterGuessingEnv(gym.Env): | ||||
| @@ -29,8 +30,28 @@ class LetterGuessingEnv(gym.Env): | ||||
|  | ||||
|         self.reset() | ||||
|  | ||||
|     def clone_state(self): | ||||
|         # Clone the current state | ||||
|         return { | ||||
|             'target_word': self.target_word, | ||||
|             'letter_flags': copy.deepcopy(self.letter_flags), | ||||
|             'letter_positions': copy.deepcopy(self.letter_positions), | ||||
|             'guessed_letters': copy.deepcopy(self.guessed_letters), | ||||
|             'guess_prefix': self.guess_prefix, | ||||
|             'round': self.round | ||||
|         } | ||||
|  | ||||
|     def set_state(self, state): | ||||
|         # Restore the state | ||||
|         self.target_word = state['target_word'] | ||||
|         self.letter_flags = copy.deepcopy(state['letter_flags']) | ||||
|         self.letter_positions = copy.deepcopy(state['letter_positions']) | ||||
|         self.guessed_letters = copy.deepcopy(state['guessed_letters']) | ||||
|         self.guess_prefix = state['guess_prefix'] | ||||
|         self.round = state['round'] | ||||
|  | ||||
|     def step(self, action): | ||||
|         letter_index = action % 26  # Assuming action is the letter index directly | ||||
|         letter_index = action  # Assuming action is the letter index directly | ||||
|         position = len(self.guess_prefix)  # The next position in the prefix is determined by its current length | ||||
|         letter = chr(ord('a') + letter_index) | ||||
|  | ||||
| @@ -56,8 +77,8 @@ class LetterGuessingEnv(gym.Env): | ||||
|             reward = 1  # Reward for adding new information by trying a new letter | ||||
|  | ||||
|             # Update the letter_positions matrix to reflect the new guess | ||||
|             if position == 4:  | ||||
|                 self.letter_positions[:,:] = 1 | ||||
|             if position == 4: | ||||
|                 self.letter_positions[:, :] = 1 | ||||
|             else: | ||||
|                 self.letter_positions[:, position] = 0 | ||||
|                 self.letter_positions[letter_index, position] = 1 | ||||
| @@ -72,15 +93,16 @@ class LetterGuessingEnv(gym.Env): | ||||
|             self.guess_prefix = '' | ||||
|             self.round += 1 | ||||
|  | ||||
|         # end after 5 rounds of total guesses | ||||
|         if self.round == 2: | ||||
|         # end after 3 rounds of total guesses | ||||
|         if self.round == 3: | ||||
|             # reward = 5 | ||||
|             done = True | ||||
|  | ||||
|         obs = self._get_obs() | ||||
|          | ||||
|         if reward < -50: | ||||
|         obs = self.get_obs() | ||||
|  | ||||
|         if reward < -5: | ||||
|             print(obs, reward, done) | ||||
|             exit(0) | ||||
|  | ||||
|         return obs, reward, done, False, {} | ||||
|  | ||||
| @@ -91,8 +113,8 @@ class LetterGuessingEnv(gym.Env): | ||||
|         self.letter_positions = np.ones((26, 4), dtype=np.int32) | ||||
|         self.guessed_letters = set() | ||||
|         self.guess_prefix = ""  # Reset the guess prefix for the new episode | ||||
|         self.round = 1 | ||||
|         return self._get_obs(), {} | ||||
|         self.round = 0 | ||||
|         return self.get_obs(), {} | ||||
|  | ||||
|     def encode_word(self, word): | ||||
|         encoded = np.zeros((26,)) | ||||
| @@ -101,7 +123,7 @@ class LetterGuessingEnv(gym.Env): | ||||
|             encoded[index] = 1 | ||||
|         return encoded | ||||
|  | ||||
|     def _get_obs(self): | ||||
|     def get_obs(self): | ||||
|         return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()]) | ||||
|  | ||||
|     def render(self, mode='human'): | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								wordle_ppo_model.zip
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								wordle_ppo_model.zip
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user