From 5eeec6e73081aa380c9de6411ed811a4d3614152 Mon Sep 17 00:00:00 2001 From: art Date: Fri, 4 Oct 2019 10:37:29 -0500 Subject: [PATCH] quick change --- data analysis/analysis/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data analysis/analysis/analysis.py b/data analysis/analysis/analysis.py index bffd5703..4ba96a6a 100644 --- a/data analysis/analysis/analysis.py +++ b/data analysis/analysis/analysis.py @@ -256,7 +256,7 @@ def regression(device, inputs, outputs, args, loss = torch.nn.MSELoss(), _iterat Regression.set_device(device) - if 'linear' in args: + if 'lin' in args: model = Regression.SGDTrain(Regression.LinearRegKernel(len(inputs)), torch.tensor(inputs).to(torch.float).cuda(), torch.tensor([outputs]).to(torch.float).cuda(), iterations=_iterations, learning_rate=lr, return_losses=True) regressions.append([model[0].parameters, model[1][::-1][0]])