mirror of
https://github.com/ltcptgeneral/cse151b-final-project.git
synced 2024-11-10 07:04:45 +00:00
92 lines
4.0 KiB
Python
92 lines
4.0 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
class Agent:
|
||
|
|
||
|
def __init__(self, ) -> None:
|
||
|
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
||
|
# GAMMA is the discount factor as mentioned in the previous section
|
||
|
# EPS_START is the starting value of epsilon
|
||
|
# EPS_END is the final value of epsilon
|
||
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||
|
# TAU is the update rate of the target network
|
||
|
# LR is the learning rate of the ``AdamW`` optimizer
|
||
|
self.batch_size = 128
|
||
|
self.gamma = 0.99
|
||
|
self.eps_start = 0.9
|
||
|
self.eps_end = 0.05
|
||
|
self.eps_decay = 1000
|
||
|
self.tau = 0.005
|
||
|
self.lr = 1e-4
|
||
|
self.n_actions = n_actions
|
||
|
|
||
|
policy_net = DQN(n_observations, n_actions).to(device)
|
||
|
target_net = DQN(n_observations, n_actions).to(device)
|
||
|
target_net.load_state_dict(policy_net.state_dict())
|
||
|
|
||
|
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
||
|
memory = ReplayMemory(10000)
|
||
|
|
||
|
def get_state(self, game):
|
||
|
pass
|
||
|
|
||
|
def select_action(state):
|
||
|
sample = random.random()
|
||
|
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
|
||
|
math.exp(-1. * steps_done / EPS_DECAY)
|
||
|
steps_done += 1
|
||
|
if sample > eps_threshold:
|
||
|
with torch.no_grad():
|
||
|
# t.max(1) will return the largest column value of each row.
|
||
|
# second column on max result is index of where max element was
|
||
|
# found, so we pick action with the larger expected reward.
|
||
|
return policy_net(state).max(1).indices.view(1, 1)
|
||
|
else:
|
||
|
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
|
||
|
|
||
|
def optimize_model():
|
||
|
if len(memory) < BATCH_SIZE:
|
||
|
return
|
||
|
transitions = memory.sample(BATCH_SIZE)
|
||
|
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
||
|
# detailed explanation). This converts batch-array of Transitions
|
||
|
# to Transition of batch-arrays.
|
||
|
batch = Transition(*zip(*transitions))
|
||
|
|
||
|
# Compute a mask of non-final states and concatenate the batch elements
|
||
|
# (a final state would've been the one after which simulation ended)
|
||
|
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
||
|
batch.next_state)), device=device, dtype=torch.bool)
|
||
|
non_final_next_states = torch.cat([s for s in batch.next_state
|
||
|
if s is not None])
|
||
|
state_batch = torch.cat(batch.state)
|
||
|
action_batch = torch.cat(batch.action)
|
||
|
reward_batch = torch.cat(batch.reward)
|
||
|
|
||
|
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
||
|
# columns of actions taken. These are the actions which would've been taken
|
||
|
# for each batch state according to policy_net
|
||
|
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||
|
|
||
|
# Compute V(s_{t+1}) for all next states.
|
||
|
# Expected values of actions for non_final_next_states are computed based
|
||
|
# on the "older" target_net; selecting their best reward with max(1).values
|
||
|
# This is merged based on the mask, such that we'll have either the expected
|
||
|
# state value or 0 in case the state was final.
|
||
|
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
||
|
with torch.no_grad():
|
||
|
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
|
||
|
# Compute the expected Q values
|
||
|
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
||
|
|
||
|
# Compute Huber loss
|
||
|
criterion = nn.SmoothL1Loss()
|
||
|
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
||
|
|
||
|
# Optimize the model
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
# In-place gradient clipping
|
||
|
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
||
|
optimizer.step()
|