mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
fix PolyRegKernel
This commit is contained in:
parent
6a082825eb
commit
e47d2ab142
@ -530,13 +530,16 @@ class Regression:
|
||||
return 1
|
||||
else:
|
||||
return n*self.factorial(n-1)
|
||||
def take_all_pwrs(self, vec,pwr):
|
||||
def take_all_pwrs(self, vec, pwr):
|
||||
#todo: vectorize (kinda)
|
||||
combins=torch.combinations(vec, r=pwr, with_replacement=True)
|
||||
out=torch.ones(combins.size()[0])
|
||||
for i in torch.t(combins):
|
||||
out=torch.ones(combins.size()[0]).to(device).to(torch.float)
|
||||
for i in torch.t(combins).to(device).to(torch.float):
|
||||
out *= i
|
||||
return torch.cat(out,take_all_pwrs(vec, pwr-1))
|
||||
if pwr == 1:
|
||||
return out
|
||||
else:
|
||||
return torch.cat((out,self.take_all_pwrs(vec, pwr-1)))
|
||||
def forward(self,mtx):
|
||||
#TODO: Vectorize the last part
|
||||
cols=[]
|
||||
|
Loading…
Reference in New Issue
Block a user