mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-12-26 01:59:10 +00:00
created custom env folder
This commit is contained in:
parent
c121415e31
commit
7ad5b97463
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
||||
**/data/*
|
||||
**/data/*
|
||||
/env
|
91
custom_env/agent.py
Normal file
91
custom_env/agent.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Agent:
|
||||
|
||||
def __init__(self, ) -> None:
|
||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
# EPS_START is the starting value of epsilon
|
||||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
# TAU is the update rate of the target network
|
||||
# LR is the learning rate of the ``AdamW`` optimizer
|
||||
self.batch_size = 128
|
||||
self.gamma = 0.99
|
||||
self.eps_start = 0.9
|
||||
self.eps_end = 0.05
|
||||
self.eps_decay = 1000
|
||||
self.tau = 0.005
|
||||
self.lr = 1e-4
|
||||
self.n_actions = n_actions
|
||||
|
||||
policy_net = DQN(n_observations, n_actions).to(device)
|
||||
target_net = DQN(n_observations, n_actions).to(device)
|
||||
target_net.load_state_dict(policy_net.state_dict())
|
||||
|
||||
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
||||
memory = ReplayMemory(10000)
|
||||
|
||||
def get_state(self, game):
|
||||
pass
|
||||
|
||||
def select_action(state):
|
||||
sample = random.random()
|
||||
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
|
||||
math.exp(-1. * steps_done / EPS_DECAY)
|
||||
steps_done += 1
|
||||
if sample > eps_threshold:
|
||||
with torch.no_grad():
|
||||
# t.max(1) will return the largest column value of each row.
|
||||
# second column on max result is index of where max element was
|
||||
# found, so we pick action with the larger expected reward.
|
||||
return policy_net(state).max(1).indices.view(1, 1)
|
||||
else:
|
||||
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
|
||||
|
||||
def optimize_model():
|
||||
if len(memory) < BATCH_SIZE:
|
||||
return
|
||||
transitions = memory.sample(BATCH_SIZE)
|
||||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
||||
# detailed explanation). This converts batch-array of Transitions
|
||||
# to Transition of batch-arrays.
|
||||
batch = Transition(*zip(*transitions))
|
||||
|
||||
# Compute a mask of non-final states and concatenate the batch elements
|
||||
# (a final state would've been the one after which simulation ended)
|
||||
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
||||
batch.next_state)), device=device, dtype=torch.bool)
|
||||
non_final_next_states = torch.cat([s for s in batch.next_state
|
||||
if s is not None])
|
||||
state_batch = torch.cat(batch.state)
|
||||
action_batch = torch.cat(batch.action)
|
||||
reward_batch = torch.cat(batch.reward)
|
||||
|
||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
||||
# columns of actions taken. These are the actions which would've been taken
|
||||
# for each batch state according to policy_net
|
||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||||
|
||||
# Compute V(s_{t+1}) for all next states.
|
||||
# Expected values of actions for non_final_next_states are computed based
|
||||
# on the "older" target_net; selecting their best reward with max(1).values
|
||||
# This is merged based on the mask, such that we'll have either the expected
|
||||
# state value or 0 in case the state was final.
|
||||
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
||||
with torch.no_grad():
|
||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
|
||||
# Compute the expected Q values
|
||||
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
||||
|
||||
# Compute Huber loss
|
||||
criterion = nn.SmoothL1Loss()
|
||||
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
||||
|
||||
# Optimize the model
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# In-place gradient clipping
|
||||
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
||||
optimizer.step()
|
16
custom_env/create_wordlist.py
Normal file
16
custom_env/create_wordlist.py
Normal file
@ -0,0 +1,16 @@
|
||||
import pathlib
|
||||
import sys
|
||||
from string import ascii_letters
|
||||
|
||||
in_path = pathlib.Path(sys.argv[1])
|
||||
out_path = pathlib.Path(sys.argv[2])
|
||||
|
||||
words = sorted(
|
||||
{
|
||||
word.lower()
|
||||
for word in in_path.read_text(encoding="utf-8").split()
|
||||
if all(letter in ascii_letters for letter in word)
|
||||
},
|
||||
key=lambda word: (len(word), word),
|
||||
)
|
||||
out_path.write_text("\n".join(words))
|
5757
custom_env/five_letter_words.txt
Normal file
5757
custom_env/five_letter_words.txt
Normal file
File diff suppressed because it is too large
Load Diff
44
custom_env/model.py
Normal file
44
custom_env/model.py
Normal file
@ -0,0 +1,44 @@
|
||||
import math
|
||||
import random
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import namedtuple, deque
|
||||
from itertools import count
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
Transition = namedtuple('Transition',
|
||||
('state', 'action', 'next_state', 'reward'))
|
||||
|
||||
|
||||
class ReplayMemory(object):
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.memory = deque([], maxlen=capacity)
|
||||
|
||||
def push(self, *args):
|
||||
self.memory.append(Transition(*args))
|
||||
|
||||
def sample(self, batch_size):
|
||||
return random.sample(self.memory, batch_size)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
|
||||
def __init__(self, n_observations: int, n_actions: int) -> None:
|
||||
super(DQN, self).__init__()
|
||||
self.layer1 = nn.Linear(n_observations, 128)
|
||||
self.layer2 = nn.Linear(128, 128)
|
||||
self.layer3 = nn.Linear(128, n_actions)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.layer1(x))
|
||||
x = F.relu(self.layer2(x))
|
||||
return self.layer3(x)
|
61
custom_env/test2.ipynb
Normal file
61
custom_env/test2.ipynb
Normal file
@ -0,0 +1,61 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from string import ascii_letters, ascii_uppercase, ascii_lowercase"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'ABCDEFGHIJKLMNOPQRSTUVWXYZ'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ascii_uppercase"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
1098
custom_env/wordlist.txt
Normal file
1098
custom_env/wordlist.txt
Normal file
File diff suppressed because it is too large
Load Diff
119
custom_env/wyrdl.py
Normal file
119
custom_env/wyrdl.py
Normal file
@ -0,0 +1,119 @@
|
||||
import contextlib
|
||||
import pathlib
|
||||
import random
|
||||
from string import ascii_letters, ascii_lowercase
|
||||
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
|
||||
console = Console(width=40, theme=Theme({"warning": "red on yellow"}))
|
||||
|
||||
NUM_LETTERS = 5
|
||||
NUM_GUESSES = 6
|
||||
WORDS_PATH = pathlib.Path(__file__).parent / "wordlist.txt"
|
||||
|
||||
|
||||
class Wordle:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.word_list = WORDS_PATH.read_text(encoding="utf-8").split("\n")
|
||||
self.n_guesses = 6
|
||||
self.num_letters = 5
|
||||
self.curr_word = None
|
||||
self.reset()
|
||||
|
||||
def refresh_page(self, headline):
|
||||
console.clear()
|
||||
console.rule(f"[bold blue]:leafy_green: {headline} :leafy_green:[/]\n")
|
||||
|
||||
def start_game(self):
|
||||
# get a new random word
|
||||
word = self.get_random_word(self.word_list)
|
||||
|
||||
self.curr_word = word
|
||||
|
||||
def get_state(self):
|
||||
return
|
||||
|
||||
def action_to_word(self, action):
|
||||
# Calculate the word from the array
|
||||
word = ''
|
||||
for i in range(0, len(ascii_lowercase), 26):
|
||||
# Find the index of 1 in each block of 26
|
||||
letter_index = action[i:i+26].index(1)
|
||||
# Append the corresponding letter to the word
|
||||
word += ascii_lowercase[letter_index]
|
||||
|
||||
return word
|
||||
|
||||
def play_guess(self, action):
|
||||
# probably an array of length 26 * 5 for 26 letters and 5 positions
|
||||
guess = action
|
||||
|
||||
def get_random_word(self, word_list):
|
||||
if words := [
|
||||
word.upper()
|
||||
for word in word_list
|
||||
if len(word) == NUM_LETTERS
|
||||
and all(letter in ascii_letters for letter in word)
|
||||
]:
|
||||
return random.choice(words)
|
||||
else:
|
||||
console.print(
|
||||
f"No words of length {NUM_LETTERS} in the word list",
|
||||
style="warning",
|
||||
)
|
||||
raise SystemExit()
|
||||
|
||||
def show_guesses(self, guesses, word):
|
||||
letter_status = {letter: letter for letter in ascii_lowercase}
|
||||
for guess in guesses:
|
||||
styled_guess = []
|
||||
for letter, correct in zip(guess, word):
|
||||
if letter == correct:
|
||||
style = "bold white on green"
|
||||
elif letter in word:
|
||||
style = "bold white on yellow"
|
||||
elif letter in ascii_letters:
|
||||
style = "white on #666666"
|
||||
else:
|
||||
style = "dim"
|
||||
styled_guess.append(f"[{style}]{letter}[/]")
|
||||
if letter != "_":
|
||||
letter_status[letter] = f"[{style}]{letter}[/]"
|
||||
|
||||
console.print("".join(styled_guess), justify="center")
|
||||
console.print("\n" + "".join(letter_status.values()), justify="center")
|
||||
|
||||
def guess_word(self, previous_guesses):
|
||||
guess = console.input("\nGuess word: ").upper()
|
||||
|
||||
if guess in previous_guesses:
|
||||
console.print(f"You've already guessed {guess}.", style="warning")
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
if len(guess) != NUM_LETTERS:
|
||||
console.print(
|
||||
f"Your guess must be {NUM_LETTERS} letters.", style="warning"
|
||||
)
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
if any((invalid := letter) not in ascii_letters for letter in guess):
|
||||
console.print(
|
||||
f"Invalid letter: '{invalid}'. Please use English letters.",
|
||||
style="warning",
|
||||
)
|
||||
return guess_word(previous_guesses)
|
||||
|
||||
return guess
|
||||
|
||||
def reset(self, guesses, word, guessed_correctly, n_episodes):
|
||||
refresh_page(headline=f"Game: {n_episodes}")
|
||||
|
||||
if guessed_correctly:
|
||||
console.print(f"\n[bold white on green]Correct, the word is {word}[/]")
|
||||
else:
|
||||
console.print(f"\n[bold white on red]Sorry, the word was {word}[/]")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user