mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
quick change
This commit is contained in:
parent
61e2cad206
commit
5eeec6e730
@ -256,7 +256,7 @@ def regression(device, inputs, outputs, args, loss = torch.nn.MSELoss(), _iterat
|
|||||||
|
|
||||||
Regression.set_device(device)
|
Regression.set_device(device)
|
||||||
|
|
||||||
if 'linear' in args:
|
if 'lin' in args:
|
||||||
|
|
||||||
model = Regression.SGDTrain(Regression.LinearRegKernel(len(inputs)), torch.tensor(inputs).to(torch.float).cuda(), torch.tensor([outputs]).to(torch.float).cuda(), iterations=_iterations, learning_rate=lr, return_losses=True)
|
model = Regression.SGDTrain(Regression.LinearRegKernel(len(inputs)), torch.tensor(inputs).to(torch.float).cuda(), torch.tensor([outputs]).to(torch.float).cuda(), iterations=_iterations, learning_rate=lr, return_losses=True)
|
||||||
regressions.append([model[0].parameters, model[1][::-1][0]])
|
regressions.append([model[0].parameters, model[1][::-1][0]])
|
||||||
|
Loading…
Reference in New Issue
Block a user