mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-09 22:54:45 +00:00
run model train, abt 3 avg reward
This commit is contained in:
parent
f40301cac9
commit
848d385482
File diff suppressed because it is too large
Load Diff
338
dqn_wordle.ipynb
338
dqn_wordle.ipynb
@ -1,338 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import gym\n",
|
|
||||||
"import gym_wordle\n",
|
|
||||||
"from stable_baselines3 import DQN, PPO, common\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import tqdm"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"<Monitor<WordleEnv instance>>\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
|
||||||
"env = common.monitor.Monitor(env)\n",
|
|
||||||
"\n",
|
|
||||||
"print(env)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Using cuda device\n",
|
|
||||||
"Wrapping the env in a DummyVecEnv.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "6921a0721569456abf5bceac7e7b6b34",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Output()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 4.97 |\n",
|
|
||||||
"| ep_rew_mean | -63.8 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 10000 |\n",
|
|
||||||
"| fps | 1628 |\n",
|
|
||||||
"| time_elapsed | 30 |\n",
|
|
||||||
"| total_timesteps | 49995 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"----------------------------------\n",
|
|
||||||
"| rollout/ | |\n",
|
|
||||||
"| ep_len_mean | 5 |\n",
|
|
||||||
"| ep_rew_mean | -70.5 |\n",
|
|
||||||
"| exploration_rate | 0.05 |\n",
|
|
||||||
"| time/ | |\n",
|
|
||||||
"| episodes | 20000 |\n",
|
|
||||||
"| fps | 662 |\n",
|
|
||||||
"| time_elapsed | 150 |\n",
|
|
||||||
"| total_timesteps | 99992 |\n",
|
|
||||||
"| train/ | |\n",
|
|
||||||
"| learning_rate | 0.0001 |\n",
|
|
||||||
"| loss | 11.7 |\n",
|
|
||||||
"| n_updates | 12497 |\n",
|
|
||||||
"----------------------------------\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
|
||||||
],
|
|
||||||
"text/plain": []
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
|
||||||
"</pre>\n"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<stable_baselines3.dqn.dqn.DQN at 0x1bfd6cc0210>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"total_timesteps = 100_000\n",
|
|
||||||
"model = DQN(\"MlpPolicy\", env, verbose=1, device='cuda')\n",
|
|
||||||
"model.learn(total_timesteps=total_timesteps, log_interval=10_000, progress_bar=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model.save(\"dqn_new_state\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
|
||||||
"Exception: code() argument 13 must be str, not int\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"c:\\Repository\\cse151b-final-project\\env\\Lib\\site-packages\\stable_baselines3\\common\\save_util.py:166: UserWarning: Could not deserialize object exploration_schedule. Consider using `custom_objects` argument to replace this object.\n",
|
|
||||||
"Exception: code() argument 13 must be str, not int\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# model = DQN.load(\"dqn_wordle\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"[1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.\n",
|
|
||||||
" 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
|
|
||||||
" 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
|
|
||||||
" 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
"0\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"env = gym_wordle.wordle.WordleEnv()\n",
|
|
||||||
"\n",
|
|
||||||
"for i in range(1000):\n",
|
|
||||||
" \n",
|
|
||||||
" state, info = env.reset()\n",
|
|
||||||
"\n",
|
|
||||||
" done = False\n",
|
|
||||||
"\n",
|
|
||||||
" wins = 0\n",
|
|
||||||
"\n",
|
|
||||||
" while not done:\n",
|
|
||||||
"\n",
|
|
||||||
" action, _states = model.predict(state, deterministic=True)\n",
|
|
||||||
"\n",
|
|
||||||
" state, reward, done, truncated, info = env.step(action)\n",
|
|
||||||
"\n",
|
|
||||||
" print(state)\n",
|
|
||||||
" if info[\"correct\"]:\n",
|
|
||||||
" wins += 1\n",
|
|
||||||
"\n",
|
|
||||||
"print(wins)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"(array([1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n",
|
|
||||||
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
|
|
||||||
" 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
|
||||||
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
|
||||||
" 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
|
||||||
" 0., 0., 0., 0., 0., 0., 0., 1.]),\n",
|
|
||||||
" -50)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"state, reward"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.11.5"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
@ -72,15 +72,16 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
self.guess_prefix = ''
|
self.guess_prefix = ''
|
||||||
self.round += 1
|
self.round += 1
|
||||||
|
|
||||||
# end after 5 rounds of total guesses
|
# end after 3 rounds of total guesses
|
||||||
if self.round == 2:
|
if self.round == 3:
|
||||||
# reward = 5
|
# reward = 5
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
obs = self._get_obs()
|
obs = self._get_obs()
|
||||||
|
|
||||||
if reward < -50:
|
if reward < -5:
|
||||||
print(obs, reward, done)
|
print(obs, reward, done)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
return obs, reward, done, False, {}
|
return obs, reward, done, False, {}
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
self.letter_positions = np.ones((26, 4), dtype=np.int32)
|
||||||
self.guessed_letters = set()
|
self.guessed_letters = set()
|
||||||
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||||
self.round = 1
|
self.round = 0
|
||||||
return self._get_obs(), {}
|
return self._get_obs(), {}
|
||||||
|
|
||||||
def encode_word(self, word):
|
def encode_word(self, word):
|
||||||
|
Loading…
Reference in New Issue
Block a user