mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-11-10 06:54:44 +00:00
fix PolyRegKernel
This commit is contained in:
parent
6a082825eb
commit
e47d2ab142
@ -530,13 +530,16 @@ class Regression:
|
|||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return n*self.factorial(n-1)
|
return n*self.factorial(n-1)
|
||||||
def take_all_pwrs(self, vec,pwr):
|
def take_all_pwrs(self, vec, pwr):
|
||||||
#todo: vectorize (kinda)
|
#todo: vectorize (kinda)
|
||||||
combins=torch.combinations(vec, r=pwr, with_replacement=True)
|
combins=torch.combinations(vec, r=pwr, with_replacement=True)
|
||||||
out=torch.ones(combins.size()[0])
|
out=torch.ones(combins.size()[0]).to(device).to(torch.float)
|
||||||
for i in torch.t(combins):
|
for i in torch.t(combins).to(device).to(torch.float):
|
||||||
out *= i
|
out *= i
|
||||||
return torch.cat(out,take_all_pwrs(vec, pwr-1))
|
if pwr == 1:
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return torch.cat((out,self.take_all_pwrs(vec, pwr-1)))
|
||||||
def forward(self,mtx):
|
def forward(self,mtx):
|
||||||
#TODO: Vectorize the last part
|
#TODO: Vectorize the last part
|
||||||
cols=[]
|
cols=[]
|
||||||
|
Loading…
Reference in New Issue
Block a user