added state saving

This commit is contained in:
Ethan Shapiro 2024-03-20 19:52:13 -07:00
parent 4fb81317f0
commit 3747af9d22
3 changed files with 230 additions and 14 deletions

View File

@ -43,6 +43,174 @@
"check_env(env) # Optional: Verify the environment is compatible with SB3" "check_env(env) # Optional: Verify the environment is compatible with SB3"
] ]
}, },
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"initial_state = env.clone_state()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"obs, _ = env.reset()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"model_save_path = \"wordle_ppo_model\"\n",
"model = PPO.load(model_save_path)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"action, _ = model.predict(obs)\n",
"obs, reward, done, _, info = env.step(action)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"action % 26"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ord('f') - ord('a')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'f'"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chr(ord('a') + action % 26)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
" 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
" 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
" 1, 1])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"obs"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"env.set_state(initial_state)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all(env.get_obs() == obs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Perform your action to see the outcome\n",
"action = # Define your action\n",
"observation, reward, done, info = env.step(action)\n",
"\n",
"# Revert to the initial state\n",
"env.env.set_state(initial_state)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
@ -128,7 +296,7 @@
} }
], ],
"source": [ "source": [
"model_save_path = \"wordle_ppo_model\"\n", "model_save_path = \"wordle_ppo_model_test\"\n",
"config = {\n", "config = {\n",
" \"policy_type\": \"MlpPolicy\",\n", " \"policy_type\": \"MlpPolicy\",\n",
" \"total_timesteps\": 200_000\n", " \"total_timesteps\": 200_000\n",
@ -369,7 +537,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.10" "version": "3.11.5"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -6,6 +6,7 @@ import numpy as np
from stable_baselines3 import PPO, DQN from stable_baselines3 import PPO, DQN
from letter_guess import LetterGuessingEnv from letter_guess import LetterGuessingEnv
def load_valid_words(file_path='wordle_words.txt'): def load_valid_words(file_path='wordle_words.txt'):
""" """
Load valid five-letter words from a specified text file. Load valid five-letter words from a specified text file.
@ -20,6 +21,7 @@ def load_valid_words(file_path='wordle_words.txt'):
valid_words = [line.strip() for line in file if len(line.strip()) == 5] valid_words = [line.strip() for line in file if len(line.strip()) == 5]
return valid_words return valid_words
class AI: class AI:
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False): def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
self.vocab_file = vocab_file self.vocab_file = vocab_file
@ -36,9 +38,10 @@ class AI:
if use_q_model: if use_q_model:
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all # we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
self.q_env = LetterGuessingEnv(vocab_file) self.q_env = LetterGuessingEnv(vocab_file)
self.q_env_state, _ = self.q_env.reset()
# load model # load model
self.q_model = PPO.load(model_file) self.q_model = PPO.load(model_file)
self.q_env_state = None
self.reset() self.reset()
@ -46,8 +49,10 @@ class AI:
num_guesses = 0 num_guesses = 0
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]: while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
num_guesses += 1 num_guesses += 1
# sample a word, this would use the q_env_state if the q_model is used # sample a word, this would use the q_env_state if the q_model is used
word = self.sample() word = self.sample()
# get emulated results # get emulated results
results = results_callback(word) results = results_callback(word)
if self.use_q_model: if self.use_q_model:
@ -132,9 +137,31 @@ class AI:
pattern = re.compile(regex_string) pattern = re.compile(regex_string)
# From the words with the highest scores, only return the best word that match the regex pattern # From the words with the highest scores, only return the best word that match the regex pattern
max_qval = float('-inf')
best_word = None
for word, _ in self.best_words: for word, _ in self.best_words:
# reset the state back to before we guessed a word
self.q_env.set_state(self.q_env_state)
if pattern.match(word) and False not in [e in word for e in self.possible_letters]: if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
if self.use_q_model:
# Use policy to grade word
# get the state and action pairs
curr_qval = 0
for l in word:
action = ord(l) - ord('a')
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action)
_, _, _, _, _ = self.q_env.step(action)
curr_qval += q_val
if curr_qval > max_qval:
max_qval = curr_qval
best_word = word
else:
# otherwise return the word from eric heuristic
return word return word
self.q_env.set_state(self.q_env_state)
return best_word
def get_vocab(self, vocab_file): def get_vocab(self, vocab_file):
vocab = [] vocab = []

View File

@ -3,6 +3,7 @@ from gymnasium import spaces
import numpy as np import numpy as np
import random import random
import re import re
import copy
class LetterGuessingEnv(gym.Env): class LetterGuessingEnv(gym.Env):
@ -29,8 +30,28 @@ class LetterGuessingEnv(gym.Env):
self.reset() self.reset()
def clone_state(self):
# Clone the current state
return {
'target_word': self.target_word,
'letter_flags': copy.deepcopy(self.letter_flags),
'letter_positions': copy.deepcopy(self.letter_positions),
'guessed_letters': copy.deepcopy(self.guessed_letters),
'guess_prefix': self.guess_prefix,
'round': self.round
}
def set_state(self, state):
# Restore the state
self.target_word = state['target_word']
self.letter_flags = copy.deepcopy(state['letter_flags'])
self.letter_positions = copy.deepcopy(state['letter_positions'])
self.guessed_letters = copy.deepcopy(state['guessed_letters'])
self.guess_prefix = state['guess_prefix']
self.round = state['round']
def step(self, action): def step(self, action):
letter_index = action % 26 # Assuming action is the letter index directly letter_index = action # Assuming action is the letter index directly
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
letter = chr(ord('a') + letter_index) letter = chr(ord('a') + letter_index)
@ -77,7 +98,7 @@ class LetterGuessingEnv(gym.Env):
# reward = 5 # reward = 5
done = True done = True
obs = self._get_obs() obs = self.get_obs()
if reward < -5: if reward < -5:
print(obs, reward, done) print(obs, reward, done)
@ -93,7 +114,7 @@ class LetterGuessingEnv(gym.Env):
self.guessed_letters = set() self.guessed_letters = set()
self.guess_prefix = "" # Reset the guess prefix for the new episode self.guess_prefix = "" # Reset the guess prefix for the new episode
self.round = 0 self.round = 0
return self._get_obs(), {} return self.get_obs(), {}
def encode_word(self, word): def encode_word(self, word):
encoded = np.zeros((26,)) encoded = np.zeros((26,))
@ -102,7 +123,7 @@ class LetterGuessingEnv(gym.Env):
encoded[index] = 1 encoded[index] = 1
return encoded return encoded
def _get_obs(self): def get_obs(self):
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()]) return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
def render(self, mode='human'): def render(self, mode='human'):