mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-09 22:54:45 +00:00
added state saving
This commit is contained in:
parent
4fb81317f0
commit
3747af9d22
@ -43,6 +43,174 @@
|
|||||||
"check_env(env) # Optional: Verify the environment is compatible with SB3"
|
"check_env(env) # Optional: Verify the environment is compatible with SB3"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"initial_state = env.clone_state()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"obs, _ = env.reset()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_save_path = \"wordle_ppo_model\"\n",
|
||||||
|
"model = PPO.load(model_save_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"action, _ = model.predict(obs)\n",
|
||||||
|
"obs, reward, done, _, info = env.step(action)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 24,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"action % 26"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 28,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"ord('f') - ord('a')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'f'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 26,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chr(ord('a') + action % 26)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||||
|
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||||
|
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||||
|
" 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
|
||||||
|
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||||
|
" 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
|
||||||
|
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||||
|
" 1, 1])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"obs"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"env.set_state(initial_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 20,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"all(env.get_obs() == obs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Perform your action to see the outcome\n",
|
||||||
|
"action = # Define your action\n",
|
||||||
|
"observation, reward, done, info = env.step(action)\n",
|
||||||
|
"\n",
|
||||||
|
"# Revert to the initial state\n",
|
||||||
|
"env.env.set_state(initial_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
@ -128,7 +296,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model_save_path = \"wordle_ppo_model\"\n",
|
"model_save_path = \"wordle_ppo_model_test\"\n",
|
||||||
"config = {\n",
|
"config = {\n",
|
||||||
" \"policy_type\": \"MlpPolicy\",\n",
|
" \"policy_type\": \"MlpPolicy\",\n",
|
||||||
" \"total_timesteps\": 200_000\n",
|
" \"total_timesteps\": 200_000\n",
|
||||||
@ -369,7 +537,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.10"
|
"version": "3.11.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
from stable_baselines3 import PPO, DQN
|
from stable_baselines3 import PPO, DQN
|
||||||
from letter_guess import LetterGuessingEnv
|
from letter_guess import LetterGuessingEnv
|
||||||
|
|
||||||
|
|
||||||
def load_valid_words(file_path='wordle_words.txt'):
|
def load_valid_words(file_path='wordle_words.txt'):
|
||||||
"""
|
"""
|
||||||
Load valid five-letter words from a specified text file.
|
Load valid five-letter words from a specified text file.
|
||||||
@ -20,6 +21,7 @@ def load_valid_words(file_path='wordle_words.txt'):
|
|||||||
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
||||||
return valid_words
|
return valid_words
|
||||||
|
|
||||||
|
|
||||||
class AI:
|
class AI:
|
||||||
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
|
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
@ -36,9 +38,10 @@ class AI:
|
|||||||
if use_q_model:
|
if use_q_model:
|
||||||
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
||||||
self.q_env = LetterGuessingEnv(vocab_file)
|
self.q_env = LetterGuessingEnv(vocab_file)
|
||||||
|
self.q_env_state, _ = self.q_env.reset()
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
self.q_model = PPO.load(model_file)
|
self.q_model = PPO.load(model_file)
|
||||||
self.q_env_state = None
|
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -46,8 +49,10 @@ class AI:
|
|||||||
num_guesses = 0
|
num_guesses = 0
|
||||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||||
num_guesses += 1
|
num_guesses += 1
|
||||||
|
|
||||||
# sample a word, this would use the q_env_state if the q_model is used
|
# sample a word, this would use the q_env_state if the q_model is used
|
||||||
word = self.sample()
|
word = self.sample()
|
||||||
|
|
||||||
# get emulated results
|
# get emulated results
|
||||||
results = results_callback(word)
|
results = results_callback(word)
|
||||||
if self.use_q_model:
|
if self.use_q_model:
|
||||||
@ -132,9 +137,31 @@ class AI:
|
|||||||
pattern = re.compile(regex_string)
|
pattern = re.compile(regex_string)
|
||||||
|
|
||||||
# From the words with the highest scores, only return the best word that match the regex pattern
|
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||||
|
max_qval = float('-inf')
|
||||||
|
best_word = None
|
||||||
for word, _ in self.best_words:
|
for word, _ in self.best_words:
|
||||||
|
# reset the state back to before we guessed a word
|
||||||
|
self.q_env.set_state(self.q_env_state)
|
||||||
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||||
|
if self.use_q_model:
|
||||||
|
# Use policy to grade word
|
||||||
|
# get the state and action pairs
|
||||||
|
curr_qval = 0
|
||||||
|
|
||||||
|
for l in word:
|
||||||
|
action = ord(l) - ord('a')
|
||||||
|
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action)
|
||||||
|
_, _, _, _, _ = self.q_env.step(action)
|
||||||
|
curr_qval += q_val
|
||||||
|
|
||||||
|
if curr_qval > max_qval:
|
||||||
|
max_qval = curr_qval
|
||||||
|
best_word = word
|
||||||
|
else:
|
||||||
|
# otherwise return the word from eric heuristic
|
||||||
return word
|
return word
|
||||||
|
self.q_env.set_state(self.q_env_state)
|
||||||
|
return best_word
|
||||||
|
|
||||||
def get_vocab(self, vocab_file):
|
def get_vocab(self, vocab_file):
|
||||||
vocab = []
|
vocab = []
|
||||||
|
@ -3,6 +3,7 @@ from gymnasium import spaces
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
class LetterGuessingEnv(gym.Env):
|
class LetterGuessingEnv(gym.Env):
|
||||||
@ -29,8 +30,28 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
def clone_state(self):
|
||||||
|
# Clone the current state
|
||||||
|
return {
|
||||||
|
'target_word': self.target_word,
|
||||||
|
'letter_flags': copy.deepcopy(self.letter_flags),
|
||||||
|
'letter_positions': copy.deepcopy(self.letter_positions),
|
||||||
|
'guessed_letters': copy.deepcopy(self.guessed_letters),
|
||||||
|
'guess_prefix': self.guess_prefix,
|
||||||
|
'round': self.round
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
# Restore the state
|
||||||
|
self.target_word = state['target_word']
|
||||||
|
self.letter_flags = copy.deepcopy(state['letter_flags'])
|
||||||
|
self.letter_positions = copy.deepcopy(state['letter_positions'])
|
||||||
|
self.guessed_letters = copy.deepcopy(state['guessed_letters'])
|
||||||
|
self.guess_prefix = state['guess_prefix']
|
||||||
|
self.round = state['round']
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
letter_index = action % 26 # Assuming action is the letter index directly
|
letter_index = action # Assuming action is the letter index directly
|
||||||
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||||
letter = chr(ord('a') + letter_index)
|
letter = chr(ord('a') + letter_index)
|
||||||
|
|
||||||
@ -77,7 +98,7 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
# reward = 5
|
# reward = 5
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
obs = self._get_obs()
|
obs = self.get_obs()
|
||||||
|
|
||||||
if reward < -5:
|
if reward < -5:
|
||||||
print(obs, reward, done)
|
print(obs, reward, done)
|
||||||
@ -93,7 +114,7 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
self.guessed_letters = set()
|
self.guessed_letters = set()
|
||||||
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||||
self.round = 0
|
self.round = 0
|
||||||
return self._get_obs(), {}
|
return self.get_obs(), {}
|
||||||
|
|
||||||
def encode_word(self, word):
|
def encode_word(self, word):
|
||||||
encoded = np.zeros((26,))
|
encoded = np.zeros((26,))
|
||||||
@ -102,7 +123,7 @@ class LetterGuessingEnv(gym.Env):
|
|||||||
encoded[index] = 1
|
encoded[index] = 1
|
||||||
return encoded
|
return encoded
|
||||||
|
|
||||||
def _get_obs(self):
|
def get_obs(self):
|
||||||
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
|
Loading…
Reference in New Issue
Block a user