mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-12-26 09:39:10 +00:00
titanlearn v 2.0.1.000
This commit is contained in:
parent
dacf12f8a4
commit
ab9b38da95
Binary file not shown.
@ -7,10 +7,13 @@
|
|||||||
# this module learns from its mistakes far faster than 2022's captains
|
# this module learns from its mistakes far faster than 2022's captains
|
||||||
# setup:
|
# setup:
|
||||||
|
|
||||||
__version__ = "2.0.0.001"
|
__version__ = "2.0.1.000"
|
||||||
|
|
||||||
#changelog should be viewed using print(analysis.__changelog__)
|
#changelog should be viewed using print(analysis.__changelog__)
|
||||||
__changelog__ = """changelog:
|
__changelog__ = """changelog:
|
||||||
|
2.0.1.000:
|
||||||
|
- added net, dataset, dataloader, and stdtrain template definitions
|
||||||
|
- added graphloss function
|
||||||
2.0.0.001:
|
2.0.0.001:
|
||||||
- added clear functions
|
- added clear functions
|
||||||
2.0.0.000:
|
2.0.0.000:
|
||||||
@ -33,6 +36,8 @@ __all__ = [
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from os import system, name
|
from os import system, name
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def clear():
|
def clear():
|
||||||
if name == 'nt':
|
if name == 'nt':
|
||||||
@ -40,7 +45,34 @@ def clear():
|
|||||||
else:
|
else:
|
||||||
_ = system('clear')
|
_ = system('clear')
|
||||||
|
|
||||||
def train(device, net, epochs, trainloader, optimizer, criterion):
|
class net(torch.nn.Module): #template for standard neural net
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class dataset(torch.utils.data.Dataset): #template for standard dataset
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(torch.utils.data.Dataset).__init__()
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dataloader(dataset, batch_size, num_workers, shuffle = True):
|
||||||
|
|
||||||
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
||||||
|
|
||||||
|
def train(device, net, epochs, trainloader, optimizer, criterion): #expects standard dataloader, whch returns (inputs, labels)
|
||||||
|
|
||||||
|
dataset_len = trainloader.dataset.__len__()
|
||||||
|
iter_count = 0
|
||||||
|
running_loss = 0
|
||||||
|
running_loss_list = []
|
||||||
|
|
||||||
for epoch in range(epochs): # loop over the dataset multiple times
|
for epoch in range(epochs): # loop over the dataset multiple times
|
||||||
|
|
||||||
@ -56,6 +88,35 @@ def train(device, net, epochs, trainloader, optimizer, criterion):
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# monitoring steps below
|
||||||
|
|
||||||
|
iter_count += 1
|
||||||
|
running_loss += loss.item()
|
||||||
|
running_loss_list.append(running_loss)
|
||||||
|
clear()
|
||||||
|
|
||||||
|
print("training on: " + device)
|
||||||
|
print("iteration: " + str(i) + "/" + str(int(dataset_len / trainloader.batch_size)) + " | " + "epoch: " + str(epoch) + "/" + str(epochs))
|
||||||
|
print("current batch loss: " + str(loss.item))
|
||||||
|
print("running loss: " + str(running_loss / iter_count))
|
||||||
|
|
||||||
return net
|
return net, running_loss_list
|
||||||
print("finished training")
|
print("finished training")
|
||||||
|
|
||||||
|
def stdtrainer(net, criterion, optimizer, dataloader, epochs, batch_size):
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
net = net.to(device)
|
||||||
|
criterion = criterion.to(device)
|
||||||
|
optimizer = optimizer.to(device)
|
||||||
|
trainloader = dataloader
|
||||||
|
|
||||||
|
return train(device, net, epochs, trainloader, optimizer, criterion)
|
||||||
|
|
||||||
|
def graphloss(losses):
|
||||||
|
|
||||||
|
x = range(0, len(losses))
|
||||||
|
plt.plot(x, losses)
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue
Block a user