cse151b-final-project/wordle_gym/envs/wordle_env.py
2024-03-13 21:36:26 -07:00

131 lines
3.8 KiB
Python

import os
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from enum import Enum
from collections import Counter
import numpy as np
WORD_LENGTH = 5
TOTAL_GUESSES = 6
SOLUTION_PATH = "../words/solution.csv"
VALID_WORDS_PATH = "../words/guess.csv"
class LetterState(Enum):
ABSENT = 0
PRESENT = 1
CORRECT_POSITION = 2
class WordleEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def _current_path(self):
return os.path.dirname(os.path.abspath(__file__))
def _read_solutions(self):
return open(os.path.join(self._current_path(), SOLUTION_PATH)).read().splitlines()
def _get_valid_words(self):
words = []
for word in open(os.path.join(self._current_path(), VALID_WORDS_PATH)).read().splitlines():
words.append((word, Counter(word)))
return words
def get_valid(self):
return self._valid_words
def __init__(self):
self._solutions = self._read_solutions()
self._valid_words = self._get_valid_words()
self.action_space = spaces.Discrete(len(self._valid_words))
self.observation_space = spaces.MultiDiscrete([3] * TOTAL_GUESSES * WORD_LENGTH)
np.random.seed(0)
self.reset()
def _check_guess(self, guess, guess_counter):
c = guess_counter & self.solution_ct
result = []
correct = True
reward = 0
for i, char in enumerate(guess):
if c.get(char, 0) > 0:
if self.solution[i] == char:
result.append(2)
reward += 2
else:
result.append(1)
correct = False
reward += 1
c[char] -= 1
else:
result.append(0)
correct = False
return result, correct, reward
def step(self, action):
"""
action: index of word in valid_words
returns:
observation: (TOTAL_GUESSES, WORD_LENGTH)
reward: 0 if incorrect, 1 if correct, -1 if game over w/o final answer being obtained
done: True if game over, w/ or w/o correct answer
additional_info: empty
"""
guess, guess_counter = self._valid_words[action]
if guess in self.guesses:
return self.obs, -1, False, {}
self.guesses.append(guess)
result, correct, reward = self._check_guess(guess, guess_counter)
done = False
for i in range(self.guess_no*WORD_LENGTH, self.guess_no*WORD_LENGTH + WORD_LENGTH):
self.obs[i] = result[i - self.guess_no*WORD_LENGTH]
self.guess_no += 1
if correct:
done = True
reward = 1200
if self.guess_no == TOTAL_GUESSES:
done = True
if not correct:
reward = -15
return self.obs, reward, done, {}
def reset(self):
self.solution = self._solutions[np.random.randint(len(self._solutions))]
self.solution_ct = Counter(self.solution)
self.guess_no = 0
self.guesses = []
self.obs = np.zeros((TOTAL_GUESSES * WORD_LENGTH, ))
return self.obs
def render(self, mode="human"):
m = {
0: "",
1: "🟨",
2: "🟩"
}
print("Solution:", self.solution)
for g, o in zip(self.guesses, np.reshape(self.obs, (TOTAL_GUESSES, WORD_LENGTH))):
o_n = "".join(map(lambda x: m[x], o))
print(g, o_n)
def close(self):
pass
if __name__ == "__main__":
env = WordleEnv()
print(env.action_space)
print(env.observation_space)
print(env.solution)
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))
print(env.step(0))