mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			f40301cac9
			...
			arthur-tes
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | dd5889da33 | ||
|  | 848ea719b7 | ||
|  | f641d77c47 | 
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,6 +1,3 @@ | ||||
| **/data/* | ||||
| **/*.zip | ||||
| **/__pycache__ | ||||
| /env | ||||
| **/runs/* | ||||
| **/wandb/* | ||||
| **/__pycache__ | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										338
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
							
						
						
									
										338
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							| @@ -1,338 +0,0 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gym\n", | ||||
|     "import gym_wordle\n", | ||||
|     "from stable_baselines3 import DQN, PPO, common\n", | ||||
|     "import numpy as np\n", | ||||
|     "import tqdm" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "<Monitor<WordleEnv instance>>\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym_wordle.wordle.WordleEnv()\n", | ||||
|     "env = common.monitor.Monitor(env)\n", | ||||
|     "\n", | ||||
|     "print(env)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Using cuda device\n", | ||||
|       "Wrapping the env in a DummyVecEnv.\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "application/vnd.jupyter.widget-view+json": { | ||||
|        "model_id": "6921a0721569456abf5bceac7e7b6b34", | ||||
|        "version_major": 2, | ||||
|        "version_minor": 0 | ||||
|       }, | ||||
|       "text/plain": [ | ||||
|        "Output()" | ||||
|       ] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "----------------------------------\n", | ||||
|       "| rollout/            |          |\n", | ||||
|       "|    ep_len_mean      | 4.97     |\n", | ||||
|       "|    ep_rew_mean      | -63.8    |\n", | ||||
|       "|    exploration_rate | 0.05     |\n", | ||||
|       "| time/               |          |\n", | ||||
|       "|    episodes         | 10000    |\n", | ||||
|       "|    fps              | 1628     |\n", | ||||
|       "|    time_elapsed     | 30       |\n", | ||||
|       "|    total_timesteps  | 49995    |\n", | ||||
|       "----------------------------------\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "----------------------------------\n", | ||||
|       "| rollout/            |          |\n", | ||||
|       "|    ep_len_mean      | 5        |\n", | ||||
|       "|    ep_rew_mean      | -70.5    |\n", | ||||
|       "|    exploration_rate | 0.05     |\n", | ||||
|       "| time/               |          |\n", | ||||
|       "|    episodes         | 20000    |\n", | ||||
|       "|    fps              | 662      |\n", | ||||
|       "|    time_elapsed     | 150      |\n", | ||||
|       "|    total_timesteps  | 99992    |\n", | ||||
|       "| train/              |          |\n", | ||||
|       "|    learning_rate    | 0.0001   |\n", | ||||
|       "|    loss             | 11.7     |\n", | ||||
|       "|    n_updates        | 12497    |\n", | ||||
|       "----------------------------------\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/html": [ | ||||
|        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | ||||
|       ], | ||||
|       "text/plain": [] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/html": [ | ||||
|        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | ||||
|        "</pre>\n" | ||||
|       ], | ||||
|       "text/plain": [ | ||||
|        "\n" | ||||
|       ] | ||||
|      }, | ||||
|      "metadata": {}, | ||||
|      "output_type": "display_data" | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 3, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "total_timesteps = 100_000\n", | ||||
|     "model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n", | ||||
|     "model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model.save(\"dqn_new_state\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n", | ||||
|       "Exception: code() argument 13 must be str, not int\n", | ||||
|       "  warnings.warn(\n", | ||||
|       "c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n", | ||||
|       "Exception: code() argument 13 must be str, not int\n", | ||||
|       "  warnings.warn(\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# model = DQN.load(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n", | ||||
|       " 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", | ||||
|       " 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n", | ||||
|       " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", | ||||
|       "0\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym_wordle.wordle.WordleEnv()\n", | ||||
|     "\n", | ||||
|     "for i in range(1000):\n", | ||||
|     "        \n", | ||||
|     "    state, info = env.reset()\n", | ||||
|     "\n", | ||||
|     "    done = False\n", | ||||
|     "\n", | ||||
|     "    wins = 0\n", | ||||
|     "\n", | ||||
|     "    while not done:\n", | ||||
|     "\n", | ||||
|     "        action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "        state, reward, done, truncated, info = env.step(action)\n", | ||||
|     "\n", | ||||
|     "    print(state)\n", | ||||
|     "    if info[\"correct\"]:\n", | ||||
|     "        wins += 1\n", | ||||
|     "\n", | ||||
|     "print(wins)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n", | ||||
|        "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | ||||
|        "        1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | ||||
|        "        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | ||||
|        "        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | ||||
|        "        0., 0., 0., 0., 0., 0., 0., 1.]),\n", | ||||
|        " -50)" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 6, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "state, reward" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3 (ipykernel)", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.11.5" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
							
								
								
									
										38
									
								
								dqn_wordle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								dqn_wordle.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| import gym | ||||
