mirror of
				https://github.com/ltcptgeneral/cse151b-final-project.git
				synced 2025-10-22 18:49:21 +00:00 
			
		
		
		
	Compare commits
	
		
			5 Commits
		
	
	
		
			arthur-tes
			...
			gymnasium-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | cf977e4797 | ||
|  | bbe9a1891c | ||
|  | 9172326013 | ||
|  | 4836be8121 | ||
|  | 5672169073 | 
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -1,2 +1,4 @@ | ||||
| **/data/* | ||||
| **/*.zip | ||||
| **/*.zip | ||||
| **/__pycache__ | ||||
| /env | ||||
							
								
								
									
										345
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							
							
						
						
									
										345
									
								
								dqn_wordle.ipynb
									
									
									
									
									
								
							| @@ -2,69 +2,268 @@ | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import gym\n", | ||||
|     "import gym_wordle\n", | ||||
|     "from stable_baselines3 import DQN\n", | ||||
|     "from stable_baselines3 import DQN, PPO, common\n", | ||||
|     "import numpy as np\n", | ||||
|     "import tqdm" | ||||
|     "from tqdm import tqdm" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "<Monitor<WordleEnv instance>>\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym.make(\"Wordle-v0\")\n", | ||||
|     "env = gym_wordle.wordle.WordleEnv()\n", | ||||
|     "env = common.monitor.Monitor(env)\n", | ||||
|     "\n", | ||||
|     "print(env)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 35, | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Using cuda device\n", | ||||
|       "Wrapping the env in a DummyVecEnv.\n", | ||||
|       "---------------------------------\n", | ||||
|       "| rollout/           |          |\n", | ||||
|       "|    ep_len_mean     | 6        |\n", | ||||
|       "|    ep_rew_mean     | 2.14     |\n", | ||||
|       "| time/              |          |\n", | ||||
|       "|    fps             | 750      |\n", | ||||
|       "|    iterations      | 1        |\n", | ||||
|       "|    time_elapsed    | 2        |\n", | ||||
|       "|    total_timesteps | 2048     |\n", | ||||
|       "---------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 6           |\n", | ||||
|       "|    ep_rew_mean          | 4.59        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 625         |\n", | ||||
|       "|    iterations           | 2           |\n", | ||||
|       "|    time_elapsed         | 6           |\n", | ||||
|       "|    total_timesteps      | 4096        |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.022059526 |\n", | ||||
|       "|    clip_fraction        | 0.331       |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.47       |\n", | ||||
|       "|    explained_variance   | -0.0118     |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 130         |\n", | ||||
|       "|    n_updates            | 10          |\n", | ||||
|       "|    policy_gradient_loss | -0.0851     |\n", | ||||
|       "|    value_loss           | 253         |\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 6           |\n", | ||||
|       "|    ep_rew_mean          | 5.86        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 585         |\n", | ||||
|       "|    iterations           | 3           |\n", | ||||
|       "|    time_elapsed         | 10          |\n", | ||||
|       "|    total_timesteps      | 6144        |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.024416003 |\n", | ||||
|       "|    clip_fraction        | 0.462       |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.47       |\n", | ||||
|       "|    explained_variance   | 0.152       |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 85.2        |\n", | ||||
|       "|    n_updates            | 20          |\n", | ||||
|       "|    policy_gradient_loss | -0.0987     |\n", | ||||
|       "|    value_loss           | 218         |\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 6           |\n", | ||||
|       "|    ep_rew_mean          | 4.75        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 566         |\n", | ||||
|       "|    iterations           | 4           |\n", | ||||
|       "|    time_elapsed         | 14          |\n", | ||||
|       "|    total_timesteps      | 8192        |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.026305672 |\n", | ||||
|       "|    clip_fraction        | 0.45        |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.47       |\n", | ||||
|       "|    explained_variance   | 0.161       |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 144         |\n", | ||||
|       "|    n_updates            | 30          |\n", | ||||
|       "|    policy_gradient_loss | -0.105      |\n", | ||||
|       "|    value_loss           | 220         |\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "----------------------------------------\n", | ||||
|       "| rollout/                |            |\n", | ||||
|       "|    ep_len_mean          | 6          |\n", | ||||
|       "|    ep_rew_mean          | 1.47       |\n", | ||||
|       "| time/                   |            |\n", | ||||
|       "|    fps                  | 554        |\n", | ||||
|       "|    iterations           | 5          |\n", | ||||
|       "|    time_elapsed         | 18         |\n", | ||||
|       "|    total_timesteps      | 10240      |\n", | ||||
|       "| train/                  |            |\n", | ||||
|       "|    approx_kl            | 0.02928267 |\n", | ||||
|       "|    clip_fraction        | 0.498      |\n", | ||||
|       "|    clip_range           | 0.2        |\n", | ||||
|       "|    entropy_loss         | -9.46      |\n", | ||||
|       "|    explained_variance   | 0.167      |\n", | ||||
|       "|    learning_rate        | 0.0003     |\n", | ||||
|       "|    loss                 | 127        |\n", | ||||
|       "|    n_updates            | 40         |\n", | ||||
|       "|    policy_gradient_loss | -0.116     |\n", | ||||
|       "|    value_loss           | 207        |\n", | ||||
|       "----------------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 6           |\n", | ||||
|       "|    ep_rew_mean          | 1.62        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 546         |\n", | ||||
|       "|    iterations           | 6           |\n", | ||||
|       "|    time_elapsed         | 22          |\n", | ||||
|       "|    total_timesteps      | 12288       |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.028425258 |\n", | ||||
|       "|    clip_fraction        | 0.483       |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.46       |\n", | ||||
|       "|    explained_variance   | 0.143       |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 109         |\n", | ||||
|       "|    n_updates            | 50          |\n", | ||||
|       "|    policy_gradient_loss | -0.117      |\n", | ||||
|       "|    value_loss           | 240         |\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 5.98        |\n", | ||||
|       "|    ep_rew_mean          | 6.14        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 541         |\n", | ||||
|       "|    iterations           | 7           |\n", | ||||
|       "|    time_elapsed         | 26          |\n", | ||||
|       "|    total_timesteps      | 14336       |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.026178032 |\n", | ||||
|       "|    clip_fraction        | 0.453       |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.46       |\n", | ||||
|       "|    explained_variance   | 0.174       |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 141         |\n", | ||||
|       "|    n_updates            | 60          |\n", | ||||
|       "|    policy_gradient_loss | -0.116      |\n", | ||||
|       "|    value_loss           | 235         |\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "----------------------------------------\n", | ||||
|       "| rollout/                |            |\n", | ||||
|       "|    ep_len_mean          | 6          |\n", | ||||
|       "|    ep_rew_mean          | 3.03       |\n", | ||||
|       "| time/                   |            |\n", | ||||
|       "|    fps                  | 537        |\n", | ||||
|       "|    iterations           | 8          |\n", | ||||
|       "|    time_elapsed         | 30         |\n", | ||||
|       "|    total_timesteps      | 16384      |\n", | ||||
|       "| train/                  |            |\n", | ||||
|       "|    approx_kl            | 0.02457074 |\n", | ||||
|       "|    clip_fraction        | 0.423      |\n", | ||||
|       "|    clip_range           | 0.2        |\n", | ||||
|       "|    entropy_loss         | -9.45      |\n", | ||||
|       "|    explained_variance   | 0.171      |\n", | ||||
|       "|    learning_rate        | 0.0003     |\n", | ||||
|       "|    loss                 | 111        |\n", | ||||
|       "|    n_updates            | 70         |\n", | ||||
|       "|    policy_gradient_loss | -0.112     |\n", | ||||
|       "|    value_loss           | 212        |\n", | ||||
|       "----------------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 6           |\n", | ||||
|       "|    ep_rew_mean          | 9.54        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 532         |\n", | ||||
|       "|    iterations           | 9           |\n", | ||||
|       "|    time_elapsed         | 34          |\n", | ||||
|       "|    total_timesteps      | 18432       |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.024578478 |\n", | ||||
|       "|    clip_fraction        | 0.417       |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.45       |\n", | ||||
|       "|    explained_variance   | 0.178       |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 121         |\n", | ||||
|       "|    n_updates            | 80          |\n", | ||||
|       "|    policy_gradient_loss | -0.114      |\n", | ||||
|       "|    value_loss           | 232         |\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "-----------------------------------------\n", | ||||
|       "| rollout/                |             |\n", | ||||
|       "|    ep_len_mean          | 6           |\n", | ||||
|       "|    ep_rew_mean          | 3.81        |\n", | ||||
|       "| time/                   |             |\n", | ||||
|       "|    fps                  | 527         |\n", | ||||
|       "|    iterations           | 10          |\n", | ||||
|       "|    time_elapsed         | 38          |\n", | ||||
|       "|    total_timesteps      | 20480       |\n", | ||||
|       "| train/                  |             |\n", | ||||
|       "|    approx_kl            | 0.022704324 |\n", | ||||
|       "|    clip_fraction        | 0.379       |\n", | ||||
|       "|    clip_range           | 0.2         |\n", | ||||
|       "|    entropy_loss         | -9.45       |\n", | ||||
|       "|    explained_variance   | 0.194       |\n", | ||||
|       "|    learning_rate        | 0.0003      |\n", | ||||
|       "|    loss                 | 108         |\n", | ||||
|       "|    n_updates            | 90          |\n", | ||||
|       "|    policy_gradient_loss | -0.112      |\n", | ||||
|       "|    value_loss           | 216         |\n", | ||||
|       "-----------------------------------------\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "<stable_baselines3.ppo.ppo.PPO at 0x7f86ef4ddcd0>" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 3, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "total_timesteps = 100000\n", | ||||
|     "model = DQN(\"MlpPolicy\", env, verbose=0)\n", | ||||
|     "model.learn(total_timesteps=total_timesteps, progress_bar=True)" | ||||
|     "total_timesteps = 20_000\n", | ||||
|     "model = PPO(\"MlpPolicy\", env, verbose=1, device='cuda')\n", | ||||
|     "model.learn(total_timesteps=total_timesteps)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def test(model):\n", | ||||
|     "\n", | ||||
|     "    end_rewards = []\n", | ||||
|     "\n", | ||||
|     "    for i in range(1000):\n", | ||||
|     "        \n", | ||||
|     "        state = env.reset()\n", | ||||
|     "\n", | ||||
|     "        done = False\n", | ||||
|     "\n", | ||||
|     "        while not done:\n", | ||||
|     "\n", | ||||
|     "            action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "            state, reward, done, info = env.step(action)\n", | ||||
|     "            \n", | ||||
|     "        end_rewards.append(reward == 0)\n", | ||||
|     "        \n", | ||||
|     "    return np.sum(end_rewards) / len(end_rewards)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
| @@ -73,11 +272,69 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "execution_count": 5, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "model = DQN.load(\"dqn_wordle\")" | ||||
|     "model = PPO.load(\"dqn_wordle\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "100%|██████████| 1000/1000 [00:03<00:00, 252.17it/s]" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[[ 7 18  1 19 16  3  3  3  2  3]\n", | ||||
|       " [16  9  5 14  4  3  3  3  3  3]\n", | ||||
|       " [16  9  5 14  4  3  3  3  3  3]\n", | ||||
|       " [16  9  5 14  4  3  3  3  3  3]\n", | ||||
|       " [ 7 18  1 19 16  3  3  3  2  3]\n", | ||||
|       " [ 7 18  1 19 16  3  3  3  2  3]] -54 {'correct': False, 'guesses': defaultdict(<class 'int'>, {'grasp': 3, 'piend': 3})}\n", | ||||
|       "0\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "env = gym_wordle.wordle.WordleEnv()\n", | ||||
|     "\n", | ||||
|     "for i in tqdm(range(1000)):\n", | ||||
|     "        \n", | ||||
|     "    state, info = env.reset()\n", | ||||
|     "\n", | ||||
|     "    done = False\n", | ||||
|     "\n", | ||||
|     "    wins = 0\n", | ||||
|     "\n", | ||||
|     "    while not done:\n", | ||||
|     "\n", | ||||
|     "        action, _states = model.predict(state, deterministic=True)\n", | ||||
|     "\n", | ||||
|     "        state, reward, done, truncated, info = env.step(action)\n", | ||||
|     "\n", | ||||
|     "    if info[\"correct\"]:\n", | ||||
|     "        wins += 1\n", | ||||
|     "\n", | ||||
|     "print(state, reward, info)\n", | ||||
|     "\n", | ||||
|     "print(wins)\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
| @@ -85,9 +342,7 @@ | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "print(test(model))" | ||||
|    ] | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   | ||||
							
								
								
									
										7
									
								
								gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								gym_wordle/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| from gym.envs.registration import register | ||||
