Update titanlearn.py

This commit is contained in:
jlevine18 2019-03-01 12:25:41 -06:00 committed by GitHub
parent 5889978f1d
commit 145e48fc89

View File

@ -141,7 +141,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
losses=[] losses=[]
dev_losses=[] dev_losses=[]
if loss.lower()=="mse": if loss.lower()=="mse":
loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum') loss_fn = torch.nn.MSELoss()
elif loss.lower()=="cross entropy": elif loss.lower()=="cross entropy":
loss_fn = torch.nn.CrossEntropyLoss() loss_fn = torch.nn.CrossEntropyLoss()
elif loss.lower()=="nll": elif loss.lower()=="nll":