| import sys | ||||
| from stable_baselines3 import DQN | ||||
| from stable_baselines3.common.env_util import make_vec_env | ||||
| import wordle_gym | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
|  | ||||
| def train (model, env, total_timesteps = 100000):  | ||||
|     model.learn(total_timesteps=total_timesteps, progress_bar=True) | ||||
|     model.save("dqn_wordle") | ||||
|  | ||||
| def test(model, env, test_num=1000): | ||||
|  | ||||
|     total_correct = 0 | ||||
|  | ||||
|     for i in tqdm(range(test_num)): | ||||
|  | ||||
|         model = DQN.load("dqn_wordle") | ||||
|  | ||||
|         env = gym.make("wordle-v0") | ||||
|         obs = env.reset() | ||||
|         done = False | ||||
|         while not done: | ||||
|             action, _states = model.predict(obs) | ||||
|             obs, rewards, done, info = env.step(action) | ||||
|              | ||||
|     return total_correct / test_num | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|      | ||||
|     env = gym.make("wordle-v0") | ||||
|     model = DQN("MlpPolicy", env, verbose=0) | ||||
|     print(env) | ||||
|     print(model) | ||||
|  | ||||
|     train(model, env, total_timesteps=500000) | ||||
|     print(test(model, env)) | ||||
							
								
								
									
										129
									
								
								eric_wordle/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										129
									
								
								eric_wordle/.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,129 +0,0 @@ | ||||
