7 Commits

Author SHA1 Message Date
Arthur Lu
dd5889da33 try more time steps 2024-03-14 10:57:18 -07:00
Arthur Lu
848ea719b7 still doesnt train 2024-03-13 21:36:26 -07:00
Arthur Lu
f641d77c47 attempt to use the other wordle gym, causing cuda errors 2024-03-13 14:27:34 -07:00
Arthur Lu
5ec123e0f1 minor changes 2024-03-13 13:57:23 -07:00
Arthur Lu
e9622b6f68 switch to notebook 2024-03-13 11:04:30 -07:00
ltcptgeneral
83e81722d2 this should probably be working but isn't 2024-03-12 22:14:03 -07:00
ltcptgeneral
320f2f81b7 delete tests 2024-03-12 21:42:59 -07:00
14 changed files with 15561 additions and 227 deletions

4
.gitignore vendored
View File

@@ -1 +1,3 @@
**/data/*
**/data/*
**/*.zip
**/__pycache__

38
dqn_wordle.py Normal file
View File

@@ -0,0 +1,38 @@
import gym
import sys
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
import wordle_gym
import numpy as np
from tqdm import tqdm
def train (model, env, total_timesteps = 100000):
model.learn(total_timesteps=total_timesteps, progress_bar=True)
model.save("dqn_wordle")
def test(model, env, test_num=1000):
total_correct = 0
for i in tqdm(range(test_num)):
model = DQN.load("dqn_wordle")
env = gym.make("wordle-v0")
obs = env.reset()
done = False
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
return total_correct / test_num
if __name__ == "__main__":
env = gym.make("wordle-v0")
model = DQN("MlpPolicy", env, verbose=0)
print(env)
print(model)
train(model, env, total_timesteps=500000)
print(test(model, env))

File diff suppressed because one or more lines are too long

61
test.py
View File

