This commit is contained in:
jlevine18 2019-03-02 08:18:28 -06:00 committed by GitHub
parent 791c4e82a5
commit e98e66bdf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -26,11 +26,10 @@ __all__ = [
import torch
import warnings
from collections import OrderedDict
from sklearn import metrics
from sklearn import metrics, datasets
import numpy as np
import matplotlib.pyplot as plt
import math
from sklearn import datasets
#enable CUDA if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -38,22 +37,22 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#linear_nn: creates a fully connected network given params
def linear_nn(in_dim, hidden_dim, out_dim, num_hidden, act_fn="tanh", end="none"):
if act_fn.lower()=="tanh":
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('tanh0', torch.nn.Tanh())])
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
for i in range(num_hidden):
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "tanh"+str(i+1):torch.nn.Tanh()})
elif act_fn.lower()=="sigmoid":
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('sig0', torch.nn.Sigmoid())])
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
for i in range(num_hidden):
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "sig"+str(i+1):torch.nn.Sigmoid()})
elif act_fn.lower()=="relu":
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('relu0', torch.nn.ReLU())])
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
for i in range(num_hidden):
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "relu"+str(i+1):torch.nn.ReLU()})
elif act_fn.lower()=="leaky relu":
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim)), ('lre0', torch.nn.LeakyReLU())])
k=OrderedDict([("in", torch.nn.Linear(in_dim,hidden_dim))])
for i in range(num_hidden):
k.update({"lin"+str(i+1): torch.nn.Linear(hidden_dim,hidden_dim), "lre"+str(i+1):torch.nn.LeakyReLU()})
else:
@ -113,8 +112,8 @@ def train_sgd_simple(net, evalType, data, ground, dev=None, devg=None, iters=100
dev_losses.append(ap)
plt.plot(np.array(range(0,i+1,testevery)),np.array(losses), label="dev AP")
elif evalType == "regression":
ap = metrics.explained_variance_score(devg.numpy(), output.numpy())
dev_losses.append(ap)
ev = metrics.explained_variance_score(devg.numpy(), output.numpy())
dev_losses.append(ev)
plt.plot(np.array(range(0,i+1,testevery)),np.array(losses), label="dev EV")
@ -191,9 +190,12 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
plt.show()
return model
data = datasets.load_diabetes()
print(data["data"], data["target"])
ground = torch.tensor(data["target"]).to(torch.float)
data = torch.tensor(data["data"]).to(torch.float)
model = linear_nn(10, 100, 1, 20, act_fn = "tanh")
model = train_sgd_simple(model,"regression", data, ground, learnrate=1e-4)
def retyuoipufdyu():
data = torch.tensor(datasets.fetch_california_housing()['data']).to(torch.float)
ground = datasets.fetch_california_housing()['target']
ground=torch.tensor(ground).to(torch.float)
model = linear_nn(8, 100, 1, 20, act_fn = "relu")
print(model)
return train_sgd_simple(model,"regression", data, ground, learnrate=1e-4, iters=1000)
retyuoipufdyu()