From 0e9a706904f70e10dbc7a107471a281bea1dec15 Mon Sep 17 00:00:00 2001 From: jlevine18 Date: Fri, 1 Mar 2019 12:25:41 -0600 Subject: [PATCH] Update titanlearn.py --- data analysis/titanlearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data analysis/titanlearn.py b/data analysis/titanlearn.py index 2b895610..1340e49c 100644 --- a/data analysis/titanlearn.py +++ b/data analysis/titanlearn.py @@ -141,7 +141,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch losses=[] dev_losses=[] if loss.lower()=="mse": - loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum') + loss_fn = torch.nn.MSELoss() elif loss.lower()=="cross entropy": loss_fn = torch.nn.CrossEntropyLoss() elif loss.lower()=="nll":