| from .wordle import WordleEnv | ||||
|  | ||||
| register( | ||||
|     id='Wordle-v0', | ||||
|     entry_point='gym_wordle.wordle:WordleEnv' | ||||
| ) | ||||
							
								
								
									
										12972
									
								
								gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12972
									
								
								gym_wordle/dictionary/guess_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gym_wordle/dictionary/guess_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2315
									
								
								gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2315
									
								
								gym_wordle/dictionary/solution_list.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								gym_wordle/dictionary/solution_list.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										93
									
								
								gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								gym_wordle/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,93 @@ | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
|  | ||||
| from pathlib import Path | ||||
|  | ||||
|  | ||||
| _chars = ' abcdefghijklmnopqrstuvwxyz' | ||||
| _char_d = {c: i for i, c in enumerate(_chars)} | ||||
|  | ||||
|  | ||||
| def to_english(array: npt.NDArray[np.int64]) -> str: | ||||
|     """Converts a numpy integer array into a corresponding English string. | ||||
|  | ||||
|     Args: | ||||
|         array: Word in array (int) form. It is assumed that each integer in the | ||||
|           array is between 0,...,26 (inclusive). | ||||
|  | ||||
|     Returns: | ||||
|         A (lowercase) string representation of the word.     | ||||
|     """ | ||||
|     return ''.join(_chars[i] for i in array) | ||||
|  | ||||
|  | ||||
| def to_array(word: str) -> npt.NDArray[np.int64]: | ||||
|     """Converts a string of characters into a corresponding numpy array. | ||||
|  | ||||
|     Args: | ||||
|         word: Word in string form. It is assumed that each character in the | ||||
|           string is either an empty space ' ' or lowercase alphabetical | ||||
|           character. | ||||
|  | ||||
|     Returns: | ||||
|         An array representation of the word. | ||||
|     """ | ||||
|     return np.array([_char_d[c] for c in word]) | ||||
|  | ||||
|  | ||||
| def get_words(category: str, build: bool = False) -> npt.NDArray[np.int64]: | ||||
|     """Loads a list of words in array form.  | ||||
|  | ||||
|     If specified, this will recompute the list from the human-readable list of | ||||
|     words, and save the results in array form. | ||||
|  | ||||
|     Args: | ||||
|         category: Either 'guess' or 'solution', which corresponds to the list | ||||
|           of acceptable guess words and the list of acceptable solution words. | ||||
|         build: If True, recomputes and saves the array-version of the computed | ||||
|           list for future access. | ||||
|  | ||||
|     Returns: | ||||
|         An array representation of the list of words specified by the category. | ||||
|         This array has two dimensions, and the number of columns is fixed at | ||||
|         five. | ||||
|     """ | ||||
|     assert category in {'guess', 'solution'} | ||||
|  | ||||
|     arr_path = Path(__file__).parent / f'dictionary/{category}_list.npy' | ||||
|     if build: | ||||
|         list_path = Path(__file__).parent / f'dictionary/{category}_list.csv' | ||||
|  | ||||
|         with open(list_path, 'r') as f: | ||||
|             words = np.array([to_array(line.strip()) for line in f]) | ||||
|             np.save(arr_path, words) | ||||
|  | ||||
|     return np.load(arr_path) | ||||
|  | ||||
|  | ||||
| def play(): | ||||
|     """Play Wordle yourself!""" | ||||
|     import gym | ||||
|     import gym_wordle | ||||
|  | ||||
|     env = gym.make('Wordle-v0')  # load the environment | ||||
|  | ||||
|     env.reset() | ||||
|     solution = to_english(env.unwrapped.solution_space[env.solution]).upper()  # no peeking! | ||||
|  | ||||
|     done = False | ||||
|  | ||||
|     while not done: | ||||
|         action = -1 | ||||
|  | ||||
|         # in general, the environment won't be forgiving if you input an | ||||
|         # invalid word, but for this function I want to let you screw up user | ||||
|         # input without consequence, so just loops until valid input is taken | ||||
|         while not env.action_space.contains(action): | ||||
|             guess = input('Guess: ') | ||||
|             action = env.unwrapped.action_space.index_of(to_array(guess)) | ||||
|  | ||||
|         state, reward, done, info = env.step(action) | ||||
|         env.render() | ||||
|  | ||||
|     print(f"The word was {solution}") | ||||
							
								
								
									
										285
									
								
								gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										285
									
								
								gym_wordle/wordle.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,285 @@ | ||||
