From 7c121d48fc17f8820016afecff1d5edf5cc118b9 Mon Sep 17 00:00:00 2001 From: jlevine18 Date: Wed, 9 Oct 2019 22:23:56 -0500 Subject: [PATCH] fix PolyRegKernel --- data analysis/analysis/analysis.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/data analysis/analysis/analysis.py b/data analysis/analysis/analysis.py index 78591b44..4082bb23 100644 --- a/data analysis/analysis/analysis.py +++ b/data analysis/analysis/analysis.py @@ -530,13 +530,16 @@ class Regression: return 1 else: return n*self.factorial(n-1) - def take_all_pwrs(self, vec,pwr): + def take_all_pwrs(self, vec, pwr): #todo: vectorize (kinda) combins=torch.combinations(vec, r=pwr, with_replacement=True) - out=torch.ones(combins.size()[0]) - for i in torch.t(combins): + out=torch.ones(combins.size()[0]).to(device).to(torch.float) + for i in torch.t(combins).to(device).to(torch.float): out *= i - return torch.cat(out,take_all_pwrs(vec, pwr-1)) + if pwr == 1: + return out + else: + return torch.cat((out,self.take_all_pwrs(vec, pwr-1))) def forward(self,mtx): #TODO: Vectorize the last part cols=[] @@ -692,4 +695,4 @@ class Gliko2: def did_not_compete(self): - self._preRatingRD() \ No newline at end of file + self._preRatingRD()