| # Byte-compiled / optimized / DLL files | ||||
| __pycache__/ | ||||
| *.py[cod] | ||||
| *$py.class | ||||
|  | ||||
| # C extensions | ||||
| *.so | ||||
|  | ||||
| # Distribution / packaging | ||||
| .Python | ||||
| build/ | ||||
| develop-eggs/ | ||||
| dist/ | ||||
| downloads/ | ||||
| eggs/ | ||||
| .eggs/ | ||||
| lib/ | ||||
| lib64/ | ||||
| parts/ | ||||
| sdist/ | ||||
| var/ | ||||
| wheels/ | ||||
| pip-wheel-metadata/ | ||||
| share/python-wheels/ | ||||
| *.egg-info/ | ||||
| .installed.cfg | ||||
| *.egg | ||||
| MANIFEST | ||||
|  | ||||
| # PyInstaller | ||||
| #  Usually these files are written by a python script from a template | ||||
| #  before PyInstaller builds the exe, so as to inject date/other infos into it. | ||||
| *.manifest | ||||
| *.spec | ||||
|  | ||||
| # Installer logs | ||||
| pip-log.txt | ||||
| pip-delete-this-directory.txt | ||||
|  | ||||
| # Unit test / coverage reports | ||||
| htmlcov/ | ||||
| .tox/ | ||||
| .nox/ | ||||
| .coverage | ||||
| .coverage.* | ||||
| .cache | ||||
| nosetests.xml | ||||
| coverage.xml | ||||
| *.cover | ||||
| *.py,cover | ||||
| .hypothesis/ | ||||
| .pytest_cache/ | ||||
|  | ||||
| # Translations | ||||
| *.mo | ||||
| *.pot | ||||
|  | ||||
| # Django stuff: | ||||
| *.log | ||||
| local_settings.py | ||||
| db.sqlite3 | ||||
| db.sqlite3-journal | ||||
|  | ||||
| # Flask stuff: | ||||
| instance/ | ||||
| .webassets-cache | ||||
|  | ||||
| # Scrapy stuff: | ||||
| .scrapy | ||||
|  | ||||
| # Sphinx documentation | ||||
| docs/_build/ | ||||
|  | ||||
| # PyBuilder | ||||
| target/ | ||||
|  | ||||
| # Jupyter Notebook | ||||
| .ipynb_checkpoints | ||||
|  | ||||
| # IPython | ||||
| profile_default/ | ||||
| ipython_config.py | ||||
|  | ||||
| # pyenv | ||||
| .python-version | ||||
|  | ||||
| # pipenv | ||||
| #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||||
| #   However, in case of collaboration, if having platform-specific dependencies or dependencies | ||||
| #   having no cross-platform support, pipenv may install dependencies that don't work, or not | ||||
| #   install all needed dependencies. | ||||
| #Pipfile.lock | ||||
|  | ||||
| # PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||||
| __pypackages__/ | ||||
|  | ||||
| # Celery stuff | ||||
| celerybeat-schedule | ||||
| celerybeat.pid | ||||
|  | ||||
| # SageMath parsed files | ||||
| *.sage.py | ||||
|  | ||||
| # Environments | ||||
| .env | ||||
| .venv | ||||
| env/ | ||||
| venv/ | ||||
| ENV/ | ||||
| env.bak/ | ||||
| venv.bak/ | ||||
|  | ||||
| # Spyder project settings | ||||
| .spyderproject | ||||
| .spyproject | ||||
|  | ||||
| # Rope project settings | ||||
| .ropeproject | ||||
|  | ||||
| # mkdocs documentation | ||||
| /site | ||||
|  | ||||
| # mypy | ||||
| .mypy_cache/ | ||||
| .dmypy.json | ||||
| dmypy.json | ||||
|  | ||||
| # Pyre type checker | ||||
| .pyre/ | ||||
| @@ -1,11 +0,0 @@ | ||||
| # N-dle Solver | ||||
|  | ||||
| A solver designed to beat New York Time's Wordle (link [here](https://www.nytimes.com/games/wordle/index.html)). If you are bored enough, can extend to solve the more general N-dle problem (for quordle, octordle, etc.) | ||||
|  | ||||
| I originally made this out of frustration for the game (and my own lack of lingual talent). One day, my friend thought she could beat my bot. To her dismay, she learned that she is no better than a machine. Let's see if you can do any better (the average number of attempts is 3.6). | ||||
|  | ||||
| ## Usage: | ||||
| 1. Run `python main.py --n 1` | ||||
| 2. Follow the prompts | ||||
|  | ||||
| Currently only supports solving for 1 word at a time (i.e. wordle). | ||||
| @@ -1,126 +0,0 @@ | ||||
| import re | ||||
| import string | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class AI: | ||||
|     def __init__(self, vocab_file, num_letters=5, num_guesses=6): | ||||
|         self.vocab_file = vocab_file | ||||
|         self.num_letters = num_letters | ||||
|         self.num_guesses = 6 | ||||
|  | ||||
|         self.vocab, self.vocab_scores, self.letter_scores = self.get_vocab(self.vocab_file) | ||||
|         self.best_words = sorted(list(self.vocab_scores.items()), key=lambda tup: tup[1])[::-1] | ||||
|  | ||||
|         self.domains = None | ||||
|         self.possible_letters = None | ||||
|  | ||||
|         self.reset() | ||||
|  | ||||
|     def solve(self): | ||||
|         num_guesses = 0 | ||||
|         while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: | ||||
|             num_guesses += 1 | ||||
|             word = self.sample() | ||||
|  | ||||
|             # # Always start with these two words | ||||
|             # if num_guesses == 1: | ||||
|             #     word = 'soare' | ||||
|             # elif num_guesses == 2: | ||||
|             #     word = 'culti' | ||||
|  | ||||
|             print('-----------------------------------------------') | ||||
|             print(f'Guess #{num_guesses}/{self.num_guesses}: {word}') | ||||
|             print('-----------------------------------------------') | ||||
|             self.arc_consistency(word) | ||||
|  | ||||
|         print(f'You did it! The word is {"".join([e[0] for e in self.domains])}') | ||||
|  | ||||
|  | ||||
|     def arc_consistency(self, word): | ||||
|         print(f'Performing arc consistency check on {word}...') | ||||
|         print(f'Specify 0 for completely nonexistent letter at the specified index, 1 for existent letter but incorrect index, and 2 for correct letter at correct index.') | ||||
|         results = [] | ||||
|  | ||||
|         # Collect results | ||||
|         for l in word: | ||||
|             while True: | ||||
|                 result = input(f'{l}: ') | ||||
|                 if result not in ['0', '1', '2']: | ||||
|                     print('Incorrect option. Try again.') | ||||
|                     continue | ||||
|                 results.append(result) | ||||
|                 break | ||||
|  | ||||
|         self.possible_letters += [word[i] for i in range(len(word)) if results[i] == '1'] | ||||
|  | ||||
|         for i in range(len(word)): | ||||
|             if results[i] == '0': | ||||
|                 if word[i] in self.possible_letters: | ||||
|                     if word[i] in self.domains[i]: | ||||
|                         self.domains[i].remove(word[i]) | ||||
|                 else: | ||||
|                     for j in range(len(self.domains)): | ||||
|                         if word[i] in self.domains[j] and len(self.domains[j]) > 1: | ||||
|                             self.domains[j].remove(word[i]) | ||||
|             if results[i] == '1': | ||||
|                 if word[i] in self.domains[i]: | ||||
|                     self.domains[i].remove(word[i]) | ||||
|             if results[i] == '2': | ||||
|                 self.domains[i] = [word[i]] | ||||
|  | ||||
|  | ||||
|     def reset(self): | ||||
|         self.domains = [list(string.ascii_lowercase) for _ in range(self.num_letters)] | ||||
|         self.possible_letters = [] | ||||
|  | ||||
|     def sample(self): | ||||
|         """ | ||||
|         Samples a best word given the current domains | ||||
|         :return: | ||||
|         """ | ||||
|         # Compile a regex of possible words with the current domain | ||||
|         regex_string = '' | ||||
|         for domain in self.domains: | ||||
|             regex_string += ''.join(['[', ''.join(domain), ']', '{1}']) | ||||
|         pattern = re.compile(regex_string) | ||||
|  | ||||
|         # From the words with the highest scores, only return the best word that match the regex pattern | ||||
|         for word, _ in self.best_words: | ||||
|             if pattern.match(word) and False not in [e in word for e in self.possible_letters]: | ||||
|                 return word | ||||
|  | ||||
|     def get_vocab(self, vocab_file): | ||||
|         vocab = [] | ||||
|         with open(vocab_file, 'r') as f: | ||||
|             for l in f: | ||||
|                 vocab.append(l.strip()) | ||||
|  | ||||
|         # Count letter frequencies at each index | ||||
|         letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)] | ||||
|         for word in vocab: | ||||
|             for i, l in enumerate(word): | ||||
|                 letter_freqs[i][l] += 1 | ||||
|  | ||||
|         # Assign a score to each letter at each index by the probability of it appearing | ||||
|         letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(self.num_letters)] | ||||
|         for i in range(len(letter_scores)): | ||||
|             max_freq = np.max(list(letter_freqs[i].values())) | ||||
|             for l in letter_scores[i].keys(): | ||||
|                 letter_scores[i][l] = letter_freqs[i][l] / max_freq | ||||
|  | ||||
|         # Find a sorted list of words ranked by sum of letter scores | ||||
|         vocab_scores = {}  # (score, word) | ||||
|         for word in vocab: | ||||
|             score = 0 | ||||
|             for i, l in enumerate(word): | ||||
|                 score += letter_scores[i][l] | ||||
|  | ||||
|             # # Optimization: If repeating letters, deduct a couple points | ||||
|             # if len(set(word)) < len(word): | ||||
|             #     score -= 0.25 * (len(word) - len(set(word))) | ||||
|  | ||||
|             vocab_scores[word] = score | ||||
|  | ||||
|         return vocab, vocab_scores, letter_scores | ||||
| @@ -1,37 +0,0 @@ | ||||
| import string | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| words = [] | ||||
| with open('words.txt', 'r') as f: | ||||
|     for l in f: | ||||
|         words.append(l.strip()) | ||||
|  | ||||
| # Count letter frequencies at each index | ||||
| letter_freqs = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)] | ||||
| for word in words: | ||||
|     for i, l in enumerate(word): | ||||
|         letter_freqs[i][l] += 1 | ||||
|  | ||||
| # Assign a score to each letter at each index by the probability of it appearing | ||||
| letter_scores = [{letter: 0 for letter in string.ascii_lowercase} for _ in range(5)] | ||||
| for i in range(len(letter_scores)): | ||||
|     max_freq = np.max(list(letter_freqs[i].values())) | ||||
|     for l in letter_scores[i].keys(): | ||||
|         letter_scores[i][l] = letter_freqs[i][l] / max_freq | ||||
|  | ||||
| # Find a sorted list of words ranked by sum of letter scores | ||||
| word_scores = []  # (score, word) | ||||
| for word in words: | ||||
|     score = 0 | ||||
|     for i, l in enumerate(word): | ||||
|         score += letter_scores[i][l] | ||||
|     word_scores.append((score, word)) | ||||
|  | ||||
| sorted_by_second = sorted(word_scores, key=lambda tup: tup[0])[::-1] | ||||
| print(sorted_by_second[:10]) | ||||
|  | ||||
| for i, (score, word) in enumerate(sorted_by_second): | ||||
|     if word == 'soare': | ||||
|         print(f'{word} with a score of {score} is found at index {i}') | ||||
|  | ||||
| @@ -1,18 +0,0 @@ | ||||
| import argparse | ||||
| from ai import AI | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     if args.n is None: | ||||
|         raise Exception('Need to specify n (i.e. n = 1 for wordle, n = 4 for quordle, n = 16 for sedecordle).') | ||||
|  | ||||
|     ai = AI(args.vocab_file) | ||||
|     ai.solve() | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--n', dest='n', type=int, default=None) | ||||
|     parser.add_argument('--vocab_file', dest='vocab_file', type=str, default='wordle_words.txt') | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
| @@ -1,15 +0,0 @@ | ||||
| import pandas | ||||
|  | ||||
| print('Loading in words dictionary; this may take a while...') | ||||
| df = pandas.read_json('words_dictionary.json') | ||||
| print('Done loading words dictionary.') | ||||
| words = [] | ||||
| for word in df.axes[0].tolist(): | ||||
|     if len(word) != 5: | ||||
|         continue | ||||
|     words.append(word) | ||||
| words.sort() | ||||
|  | ||||
| with open('words.txt', 'w') as f: | ||||
|     for word in words: | ||||
|         f.write(word + '\n') | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										108
									
								
								letter_guess.py
									
									
									
									
									
								
							
							
						
						
									
										108
									
								
								letter_guess.py
									
									
									
									
									
								
							| @@ -1,108 +0,0 @@ | ||||
