diff --git a/data analysis/analysis/analysis.py b/data analysis/analysis/analysis.py index bffd5703..4ba96a6a 100644 --- a/data analysis/analysis/analysis.py +++ b/data analysis/analysis/analysis.py @@ -256,7 +256,7 @@ def regression(device, inputs, outputs, args, loss = torch.nn.MSELoss(), _iterat Regression.set_device(device) - if 'linear' in args: + if 'lin' in args: model = Regression.SGDTrain(Regression.LinearRegKernel(len(inputs)), torch.tensor(inputs).to(torch.float).cuda(), torch.tensor([outputs]).to(torch.float).cuda(), iterations=_iterations, learning_rate=lr, return_losses=True) regressions.append([model[0].parameters, model[1][::-1][0]])