mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
tl.py
This commit is contained in:
parent
a2763d0e07
commit
d3b8329164
@ -26,11 +26,10 @@ __all__ = [
|
|||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from sklearn import metrics
|
from sklearn import metrics, datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import math
|
import math
|
||||||
from sklearn import datasets
|
|
||||||
|
|
||||||
#enable CUDA if possible
|
#enable CUDA if possible
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
@ -38,22 +37,22 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|||||||
#linear_nn: creates a fully connected network given params
|
#linear_nn: creates a fully connected network given params
|
||||||
def linear_nn(in_dim, hidden_dim, out_dim, num_hidden, act_fn="tanh", end="none"):
|
def linear_nn(in_dim, hidden_dim, out_dim, num_hidden, act_fn="tanh", end="none"):
|
||||||
if act_fn.lower()=="tanh":
|
if act_fn.lower()=="tanh":
|
||||||
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('tanh0', torch.nn.Tanh())])
|
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
|
||||||
for i in range(num_hidden):
|
for i in range(num_hidden):
|
||||||
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "tanh"+str(i+1):torch.nn.Tanh()})
|
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "tanh"+str(i+1):torch.nn.Tanh()})
|
||||||
|
|
||||||
elif act_fn.lower()=="sigmoid":
|
elif act_fn.lower()=="sigmoid":
|
||||||
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('sig0', torch.nn.Sigmoid())])
|
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
|
||||||
for i in range(num_hidden):
|
for i in range(num_hidden):
|
||||||
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "sig"+str(i+1):torch.nn.Sigmoid()})
|
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "sig"+str(i+1):torch.nn.Sigmoid()})
|
||||||
|
|
||||||
elif act_fn.lower()=="relu":
|
elif act_fn.lower()=="relu":
|
||||||
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('relu0', torch.nn.ReLU())])
|
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
|
||||||
for i in range(num_hidden):
|
for i in range(num_hidden):
|
||||||
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "relu"+str(i+1):torch.nn.ReLU()})
|
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "relu"+str(i+1):torch.nn.ReLU()})
|
||||||
|
|
||||||
elif act_fn.lower()=="leaky relu":
|
elif act_fn.lower()=="leaky relu":
|
||||||
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('lre0', torch.nn.LeakyReLU())])
|
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
|
||||||
for i in range(num_hidden):
|
for i in range(num_hidden):
|
||||||
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "lre"+str(i+1):torch.nn.LeakyReLU()})
|
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "lre"+str(i+1):torch.nn.LeakyReLU()})
|
||||||
else:
|
else:
|
||||||
@ -113,8 +112,8 @@ def train_sgd_simple(net, evalType, data, ground, dev=None, devg=None, iters=100
|
|||||||
dev_losses.append(ap)
|
dev_losses.append(ap)
|
||||||
plt.plot(np.array(range(0,i+1,testevery)),np.array(losses), label="dev AP")
|
plt.plot(np.array(range(0,i+1,testevery)),np.array(losses), label="dev AP")
|
||||||
elif evalType == "regression":
|
elif evalType == "regression":
|
||||||
ap = metrics.explained_variance_score(devg.numpy(), output.numpy())
|
ev = metrics.explained_variance_score(devg.numpy(), output.numpy())
|
||||||
dev_losses.append(ap)
|
dev_losses.append(ev)
|
||||||
plt.plot(np.array(range(0,i+1,testevery)),np.array(losses), label="dev EV")
|
plt.plot(np.array(range(0,i+1,testevery)),np.array(losses), label="dev EV")
|
||||||
|
|
||||||
|
|
||||||
@ -191,9 +190,12 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
|
|||||||
plt.show()
|
plt.show()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
data = datasets.load_diabetes()
|
def retyuoipufdyu():
|
||||||
print(data["data"], data["target"])
|
|
||||||
ground = torch.tensor(data["target"]).to(torch.float)
|
data = torch.tensor(datasets.fetch_california_housing()['data']).to(torch.float)
|
||||||
data = torch.tensor(data["data"]).to(torch.float)
|
ground = datasets.fetch_california_housing()['target']
|
||||||
model = linear_nn(10, 100, 1, 20, act_fn = "tanh")
|
ground=torch.tensor(ground).to(torch.float)
|
||||||
model = train_sgd_simple(model,"regression", data, ground, learnrate=1e-4)
|
model = linear_nn(8, 100, 1, 20, act_fn = "relu")
|
||||||
|
print(model)
|
||||||
|
return train_sgd_simple(model,"regression", data, ground, learnrate=1e-4, iters=1000)
|
||||||
|
retyuoipufdyu()
|
||||||
|
Loading…
Reference in New Issue
Block a user