@@ -1,61 +0,0 @@
from torch.utils.data import Dataset
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer
from tqdm import tqdm as progress_bar
import torch
import matplotlib
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-base-uncased", bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("google-bert/bert-base-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
# create tokenizer...
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
import json
class CodeDataset(Dataset):
def __init__(self):
with open("data/conala-train.json") as f:
self.data = json.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
intent = self.data[idx]["rewritten_intent"] if self.data[idx]["rewritten_intent"] else self.data[idx]["intent"]
return intent, self.data[idx]["snippet"]
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)
dataloader = CodeDataset()
model = model.to(device)
losses = []
epochs = 10
for i in range(epochs):
epoch_loss = 0
for idx, (question, answer) in progress_bar(enumerate(dataloader), total=len(dataloader)):
input_ids = tokenizer(question, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
label_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
loss = model(input_ids=input_ids, decoder_input_ids=label_ids, labels=label_ids).loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(epoch_loss)
plt.plot(losses, color="green", label="Training Loss")
plt.legend(loc = 'upper left')
plt.savefig("plot.png")

9
wordle_gym/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from gym.envs.registration import register
register(
id="wordle-v0", entry_point="wordle_gym.envs.wordle_env:WordleEnv",
)
register(
id="wordle-alpha-v0", entry_point="wordle_gym.envs.wordle_alpha_env:WordleEnv",
)

View File

View File

@@ -0,0 +1,15 @@
from enum import Enum
from typing import List
class StrategyType(Enum):
RANDOM = 1
ELIMINATION = 2
PROBABILITY = 3
class Strategy:
def __init__(self, type: StrategyType):
self.type = type
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
raise NotImplementedError("Strategy.get_best_word() not implemented")

View File

@@ -0,0 +1,2 @@
def get_best_word(state):

View File

@@ -0,0 +1,20 @@
from random import sample
from typing import List
from base import Strategy
from base import StrategyType
from utils import freq
class Random(Strategy):
def __init__(self):
self.words = freq.get_5_letter_word_freqs()
super().__init__(StrategyType.RANDOM)
def get_best_word(self, state: List[List[int]]):
if __name__ == "__main__":
r = Random()
print(r.get_best_word([]))

View File

@@ -0,0 +1,29 @@
from random import sample
from typing import List
from base import Strategy
from base import StrategyType
from utils import freq
class Random(Strategy):
def __init__(self):
self.words = freq.get_5_letter_word_freqs()
super().__init__(StrategyType.RANDOM)
def get_best_word(self, guesses: List[List[str]], state: List[List[int]]):
correct_letters = []
regex = ""
for g, s in zip(guesses, state):
for c, s in zip(g, s):
if s == 2:
correct_letters.append(c)
regex += c
if __name__ == "__main__":
r = Random()
print(r.get_best_word([]))

View File

@@ -0,0 +1,27 @@
from os import path
def get_5_letter_word_freqs():
"""
Returns a list of words with 5 letters.
"""
FILEPATH = path.join(path.dirname(path.abspath(__file__)), "data/norvig.txt")
lines = read_file(FILEPATH)
return {k:v for k, v in get_freq(lines).items() if len(k) == 5}
def read_file(filename):
"""
Reads a file and returns a list of words and frequencies
"""
with open(filename, 'r') as f:
return f.readlines()
def get_freq(lines):
"""
Returns a dictionary of words and their frequencies
"""
freqs = {}
for word, freq in map(lambda x: x.split("\t"), lines):
freqs[word] = int(freq)
return freqs

View File

@@ -0,0 +1,131 @@
import os
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from enum import Enum
from collections import Counter
import numpy as np
WORD_LENGTH = 5
TOTAL_GUESSES = 6
SOLUTION_PATH = "../words/solution.csv"
VALID_WORDS_PATH = "../words/guess.csv"
class LetterState(Enum):
ABSENT = 0
PRESENT = 1
CORRECT_POSITION = 2
class WordleEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def _current_path(self):
return os.path.dirname(os.path.abspath(__file__))
def _read_solutions(self):
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
def _get_valid_words(self):
words = []
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
words.append((word, Counter(word)))
return words
def get_valid(self):
return self._valid_words
def __init__(self):
self._solutions = self._read_solutions()
self._valid_words = self._get_valid_words()
self.action_space = spaces.Discrete(len(self._valid_words))
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
np.random.seed(0)
self.reset()
def _check_guess(self, guess, guess_counter):
c = guess_counter & self.solution_ct
result = []
correct = True
reward = 0
for i, char in enumerate(guess):
if c.get(char, 0) > 0:
if self.solution[i] == char:
result.append(2)
reward += 2
else:
result.append(1)
correct = False
reward += 1
c[char] -= 1
else:
result.append(0)
correct = False
return result, correct, reward
def step(self, action):
"""
action: index of word in valid_words
returns:
observation: (TOTAL_GUESSES, WORD_LENGTH)
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
done: True if game over, w/ or w/o correct answer
additional_info: empty
"""
guess, guess_counter = self._valid_words[action]
if guess in self.guesses:
return self.obs, -1, False, {}
self.guesses.append(guess)
result, correct, reward = self._check_guess(guess, guess_counter)
done = False
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
self.guess_no += 1
if correct:
done = True
reward = 1200
if self.guess_no == TOTAL_GUESSES:
done = True
if not correct:
reward = -15
return self.obs, reward, done, {}
def reset(self):
self.solution = self._solutions[np.random.randint(len(self._solutions))]
self.solution_ct = Counter(self.solution)
self.guess_no = 0
self.guesses = []
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
return self.obs
def render(self, mode="human"):
m = {
0: "",
1: "🟨",
2: "🟩"
}
print("Solution:", self.solution)
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
o_n = "".join(map(lambda x: m[x], o))
print(g, o_n)
def close(self):
pass
if __name__ == "__main__":
env = WordleEnv()
print(env.action_space)
print(env.observation_space)
print(env.solution)
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))

12972
wordle_gym/words/guess.csv Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff