This commit is contained in:
ltcptgeneral
2019-03-01 13:49:36 -06:00

View File

@@ -142,7 +142,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
losses=[]
dev_losses=[]
if loss.lower()=="mse":
loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum')
loss_fn = torch.nn.MSELoss()
elif loss.lower()=="cross entropy":
loss_fn = torch.nn.CrossEntropyLoss()
elif loss.lower()=="nll":