mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-12-27 09:59:10 +00:00
Update titanlearn.py
This commit is contained in:
parent
0ae47b072b
commit
7f312dcba9
@ -30,9 +30,10 @@ from sklearn import metrics, datasets
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import time
|
||||
|
||||
#enable CUDA if possible
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cpu")
|
||||
|
||||
#linear_nn: creates a fully connected network given params
|
||||
def linear_nn(in_dim, hidden_dim, out_dim, num_hidden, act_fn="tanh", end="none"):
|
||||
@ -198,4 +199,8 @@ def retyuoipufdyu():
|
||||
model = linear_nn(8, 100, 1, 20, act_fn = "relu")
|
||||
print(model)
|
||||
return train_sgd_simple(model,"regression", data, ground, learnrate=1e-4, iters=1000)
|
||||
#retyuoipufdyu()
|
||||
|
||||
start = time.time()
|
||||
retyuoipufdyu()
|
||||
end = time.time()
|
||||
print(end-start)
|
||||
|
Loading…
Reference in New Issue
Block a user