| import gymnasium as gym | ||||
| import numpy as np | ||||
| import numpy.typing as npt | ||||
| from sty import fg, bg, ef, rs | ||||
|  | ||||
| from collections import Counter, defaultdict | ||||
| from gym_wordle.utils import to_english, to_array, get_words | ||||
| from typing import Optional | ||||
|  | ||||
| class WordList(gym.spaces.Discrete): | ||||
|     """Super class for defining a space of valid words according to a specified | ||||
|     list. | ||||
|  | ||||
|     The space is a subclass of gym.spaces.Discrete, where each element | ||||
|     corresponds to an index of a valid word in the word list. The obfuscation | ||||
|     is necessary for more direct implementation of RL algorithms, which expect | ||||
|     spaces of less sophisticated form. | ||||
|  | ||||
|     In addition to the default methods of the Discrete space, it implements | ||||
|     a __getitem__ method for easy index lookup, and an index_of method to | ||||
|     convert potential words into their corresponding index (if they exist). | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, words: npt.NDArray[np.int64], **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             words: Collection of words in array form with shape (_, 5), where | ||||
|               each word is a row of the array. Each array element is an integer | ||||
|               between 0,...,26 (inclusive). | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         super().__init__(words.shape[0], **kwargs) | ||||
|         self.words = words | ||||
|  | ||||
|     def __getitem__(self, index: int) -> npt.NDArray[np.int64]: | ||||
|         """Obtains the (int-encoded) word associated with the given index. | ||||
|  | ||||
|         Args: | ||||
|             index: Index for the list of words. | ||||
|  | ||||
|         Returns: | ||||
|             Associated word at the position specified by index. | ||||
|         """ | ||||
|         return self.words[index] | ||||
|  | ||||
|     def index_of(self, word: npt.NDArray[np.int64]) -> int: | ||||
|         """Given a word, determine its index in the list (if it exists), | ||||
|         otherwise returning -1 if no index exists. | ||||
|  | ||||
|         Args: | ||||
|             word: Word to find in the word list. | ||||
|  | ||||
|         Returns: | ||||
|             The index of the given word if it exists, otherwise -1. | ||||
|         """ | ||||
|         try: | ||||
|             index, = np.nonzero((word == self.words).all(axis=1)) | ||||
|             return index[0] | ||||
|         except: | ||||
|             return -1 | ||||
|  | ||||
|  | ||||
| class SolutionList(WordList): | ||||
|     """Space for *solution* words to the Wordle environment. | ||||
|  | ||||
|     In the game Wordle, there are two different collections of words: | ||||
|  | ||||
|     * "guesses", which the game accepts as valid words to use to guess the | ||||
|       answer. | ||||
|     * "solutions", which the game uses to choose solutions from. | ||||
|  | ||||
|     Of course, the set of solutions is a strict subset of the set of guesses. | ||||
|  | ||||
|     This class represents the set of solution words. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         words = get_words('solution') | ||||
|         super().__init__(words, **kwargs) | ||||
|  | ||||
|  | ||||
| class WordleObsSpace(gym.spaces.Box): | ||||
|     """Implementation of the state (observation) space in terms of gym | ||||
|     primitives, in this case, gym.spaces.Box. | ||||
|  | ||||
|     The Wordle observation space can be thought of as a 6x5 array with two | ||||
|     channels: | ||||
|  | ||||
|       - the character channel, indicating which characters are placed on the | ||||
|         board (unfilled rows are marked with the empty character, 0) | ||||
|       - the flag channel, indicating the in-game information associated with | ||||
|         each character's placement (green highlight, yellow highlight, etc.) | ||||
|  | ||||
|     where there are 6 rows, one for each turn in the game, and 5 columns, since | ||||
|     the solution will always be a word of length 5. | ||||
|  | ||||
|     For simplicity, and compatibility with stable_baselines algorithms, | ||||
|     this multichannel is modeled as a 6x10 array, where the two channels are | ||||
|     horizontally appended (along columns). Thus each row in the observation | ||||
|     should be interpreted as c0 c1 c2 c3 c4 f0 f1 f2 f3 f4 when the word is | ||||
|     c0...c4 and its associated flags are f0...f4. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         self.n_rows = 6 | ||||
|         self.n_cols = 5 | ||||
|         self.max_char = 26 | ||||
|         self.max_flag = 4 | ||||
|  | ||||
|         low = np.zeros((self.n_rows, 2*self.n_cols)) | ||||
|         high = np.c_[np.full((self.n_rows, self.n_cols), self.max_char), | ||||
|                      np.full((self.n_rows, self.n_cols), self.max_flag)] | ||||
|  | ||||
|         super().__init__(low, high, dtype=np.int64, **kwargs) | ||||
|  | ||||
|  | ||||
| class GuessList(WordList): | ||||
|     """Space for *guess* words to the Wordle environment. | ||||
|  | ||||
|     This class represents the set of guess words. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, **kwargs): | ||||
|         """ | ||||
|         Args: | ||||
|             kwargs: See documentation for gym.spaces.MultiDiscrete | ||||
|         """ | ||||
|         words = get_words('guess') | ||||
|         super().__init__(words, **kwargs) | ||||
|  | ||||
|  | ||||
| class WordleEnv(gym.Env): | ||||
|     metadata = {'render.modes': ['human']} | ||||
|  | ||||
|     # Character flag codes | ||||
|     no_char = 0 | ||||
|     right_pos = 1 | ||||
|     wrong_pos = 2 | ||||
|     wrong_char = 3 | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.action_space = GuessList() | ||||
|         self.solution_space = SolutionList() | ||||
|  | ||||
|         self.observation_space = WordleObsSpace() | ||||
|  | ||||
|         self._highlights = { | ||||
|             self.right_pos: (bg.green, bg.rs), | ||||
|             self.wrong_pos: (bg.yellow, bg.rs), | ||||
|             self.wrong_char: ('', ''), | ||||
|             self.no_char: ('', ''), | ||||
|         } | ||||
|  | ||||
|         self.n_rounds = 6 | ||||
|         self.n_letters = 5 | ||||
|         self.info = {'correct': False, 'guesses': defaultdict(int)} | ||||
|  | ||||
|     def _highlighter(self, char: str, flag: int) -> str: | ||||
|         """Terminal renderer functionality. Properly highlights a character | ||||
|         based on the flag associated with it. | ||||
|  | ||||
|         Args: | ||||
|             char: Character in question. | ||||
|             flag: Associated flag, one of: | ||||
|                 - 0: no character (render no background) | ||||
|                 - 1: right position (render green background) | ||||
|                 - 2: wrong position (render yellow background) | ||||
|                 - 3: wrong character (render no background) | ||||
|  | ||||
|         Returns: | ||||
|             Correct ASCII sequence producing the desired character in the | ||||
|             correct background. | ||||
|         """ | ||||
|         front, back = self._highlights[flag] | ||||
|         return front + char + back | ||||
|  | ||||
|     def reset(self, seed=None, options=None): | ||||
|         """Reset the environment to an initial state and returns an initial | ||||
|         observation. | ||||
|  | ||||
|         Note: The observation space instance should be a Box space. | ||||
|  | ||||
|         Returns: | ||||
|             state (object): The initial observation of the space. | ||||
|         """ | ||||
|         self.round = 0 | ||||
|         self.solution = self.solution_space.sample() | ||||
|  | ||||
|         self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64) | ||||
|  | ||||
|         self.info = {'correct': False, 'guesses': defaultdict(int)} | ||||
|  | ||||
|         return self.state, self.info | ||||
|  | ||||
|     def render(self, mode: str = 'human'): | ||||
|         """Renders the Wordle environment. | ||||
|  | ||||
|         Currently supported render modes: | ||||
|         - human: renders the Wordle game to the terminal. | ||||
|  | ||||
|         Args: | ||||
|             mode: the mode to render with. | ||||
|         """ | ||||
|         if mode == 'human': | ||||
|             for row in self.state: | ||||
|                 text = ''.join(map( | ||||
|                     self._highlighter, | ||||
|                     to_english(row[:self.n_letters]).upper(), | ||||
|                     row[self.n_letters:] | ||||
|                 )) | ||||
|                 print(text) | ||||
|         else: | ||||
|             super().render(mode=mode) | ||||
|  | ||||
|     def step(self, action): | ||||
|         """Run one step of the Wordle game. Every game must be previously | ||||
|         initialized by a call to the `reset` method. | ||||
|  | ||||
|         Args: | ||||
|             action: Word guessed by the agent. | ||||
|  | ||||
|         Returns: | ||||
|             state (object): Wordle game state after the guess. | ||||
|             reward (float): Reward associated with the guess. | ||||
|             done (bool): Whether the game has ended. | ||||
|             info (dict): Auxiliary diagnostic information. | ||||
|         """ | ||||
|         assert self.action_space.contains(action), 'Invalid word!' | ||||
|  | ||||
|         action = self.action_space[action] | ||||
|         solution = self.solution_space[self.solution] | ||||
|  | ||||
|         self.state[self.round][:self.n_letters] = action | ||||
|  | ||||
|         counter = Counter() | ||||
|         for i, char in enumerate(action): | ||||
|             flag_i = i + self.n_letters | ||||
|             counter[char] += 1 | ||||
|  | ||||
|             if char == solution[i]: | ||||
|                 self.state[self.round, flag_i] = self.right_pos | ||||
|             elif counter[char] <= (char == solution).sum(): | ||||
|                 self.state[self.round, flag_i] = self.wrong_pos | ||||
|             else: | ||||
|                 self.state[self.round, flag_i] = self.wrong_char | ||||
|  | ||||
|         self.round += 1 | ||||
|  | ||||
|         correct = (action == solution).all() | ||||
|         game_over = (self.round == self.n_rounds) | ||||
|  | ||||
|         done = correct or game_over | ||||
|  | ||||
|         reward = 0 | ||||
|         # correct spot | ||||
|         reward += np.sum(self.state[:, 5:] == 1) * 2 | ||||
|  | ||||
|         # correct letter not correct spot | ||||
|         reward += np.sum(self.state[:, 5:] == 2) * 1 | ||||
|  | ||||
|         # incorrect letter | ||||
|         reward += np.sum(self.state[:, 5:] == 3) * -1 | ||||
|  | ||||
|         # guess same word as before | ||||
|         hashable_action = to_english(action) | ||||
|         if hashable_action in self.info['guesses']: | ||||
|             reward += -10 * self.info['guesses'][hashable_action] | ||||
|         else:  # guess different word | ||||
|             reward += 10 | ||||
|  | ||||
|         self.info['guesses'][hashable_action] += 1 | ||||
|  | ||||
|         # for game ending in win or loss | ||||
|         reward += 10 if correct else -10 if done else 0 | ||||
|  | ||||
|         self.info['correct'] = correct | ||||
|  | ||||
|         # observation, reward, terminated, truncated, info | ||||
|         return self.state, reward, done, False, self.info | ||||
		Reference in New Issue
	
	Block a user