diff --git a/data analysis/titanlearn.py b/data analysis/titanlearn.py index 2b895610..1340e49c 100644 --- a/data analysis/titanlearn.py +++ b/data analysis/titanlearn.py @@ -141,7 +141,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch losses=[] dev_losses=[] if loss.lower()=="mse": - loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum') + loss_fn = torch.nn.MSELoss() elif loss.lower()=="cross entropy": loss_fn = torch.nn.CrossEntropyLoss() elif loss.lower()=="nll":