mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-10 07:04:45 +00:00
45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
import math
|
|
import random
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
from collections import namedtuple, deque
|
|
from itertools import count
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
|
|
|
|
Transition = namedtuple('Transition',
|
|
('state', 'action', 'next_state', 'reward'))
|
|
|
|
|
|
class ReplayMemory(object):
|
|
|
|
def __init__(self, capacity: int) -> None:
|
|
self.memory = deque([], maxlen=capacity)
|
|
|
|
def push(self, *args):
|
|
self.memory.append(Transition(*args))
|
|
|
|
def sample(self, batch_size):
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
|
|
class DQN(nn.Module):
|
|
|
|
def __init__(self, n_observations: int, n_actions: int) -> None:
|
|
super(DQN, self).__init__()
|
|
self.layer1 = nn.Linear(n_observations, 128)
|
|
self.layer2 = nn.Linear(128, 128)
|
|
self.layer3 = nn.Linear(128, n_actions)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.layer1(x))
|
|
x = F.relu(self.layer2(x))
|
|
return self.layer3(x)
|