cse151b-final-project/wordle_gym/envs/wordle_env.py

131 lines
3.8 KiB
Python
Raw Permalink Normal View History

2024-03-14 04:36:26 +00:00
import os
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from enum import Enum
from collections import Counter
import numpy as np
WORD_LENGTH = 5
TOTAL_GUESSES = 6
SOLUTION_PATH = "../words/solution.csv"
VALID_WORDS_PATH = "../words/guess.csv"
class LetterState(Enum):
ABSENT = 0
PRESENT = 1
CORRECT_POSITION = 2
class WordleEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def _current_path(self):
return os.path.dirname(os.path.abspath(__file__))
def _read_solutions(self):
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
def _get_valid_words(self):
words = []
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
words.append((word, Counter(word)))
return words
def get_valid(self):
return self._valid_words
def __init__(self):
self._solutions = self._read_solutions()
self._valid_words = self._get_valid_words()
self.action_space = spaces.Discrete(len(self._valid_words))
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
np.random.seed(0)
self.reset()
def _check_guess(self, guess, guess_counter):
c = guess_counter & self.solution_ct
result = []
correct = True
reward = 0
for i, char in enumerate(guess):
if c.get(char, 0) > 0:
if self.solution[i] == char:
result.append(2)
reward += 2
else:
result.append(1)
correct = False
reward += 1
c[char] -= 1
else:
result.append(0)
correct = False
return result, correct, reward
def step(self, action):
"""
action: index of word in valid_words
returns:
observation: (TOTAL_GUESSES, WORD_LENGTH)
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
done: True if game over, w/ or w/o correct answer
additional_info: empty
"""
guess, guess_counter = self._valid_words[action]
if guess in self.guesses:
return self.obs, -1, False, {}
self.guesses.append(guess)
result, correct, reward = self._check_guess(guess, guess_counter)
done = False
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
self.guess_no += 1
if correct:
done = True
reward = 1200
if self.guess_no == TOTAL_GUESSES:
done = True
if not correct:
reward = -15
return self.obs, reward, done, {}
def reset(self):
self.solution = self._solutions[np.random.randint(len(self._solutions))]
self.solution_ct = Counter(self.solution)
self.guess_no = 0
self.guesses = []
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
return self.obs
def render(self, mode="human"):
m = {
0: "",
1: "🟨",
2: "🟩"
}
print("Solution:", self.solution)
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
o_n = "".join(map(lambda x: m[x], o))
print(g, o_n)
def close(self):
pass
if __name__ == "__main__":
env = WordleEnv()
print(env.action_space)
print(env.observation_space)
print(env.solution)
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))