| import gymnasium as gym | ||||
| from gymnasium import spaces | ||||
| import numpy as np | ||||
| import random | ||||
| import re | ||||
|  | ||||
|  | ||||
| class LetterGuessingEnv(gym.Env): | ||||
|     """ | ||||
|     Custom Gymnasium environment for a letter guessing game with a focus on forming | ||||
|     valid prefixes and words from a list of valid Wordle words. The environment tracks | ||||
|     the current guess prefix and validates it against known valid words, ending the game | ||||
|     early with a negative reward for invalid prefixes. | ||||
|     """ | ||||
|  | ||||
|     metadata = {'render_modes': ['human']} | ||||
|  | ||||
|     def __init__(self, valid_words, seed=None): | ||||
|         self.action_space = spaces.Discrete(26) | ||||
|         self.observation_space = spaces.Box(low=0, high=1, shape=(26*2 + 26*4,), dtype=np.int32) | ||||
|  | ||||
|         self.valid_words = valid_words  # List of valid Wordle words | ||||
|         self.target_word = ''  # Target word for the current episode | ||||
|         self.valid_words_str = ' '.join(self.valid_words) + ' ' | ||||
|         self.letter_flags = None | ||||
|         self.letter_positions = None | ||||
|         self.guessed_letters = set() | ||||
|         self.guess_prefix = ""  # Tracks the current guess prefix | ||||
|  | ||||
|         self.reset() | ||||
|  | ||||
|     def step(self, action): | ||||
|         letter_index = action % 26  # Assuming action is the letter index directly | ||||
|         position = len(self.guess_prefix)  # The next position in the prefix is determined by its current length | ||||
|         letter = chr(ord('a') + letter_index) | ||||
|  | ||||
|         reward = 0 | ||||
|         done = False | ||||
|  | ||||
|         # Check if the letter has already been used in the guess prefix | ||||
|         if letter in self.guessed_letters: | ||||
|             reward = -1  # Penalize for repeating letters in the prefix | ||||
|         else: | ||||
|             # Add the new letter to the prefix and update guessed letters set | ||||
|             self.guess_prefix += letter | ||||
|             self.guessed_letters.add(letter) | ||||
|  | ||||
|             # Update letter flags based on whether the letter is in the target word | ||||
|             if self.target_word[position] == letter: | ||||
|                 self.letter_flags[letter_index, :] = [1, 0]  # Update flag for correct guess | ||||
|             elif letter in self.target_word: | ||||
|                 self.letter_flags[letter_index, :] = [0, 1]  # Update flag for correct guess wrong position | ||||
|             else: | ||||
|                 self.letter_flags[letter_index, :] = [0, 0]  # Update flag for incorrect guess | ||||
|  | ||||
|             reward = 1  # Reward for adding new information by trying a new letter | ||||
|  | ||||
|             # Update the letter_positions matrix to reflect the new guess | ||||
|             if position == 4:  | ||||
|                 self.letter_positions[:,:] = 1 | ||||
|             else: | ||||
|                 self.letter_positions[:, position] = 0 | ||||
|                 self.letter_positions[letter_index, position] = 1 | ||||
|  | ||||
|         # Use regex to check if the current prefix can lead to a valid word | ||||
|         if not re.search(r'\b' + self.guess_prefix, self.valid_words_str): | ||||
|             reward = -5  # Penalize for forming an invalid prefix | ||||
|             done = True  # End the episode if the prefix is invalid | ||||
|  | ||||
|         # guessed a full word so we reset our guess prefix to guess next round | ||||
|         if len(self.guess_prefix) == len(self.target_word): | ||||
|             self.guess_prefix = '' | ||||
|             self.round += 1 | ||||
|  | ||||
|         # end after 5 rounds of total guesses | ||||
|         if self.round == 2: | ||||
|             # reward = 5 | ||||
|             done = True | ||||
|  | ||||
|         obs = self._get_obs() | ||||
|          | ||||
|         if reward < -50: | ||||
|             print(obs, reward, done) | ||||
|  | ||||
|         return obs, reward, done, False, {} | ||||
|  | ||||
|     def reset(self, seed=None): | ||||
|         self.target_word = random.choice(self.valid_words) | ||||
|         # self.target_word_encoded = self.encode_word(self.target_word) | ||||
|         self.letter_flags = np.ones((26, 2), dtype=np.int32) | ||||
|         self.letter_positions = np.ones((26, 4), dtype=np.int32) | ||||
|         self.guessed_letters = set() | ||||
|         self.guess_prefix = ""  # Reset the guess prefix for the new episode | ||||
|         self.round = 1 | ||||
|         return self._get_obs(), {} | ||||
|  | ||||
|     def encode_word(self, word): | ||||
|         encoded = np.zeros((26,)) | ||||
|         for char in word: | ||||
|             index = ord(char) - ord('a') | ||||
|             encoded[index] = 1 | ||||
|         return encoded | ||||
|  | ||||
|     def _get_obs(self): | ||||
|         return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()]) | ||||
|  | ||||
|     def render(self, mode='human'): | ||||
|         pass  # Optional: Implement rendering logic if needed | ||||
							
								
								
									
										9
									
								
								wordle_gym/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								wordle_gym/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| from gym.envs.registration import register | ||||
