mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
quick change
This commit is contained in:
parent
af20fb0fa7
commit
a853e9b02b
@ -256,7 +256,7 @@ def regression(device, inputs, outputs, args, loss = torch.nn.MSELoss(), _iterat
|
||||
|
||||
Regression.set_device(device)
|
||||
|
||||
if 'linear' in args:
|
||||
if 'lin' in args:
|
||||
|
||||
model = Regression.SGDTrain(Regression.LinearRegKernel(len(inputs)), torch.tensor(inputs).to(torch.float).cuda(), torch.tensor([outputs]).to(torch.float).cuda(), iterations=_iterations, learning_rate=lr, return_losses=True)
|
||||
regressions.append([model[0].parameters, model[1][::-1][0]])
|
||||
|
Loading…
Reference in New Issue
Block a user