mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-25 17:49:10 +00:00
added state saving
This commit is contained in:
parent
4fb81317f0
commit
3747af9d22
@ -43,6 +43,174 @@
|
||||
"check_env(env) # Optional: Verify the environment is compatible with SB3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"initial_state = env.clone_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, _ = env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_save_path = \"wordle_ppo_model\"\n",
|
||||
"model = PPO.load(model_save_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"action, _ = model.predict(obs)\n",
|
||||
"obs, reward, done, _, info = env.step(action)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"5"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"action % 26"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"5"
|
||||
]
|
||||
},
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ord('f') - ord('a')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'f'"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chr(ord('a') + action % 26)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
||||
" 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||
" 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
|
||||
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||
" 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1,\n",
|
||||
" 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n",
|
||||
" 1, 1])"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"obs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.set_state(initial_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all(env.get_obs() == obs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Perform your action to see the outcome\n",
|
||||
"action = # Define your action\n",
|
||||
"observation, reward, done, info = env.step(action)\n",
|
||||
"\n",
|
||||
"# Revert to the initial state\n",
|
||||
"env.env.set_state(initial_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
@ -128,7 +296,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_save_path = \"wordle_ppo_model\"\n",
|
||||
"model_save_path = \"wordle_ppo_model_test\"\n",
|
||||
"config = {\n",
|
||||
" \"policy_type\": \"MlpPolicy\",\n",
|
||||
" \"total_timesteps\": 200_000\n",
|
||||
@ -369,7 +537,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -6,6 +6,7 @@ import numpy as np
|
||||
from stable_baselines3 import PPO, DQN
|
||||
from letter_guess import LetterGuessingEnv
|
||||
|
||||
|
||||
def load_valid_words(file_path='wordle_words.txt'):
|
||||
"""
|
||||
Load valid five-letter words from a specified text file.
|
||||
@ -20,8 +21,9 @@ def load_valid_words(file_path='wordle_words.txt'):
|
||||
valid_words = [line.strip() for line in file if len(line.strip()) == 5]
|
||||
return valid_words
|
||||
|
||||
|
||||
class AI:
|
||||
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model = False):
|
||||
def __init__(self, vocab_file, model_file, num_letters=5, num_guesses=6, use_q_model=False):
|
||||
self.vocab_file = vocab_file
|
||||
self.num_letters = num_letters
|
||||
self.num_guesses = 6
|
||||
@ -36,18 +38,21 @@ class AI:
|
||||
if use_q_model:
|
||||
# we initialize the same q env as the model train ONLY to simplify storing/calculating the gym state, not used to control the game at all
|
||||
self.q_env = LetterGuessingEnv(vocab_file)
|
||||
self.q_env_state, _ = self.q_env.reset()
|
||||
|
||||
# load model
|
||||
self.q_model = PPO.load(model_file)
|
||||
self.q_env_state = None
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
def solve_eval(self, results_callback):
|
||||
num_guesses = 0
|
||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||
num_guesses += 1
|
||||
|
||||
# sample a word, this would use the q_env_state if the q_model is used
|
||||
word = self.sample()
|
||||
|
||||
# get emulated results
|
||||
results = results_callback(word)
|
||||
if self.use_q_model:
|
||||
@ -59,7 +64,7 @@ class AI:
|
||||
|
||||
self.arc_consistency(word, results)
|
||||
return num_guesses, word
|
||||
|
||||
|
||||
def solve(self):
|
||||
num_guesses = 0
|
||||
while [len(e) for e in self.domains] != [1 for _ in range(self.num_letters)]:
|
||||
@ -132,9 +137,31 @@ class AI:
|
||||
pattern = re.compile(regex_string)
|
||||
|
||||
# From the words with the highest scores, only return the best word that match the regex pattern
|
||||
max_qval = float('-inf')
|
||||
best_word = None
|
||||
for word, _ in self.best_words:
|
||||
# reset the state back to before we guessed a word
|
||||
self.q_env.set_state(self.q_env_state)
|
||||
if pattern.match(word) and False not in [e in word for e in self.possible_letters]:
|
||||
return word
|
||||
if self.use_q_model:
|
||||
# Use policy to grade word
|
||||
# get the state and action pairs
|
||||
curr_qval = 0
|
||||
|
||||
for l in word:
|
||||
action = ord(l) - ord('a')
|
||||
q_val = self.q_model.policy.evaluate_actions(self.q_env.get_obs(), action)
|
||||
_, _, _, _, _ = self.q_env.step(action)
|
||||
curr_qval += q_val
|
||||
|
||||
if curr_qval > max_qval:
|
||||
max_qval = curr_qval
|
||||
best_word = word
|
||||
else:
|
||||
# otherwise return the word from eric heuristic
|
||||
return word
|
||||
self.q_env.set_state(self.q_env_state)
|
||||
return best_word
|
||||
|
||||
def get_vocab(self, vocab_file):
|
||||
vocab = []
|
||||
|
@ -3,6 +3,7 @@ from gymnasium import spaces
|
||||
import numpy as np
|
||||
import random
|
||||
import re
|
||||
import copy
|
||||
|
||||
|
||||
class LetterGuessingEnv(gym.Env):
|
||||
@ -29,8 +30,28 @@ class LetterGuessingEnv(gym.Env):
|
||||
|
||||
self.reset()
|
||||
|
||||
def clone_state(self):
|
||||
# Clone the current state
|
||||
return {
|
||||
'target_word': self.target_word,
|
||||
'letter_flags': copy.deepcopy(self.letter_flags),
|
||||
'letter_positions': copy.deepcopy(self.letter_positions),
|
||||
'guessed_letters': copy.deepcopy(self.guessed_letters),
|
||||
'guess_prefix': self.guess_prefix,
|
||||
'round': self.round
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
# Restore the state
|
||||
self.target_word = state['target_word']
|
||||
self.letter_flags = copy.deepcopy(state['letter_flags'])
|
||||
self.letter_positions = copy.deepcopy(state['letter_positions'])
|
||||
self.guessed_letters = copy.deepcopy(state['guessed_letters'])
|
||||
self.guess_prefix = state['guess_prefix']
|
||||
self.round = state['round']
|
||||
|
||||
def step(self, action):
|
||||
letter_index = action % 26 # Assuming action is the letter index directly
|
||||
letter_index = action # Assuming action is the letter index directly
|
||||
position = len(self.guess_prefix) # The next position in the prefix is determined by its current length
|
||||
letter = chr(ord('a') + letter_index)
|
||||
|
||||
@ -56,8 +77,8 @@ class LetterGuessingEnv(gym.Env):
|
||||
reward = 1 # Reward for adding new information by trying a new letter
|
||||
|
||||
# Update the letter_positions matrix to reflect the new guess
|
||||
if position == 4:
|
||||
self.letter_positions[:,:] = 1
|
||||
if position == 4:
|
||||
self.letter_positions[:, :] = 1
|
||||
else:
|
||||
self.letter_positions[:, position] = 0
|
||||
self.letter_positions[letter_index, position] = 1
|
||||
@ -77,8 +98,8 @@ class LetterGuessingEnv(gym.Env):
|
||||
# reward = 5
|
||||
done = True
|
||||
|
||||
obs = self._get_obs()
|
||||
|
||||
obs = self.get_obs()
|
||||
|
||||
if reward < -5:
|
||||
print(obs, reward, done)
|
||||
exit(0)
|
||||
@ -93,7 +114,7 @@ class LetterGuessingEnv(gym.Env):
|
||||
self.guessed_letters = set()
|
||||
self.guess_prefix = "" # Reset the guess prefix for the new episode
|
||||
self.round = 0
|
||||
return self._get_obs(), {}
|
||||
return self.get_obs(), {}
|
||||
|
||||
def encode_word(self, word):
|
||||
encoded = np.zeros((26,))
|
||||
@ -102,7 +123,7 @@ class LetterGuessingEnv(gym.Env):
|
||||
encoded[index] = 1
|
||||
return encoded
|
||||
|
||||
def _get_obs(self):
|
||||
def get_obs(self):
|
||||
return np.concatenate([self.letter_flags.flatten(), self.letter_positions.flatten()])
|
||||
|
||||
def render(self, mode='human'):
|
||||
|
Loading…
Reference in New Issue
Block a user