|  | ||||
| register( | ||||
|     id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv", | ||||
| ) | ||||
|  | ||||
| register( | ||||
|     id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv", | ||||
| ) | ||||
							
								
								
									
										0
									
								
								wordle_gym/envs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								wordle_gym/envs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										15
									
								
								wordle_gym/envs/strategies/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								wordle_gym/envs/strategies/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| from enum import Enum | ||||
|  | ||||
| from typing import List | ||||
|  | ||||
| class StrategyType(Enum): | ||||
|     RANDOM = 1 | ||||
|     ELIMINATION = 2 | ||||
|     PROBABILITY = 3 | ||||
|  | ||||
| class Strategy: | ||||
|     def __init__(self, type: StrategyType): | ||||
|         self.type = type | ||||
|      | ||||
|     def get_best_word(self, guesses: List[List[str]], state: List[List[int]]): | ||||
|         raise NotImplementedError("Strategy.get_best_word() not implemented") | ||||
							
								
								
									
										2
									
								
								wordle_gym/envs/strategies/elimination.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								wordle_gym/envs/strategies/elimination.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| def get_best_word(state): | ||||
|      | ||||
							
								
								
									
										20
									
								
								wordle_gym/envs/strategies/probabilistic.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								wordle_gym/envs/strategies/probabilistic.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| from random import sample | ||||
