mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
Update titanlearn.py
This commit is contained in:
parent
5889978f1d
commit
145e48fc89
@ -141,7 +141,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
|
||||
losses=[]
|
||||
dev_losses=[]
|
||||
if loss.lower()=="mse":
|
||||
loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum')
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
elif loss.lower()=="cross entropy":
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
elif loss.lower()=="nll":
|
||||
|
Loading…
Reference in New Issue
Block a user