fix PolyRegKernel

This commit is contained in:
jlevine18 2019-10-09 22:23:56 -05:00 committed by GitHub
parent 8eac3d5af1
commit 7c121d48fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -530,13 +530,16 @@ class Regression:
return 1 return 1
else: else:
return n*self.factorial(n-1) return n*self.factorial(n-1)
def take_all_pwrs(self, vec,pwr): def take_all_pwrs(self, vec, pwr):
#todo: vectorize (kinda) #todo: vectorize (kinda)
combins=torch.combinations(vec, r=pwr, with_replacement=True) combins=torch.combinations(vec, r=pwr, with_replacement=True)
out=torch.ones(combins.size()[0]) out=torch.ones(combins.size()[0]).to(device).to(torch.float)
for i in torch.t(combins): for i in torch.t(combins).to(device).to(torch.float):
out *= i out *= i
return torch.cat(out,take_all_pwrs(vec, pwr-1)) if pwr == 1:
return out
else:
return torch.cat((out,self.take_all_pwrs(vec, pwr-1)))
def forward(self,mtx): def forward(self,mtx):
#TODO: Vectorize the last part #TODO: Vectorize the last part
cols=[] cols=[]