| from typing import List | ||||
|  | ||||
| from base import Strategy | ||||
| from base import StrategyType | ||||
|  | ||||
| from utils import freq | ||||
|  | ||||
| class Random(Strategy): | ||||
|     def __init__(self): | ||||
|         self.words = freq.get_5_letter_word_freqs() | ||||
|         super().__init__(StrategyType.RANDOM) | ||||
|  | ||||
|     def get_best_word(self, state: List[List[int]]): | ||||
|          | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     r = Random() | ||||
|     print(r.get_best_word([])) | ||||
							
								
								
									
										29
									
								
								wordle_gym/envs/strategies/rand.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								wordle_gym/envs/strategies/rand.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from random import sample | ||||
| from typing import List | ||||
|  | ||||
| from base import Strategy | ||||
| from base import StrategyType | ||||
|  | ||||
| from utils import freq | ||||
|  | ||||
| class Random(Strategy): | ||||
|     def __init__(self): | ||||
|         self.words = freq.get_5_letter_word_freqs() | ||||
|         super().__init__(StrategyType.RANDOM) | ||||
|  | ||||
|     def get_best_word(self, guesses: List[List[str]], state: List[List[int]]): | ||||
|         correct_letters = [] | ||||
|         regex = "" | ||||
|         for g, s in zip(guesses, state): | ||||
|             for c, s in zip(g, s): | ||||
|                 if s == 2: | ||||
|                     correct_letters.append(c) | ||||
|                     regex += c | ||||
|                  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     r = Random() | ||||
|     print(r.get_best_word([])) | ||||
							
								
								
									
										27
									
								
								wordle_gym/envs/strategies/utils/freq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								wordle_gym/envs/strategies/utils/freq.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,27 @@ | ||||
