mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-10-22 18:49:21 +00:00
Compare commits
5 Commits
83e81722d2
...
arthur-tes
Author | SHA1 | Date | |
---|---|---|---|
|
dd5889da33 | ||
|
848ea719b7 | ||
|
f641d77c47 | ||
|
5ec123e0f1 | ||
|
e9622b6f68 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
**/data/*
|
||||
**/*.zip
|
||||
**/__pycache__
|
@@ -1,24 +1,38 @@
|
||||
import gym
|
||||
import gym_wordle
|
||||
import sys
|
||||
from stable_baselines3 import DQN
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
import wordle_gym
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
env = gym.make("Wordle-v0")
|
||||
done = False
|
||||
|
||||
print(env)
|
||||
|
||||
model = DQN("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=100)
|
||||
def train (model, env, total_timesteps = 100000):
|
||||
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
||||
model.save("dqn_wordle")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
def test(model, env, test_num=1000):
|
||||
|
||||
total_correct = 0
|
||||
|
||||
for i in tqdm(range(test_num)):
|
||||
|
||||
model = DQN.load("dqn_wordle")
|
||||
|
||||
state = env.reset()
|
||||
|
||||
env = gym.make("wordle-v0")
|
||||
obs = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
|
||||
action, _states = model.predict(state, deterministic=True)
|
||||
return total_correct / test_num
|
||||
|
||||
state, reward, done, info = env.step(action)
|
||||
if __name__ == "__main__":
|
||||
|
||||
env = gym.make("wordle-v0")
|
||||
model = DQN("MlpPolicy", env, verbose=0)
|
||||
print(env)
|
||||
print(model)
|
||||
|
||||
train(model, env, total_timesteps=500000)
|
||||
print(test(model, env))
|
9
wordle_gym/__init__.py
Normal file
9
wordle_gym/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from gym.envs.registration import register
|
||||
|
||||
register(
|
||||
id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv",
|
||||
)
|
||||
|
||||
register(
|
||||
id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv",
|
||||
)
|
0
wordle_gym/envs/__init__.py
Normal file
0
wordle_gym/envs/__init__.py
Normal file
15
wordle_gym/envs/strategies/base.py
Normal file
15
wordle_gym/envs/strategies/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import Enum
|
||||
|
||||
from typing import List
|
||||
|
||||
class StrategyType(Enum):
|
||||
RANDOM = 1
|
||||
ELIMINATION = 2
|
||||
PROBABILITY = 3
|
||||
|
||||
class Strategy:
|
||||
def __init__(self, type: StrategyType):
|
||||
self.type = type
|
||||
|
||||
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
|
||||
raise NotImplementedError("Strategy.get_best_word() not implemented")
|
2
wordle_gym/envs/strategies/elimination.py
Normal file
2
wordle_gym/envs/strategies/elimination.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def get_best_word(state):
|
||||
|
20
wordle_gym/envs/strategies/probabilistic.py
Normal file
20
wordle_gym/envs/strategies/probabilistic.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from random import sample
|
||||
from typing import List
|
||||
|
||||
from base import Strategy
|
||||
from base import StrategyType
|
||||
|
||||
from utils import freq
|
||||
|
||||
class Random(Strategy):
|
||||
def __init__(self):
|
||||
self.words = freq.get_5_letter_word_freqs()
|
||||
super().__init__(StrategyType.RANDOM)
|
||||
|
||||
def get_best_word(self, state: List[List[int]]):
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r = Random()
|
||||
print(r.get_best_word([]))
|
29
wordle_gym/envs/strategies/rand.py
Normal file
29
wordle_gym/envs/strategies/rand.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from random import sample
|
||||
from typing import List
|
||||
|
||||
from base import Strategy
|
||||
from base import StrategyType
|
||||
|
||||
from utils import freq
|
||||
|
||||
class Random(Strategy):
|
||||
def __init__(self):
|
||||
self.words = freq.get_5_letter_word_freqs()
|
||||
super().__init__(StrategyType.RANDOM)
|
||||
|
||||
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
|
||||
correct_letters = []
|
||||
regex = ""
|
||||
for g, s in zip(guesses, state):
|
||||
for c, s in zip(g, s):
|
||||
if s == 2:
|
||||
correct_letters.append(c)
|
||||
regex += c
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r = Random()
|
||||
print(r.get_best_word([]))
|
27
wordle_gym/envs/strategies/utils/freq.py
Normal file
27
wordle_gym/envs/strategies/utils/freq.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from os import path
|
||||
|
||||
def get_5_letter_word_freqs():
|
||||
"""
|
||||
Returns a list of words with 5 letters.
|
||||
"""
|
||||
FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt")
|
||||
lines = read_file(FILEPATH)
|
||||
return {k:v for k, v in get_freq(lines).items() if len(k) == 5}
|
||||
|
||||
|
||||
def read_file(filename):
|
||||
"""
|
||||
Reads a file and returns a list of words and frequencies
|
||||
"""
|
||||
with open(filename, 'r') as f:
|
||||
return f.readlines()
|
||||
|
||||
|
||||
def get_freq(lines):
|
||||
"""
|
||||
Returns a dictionary of words and their frequencies
|
||||
"""
|
||||
freqs = {}
|
||||
for word, freq in map(lambda x: x.split("\t"), lines):
|
||||
freqs[word] = int(freq)
|
||||
return freqs
|
131
wordle_gym/envs/wordle_env.py
Normal file
131
wordle_gym/envs/wordle_env.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
|
||||
|
||||
import gym
|
||||
from gym import error, spaces, utils
|
||||
from gym.utils import seeding
|
||||
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
WORD_LENGTH = 5
|
||||
TOTAL_GUESSES = 6
|
||||
SOLUTION_PATH = "../words/solution.csv"
|
||||
VALID_WORDS_PATH = "../words/guess.csv"
|
||||
|
||||
class LetterState(Enum):
|
||||
ABSENT = 0
|
||||
PRESENT = 1
|
||||
CORRECT_POSITION = 2
|
||||
|
||||
|
||||
class WordleEnv(gym.Env):
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def _current_path(self):
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _read_solutions(self):
|
||||
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
|
||||
|
||||
def _get_valid_words(self):
|
||||
words = []
|
||||
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
|
||||
words.append((word, Counter(word)))
|
||||
return words
|
||||
|
||||
def get_valid(self):
|
||||
return self._valid_words
|
||||
|
||||
def __init__(self):
|
||||
self._solutions = self._read_solutions()
|
||||
self._valid_words = self._get_valid_words()
|
||||
self.action_space = spaces.Discrete(len(self._valid_words))
|
||||
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
|
||||
np.random.seed(0)
|
||||
self.reset()
|
||||
|
||||
def _check_guess(self, guess, guess_counter):
|
||||
c = guess_counter & self.solution_ct
|
||||
result = []
|
||||
correct = True
|
||||
reward = 0
|
||||
for i, char in enumerate(guess):
|
||||
if c.get(char, 0) > 0:
|
||||
if self.solution[i] == char:
|
||||
result.append(2)
|
||||
reward += 2
|
||||
else:
|
||||
result.append(1)
|
||||
correct = False
|
||||
reward += 1
|
||||
c[char] -= 1
|
||||
else:
|
||||
result.append(0)
|
||||
correct = False
|
||||
return result, correct, reward
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
action: index of word in valid_words
|
||||
|
||||
returns:
|
||||
observation: (TOTAL_GUESSES, WORD_LENGTH)
|
||||
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
|
||||
done: True if game over, w/ or w/o correct answer
|
||||
additional_info: empty
|
||||
"""
|
||||
guess, guess_counter = self._valid_words[action]
|
||||
if guess in self.guesses:
|
||||
return self.obs, -1, False, {}
|
||||
self.guesses.append(guess)
|
||||
result, correct, reward = self._check_guess(guess, guess_counter)
|
||||
done = False
|
||||
|
||||
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
|
||||
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
|
||||
|
||||
self.guess_no += 1
|
||||
if correct:
|
||||
done = True
|
||||
reward = 1200
|
||||
if self.guess_no == TOTAL_GUESSES:
|
||||
done = True
|
||||
if not correct:
|
||||
reward = -15
|
||||
return self.obs, reward, done, {}
|
||||
|
||||
def reset(self):
|
||||
self.solution = self._solutions[np.random.randint(len(self._solutions))]
|
||||
self.solution_ct = Counter(self.solution)
|
||||
self.guess_no = 0
|
||||
self.guesses = []
|
||||
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
|
||||
return self.obs
|
||||
|
||||
def render(self, mode="human"):
|
||||
m = {
|
||||
0: "⬜",
|
||||
1: "🟨",
|
||||
2: "🟩"
|
||||
}
|
||||
print("Solution:", self.solution)
|
||||
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
|
||||
o_n = "".join(map(lambda x: m[x], o))
|
||||
print(g, o_n)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = WordleEnv()
|
||||
print(env.action_space)
|
||||
print(env.observation_space)
|
||||
print(env.solution)
|
||||
print(env.step(0))
|
||||
print(env.step(0))
|
||||
print(env.step(0))
|
||||
print(env.step(0))
|
||||
print(env.step(0))
|
||||
print(env.step(0))
|
12972
wordle_gym/words/guess.csv
Normal file
12972
wordle_gym/words/guess.csv
Normal file
File diff suppressed because it is too large
Load Diff
2315
wordle_gym/words/solution.csv
Normal file
2315
wordle_gym/words/solution.csv
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user