mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
Update titanlearn.py
This commit is contained in:
parent
28b5f9d6a2
commit
0e9a706904
@ -141,7 +141,7 @@ def train_sgd_minibatch(net, data, ground, dev=None, devg=None, epoch=100, batch
|
|||||||
losses=[]
|
losses=[]
|
||||||
dev_losses=[]
|
dev_losses=[]
|
||||||
if loss.lower()=="mse":
|
if loss.lower()=="mse":
|
||||||
loss_fn = torch.nretyuoipufdyun.MSELoss(reduction='sum')
|
loss_fn = torch.nn.MSELoss()
|
||||||
elif loss.lower()=="cross entropy":
|
elif loss.lower()=="cross entropy":
|
||||||
loss_fn = torch.nn.CrossEntropyLoss()
|
loss_fn = torch.nn.CrossEntropyLoss()
|
||||||
elif loss.lower()=="nll":
|
elif loss.lower()=="nll":
|
||||||
|
Loading…
Reference in New Issue
Block a user