| from os import path | ||||
|  | ||||
| def get_5_letter_word_freqs(): | ||||
|     """ | ||||
|     Returns a list of words with 5 letters. | ||||
|     """ | ||||
|     FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt") | ||||
|     lines = read_file(FILEPATH) | ||||
|     return {k:v for k, v in get_freq(lines).items() if len(k) == 5} | ||||
|  | ||||
|  | ||||
| def read_file(filename): | ||||
|     """ | ||||
|     Reads a file and returns a list of words and frequencies | ||||
|     """ | ||||
|     with open(filename, 'r') as f: | ||||
|         return f.readlines() | ||||
|  | ||||
|  | ||||
| def get_freq(lines): | ||||
|     """ | ||||
|     Returns a dictionary of words and their frequencies | ||||
|     """ | ||||
|     freqs = {} | ||||
|     for word, freq in map(lambda x: x.split("\t"), lines): | ||||
|         freqs[word] = int(freq) | ||||
|     return freqs | ||||
							
								
								
									
										131
									
								
								wordle_gym/envs/wordle_env.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								wordle_gym/envs/wordle_env.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | ||||
| import os | ||||
|  | ||||
|  | ||||
| import gym | ||||
| from gym import error, spaces, utils | ||||
| from gym.utils import seeding | ||||
|  | ||||
| from enum import Enum | ||||
| from collections import Counter | ||||
| import numpy as np | ||||
|  | ||||
| WORD_LENGTH = 5 | ||||
| TOTAL_GUESSES = 6 | ||||
| SOLUTION_PATH = "../words/solution.csv" | ||||
| VALID_WORDS_PATH = "../words/guess.csv" | ||||
|  | ||||
| class LetterState(Enum): | ||||
|     ABSENT = 0 | ||||
|     PRESENT = 1 | ||||
|     CORRECT_POSITION = 2 | ||||
|  | ||||
|  | ||||
| class WordleEnv(gym.Env): | ||||
|     metadata = {"render.modes": ["human"]} | ||||
|  | ||||
|     def _current_path(self): | ||||
|         return os.path.dirname(os.path.abspath(__file__)) | ||||
|  | ||||
|     def _read_solutions(self): | ||||
|         return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines() | ||||
|      | ||||
|     def _get_valid_words(self): | ||||
|         words = [] | ||||
|         for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines(): | ||||
|             words.append((word, Counter(word))) | ||||
|         return words | ||||
|  | ||||
|     def get_valid(self): | ||||
|         return self._valid_words | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._solutions = self._read_solutions() | ||||
|         self._valid_words = self._get_valid_words() | ||||
|         self.action_space = spaces.Discrete(len(self._valid_words)) | ||||
|         self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH) | ||||
|         np.random.seed(0) | ||||
|         self.reset() | ||||
|      | ||||
|     def _check_guess(self, guess, guess_counter): | ||||
|         c = guess_counter & self.solution_ct | ||||
|         result = [] | ||||
|         correct = True | ||||
|         reward = 0 | ||||
|         for i, char in enumerate(guess): | ||||
|             if c.get(char, 0) > 0: | ||||
|                 if self.solution[i] == char: | ||||
|                     result.append(2) | ||||
|                     reward += 2 | ||||
|                 else: | ||||
|                     result.append(1) | ||||
|                     correct = False | ||||
|                     reward += 1 | ||||
|                 c[char] -= 1 | ||||
|             else: | ||||
|                 result.append(0) | ||||
|                 correct = False | ||||
|         return result, correct, reward | ||||
|  | ||||
|     def step(self, action): | ||||
|         """ | ||||
|         action: index of word in valid_words | ||||
|  | ||||
|         returns: | ||||
|             observation: (TOTAL_GUESSES, WORD_LENGTH) | ||||
|             reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained | ||||
|             done: True if game over, w/ or w/o correct answer | ||||
|             additional_info: empty | ||||
|         """ | ||||
|         guess, guess_counter = self._valid_words[action] | ||||
|         if guess in self.guesses: | ||||
|             return self.obs, -1, False, {} | ||||
|         self.guesses.append(guess) | ||||
|         result, correct, reward = self._check_guess(guess, guess_counter) | ||||
|         done = False | ||||
|  | ||||
|         for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH): | ||||
|             self.obs[i] = result[i - self.guess_no*WORD_LENGTH] | ||||
|          | ||||
|         self.guess_no += 1 | ||||
|         if correct: | ||||
|             done = True | ||||
|             reward = 1200 | ||||
|         if self.guess_no == TOTAL_GUESSES: | ||||
|             done = True | ||||
|             if not correct: | ||||
|                 reward = -15 | ||||
|         return self.obs, reward, done, {} | ||||
|  | ||||
|     def reset(self): | ||||
|         self.solution = self._solutions[np.random.randint(len(self._solutions))] | ||||
|         self.solution_ct = Counter(self.solution) | ||||
|         self.guess_no = 0 | ||||
|         self.guesses = [] | ||||
|         self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, )) | ||||
|         return self.obs | ||||
|  | ||||
|     def render(self, mode="human"): | ||||
|         m = { | ||||
|             0: "⬜", | ||||
|             1: "🟨", | ||||
|             2: "🟩" | ||||
|         } | ||||
|         print("Solution:", self.solution) | ||||
|         for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))): | ||||
|             o_n = "".join(map(lambda x: m[x], o)) | ||||
|             print(g, o_n) | ||||
|  | ||||
|     def close(self): | ||||
|         pass | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     env = WordleEnv() | ||||
|     print(env.action_space) | ||||
|     print(env.observation_space) | ||||
|     print(env.solution) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
|     print(env.step(0)) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user