quick change

This commit is contained in:
art 2019-10-04 10:37:29 -05:00
parent 61e2cad206
commit 5eeec6e730

View File

@ -256,7 +256,7 @@ def regression(device, inputs, outputs, args, loss = torch.nn.MSELoss(), _iterat
Regression.set_device(device)
if 'linear' in args:
if 'lin' in args:
model = Regression.SGDTrain(Regression.LinearRegKernel(len(inputs)), torch.tensor(inputs).to(torch.float).cuda(), torch.tensor([outputs]).to(torch.float).cuda(), iterations=_iterations, learning_rate=lr, return_losses=True)
regressions.append([model[0].parameters, model[1][::-1][0]])