import math import random import matplotlib import matplotlib.pyplot as plt from collections import namedtuple, deque from itertools import count import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) class ReplayMemory(object): def __init__(self, capacity: int) -> None: self.memory = deque([], maxlen=capacity) def push(self, *args): self.memory.append(Transition(*args)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class DQN(nn.Module): def __init__(self, n_observations: int, n_actions: int) -> None: super(DQN, self).__init__() self.layer1 = nn.Linear(n_observations, 128) self.layer2 = nn.Linear(128, 128) self.layer3 = nn.Linear(128, n_actions) def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x)