This commit is contained in:
ltcptgeneral 2019-03-01 13:49:36 -06:00
commit 791c4e82a5

View File

@ -142,7 +142,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
losses=[]
dev_losses=[]
if loss.lower()=="mse":
loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum')
loss_fn = torch.nn.MSELoss()
elif loss.lower()=="cross entropy":
loss_fn = torch.nn.CrossEntropyLoss()
elif loss.lower()=="nll":