mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2025-09-10 01:07:21 +00:00
try penalizing duplicate guesses
This commit is contained in:
@@ -3,11 +3,10 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
from sty import fg, bg, ef, rs
|
||||
|
||||
from collections import Counter
|
||||
from collections import Counter, defaultdict
|
||||
from gym_wordle.utils import to_english, to_array, get_words
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WordList(gym.spaces.Discrete):
|
||||
"""Super class for defining a space of valid words according to a specified
|
||||
list.
|
||||
@@ -160,7 +159,7 @@ class WordleEnv(gym.Env):
|
||||
|
||||
self.n_rounds = 6
|
||||
self.n_letters = 5
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
||||
|
||||
def _highlighter(self, char: str, flag: int) -> str:
|
||||
"""Terminal renderer functionality. Properly highlights a character
|
||||
@@ -195,7 +194,7 @@ class WordleEnv(gym.Env):
|
||||
|
||||
self.state = np.zeros((self.n_rounds, 2 * self.n_letters), dtype=np.int64)
|
||||
|
||||
self.info = {'correct': False, 'guesses': set()}
|
||||
self.info = {'correct': False, 'guesses': defaultdict(int)}
|
||||
|
||||
return self.state, self.info
|
||||
|
||||
@@ -269,13 +268,13 @@ class WordleEnv(gym.Env):
|
||||
reward += np.sum(self.state[:, 5:] == 3) * -1
|
||||
|
||||
# guess same word as before
|
||||
hashable_action = tuple(action)
|
||||
hashable_action = to_english(action)
|
||||
if hashable_action in self.info['guesses']:
|
||||
reward += -10
|
||||
reward += -10 * self.info['guesses'][hashable_action]
|
||||
else: # guess different word
|
||||
reward += 10
|
||||
|
||||
self.info['guesses'].add(hashable_action)
|
||||
self.info['guesses'][hashable_action] += 1
|
||||
|
||||
# for game ending in win or loss
|
||||
reward += 10 if correct else -10 if done else 0
|
||||
|
Reference in New Issue
Block a user