diff --git a/data analysis/analysis/analysis.py b/data analysis/analysis/analysis.py index 78591b44..4082bb23 100644 --- a/data analysis/analysis/analysis.py +++ b/data analysis/analysis/analysis.py @@ -530,13 +530,16 @@ class Regression: return 1 else: return n*self.factorial(n-1) - def take_all_pwrs(self, vec,pwr): + def take_all_pwrs(self, vec, pwr): #todo: vectorize (kinda) combins=torch.combinations(vec, r=pwr, with_replacement=True) - out=torch.ones(combins.size()[0]) - for i in torch.t(combins): + out=torch.ones(combins.size()[0]).to(device).to(torch.float) + for i in torch.t(combins).to(device).to(torch.float): out *= i - return torch.cat(out,take_all_pwrs(vec, pwr-1)) + if pwr == 1: + return out + else: + return torch.cat((out,self.take_all_pwrs(vec, pwr-1))) def forward(self,mtx): #TODO: Vectorize the last part cols=[] @@ -692,4 +695,4 @@ class Gliko2: def did_not_compete(self): - self._preRatingRD() \ No newline at end of file + self._preRatingRD()