Update titanlearn.py

This commit is contained in:
jlevine18 2019-03-01 12:25:41 -06:00 committed by GitHub
parent 28b5f9d6a2
commit 0e9a706904
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,7 +141,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
losses=[] losses=[]
dev_losses=[] dev_losses=[]
if loss.lower()=="mse": if loss.lower()=="mse":
loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum') loss_fn = torch.nn.MSELoss()
elif loss.lower()=="cross entropy": elif loss.lower()=="cross entropy":
loss_fn = torch.nn.CrossEntropyLoss() loss_fn = torch.nn.CrossEntropyLoss()
elif loss.lower()=="nll": elif loss.lower()=="nll":