fix PolyRegKernel

This commit is contained in:
jlevine18 2019-10-09 22:23:56 -05:00 committed by GitHub
parent 6a082825eb
commit e47d2ab142

View File

@ -533,10 +533,13 @@ class Regression:
def take_all_pwrs(self, vec, pwr): def take_all_pwrs(self, vec, pwr):
#todo: vectorize (kinda) #todo: vectorize (kinda)
combins=torch.combinations(vec, r=pwr, with_replacement=True) combins=torch.combinations(vec, r=pwr, with_replacement=True)
out=torch.ones(combins.size()[0]) out=torch.ones(combins.size()[0]).to(device).to(torch.float)
for i in torch.t(combins): for i in torch.t(combins).to(device).to(torch.float):
out *= i out *= i
return torch.cat(out,take_all_pwrs(vec, pwr-1)) if pwr == 1:
return out
else:
return torch.cat((out,self.take_all_pwrs(vec, pwr-1)))
def forward(self,mtx): def forward(self,mtx):
#TODO: Vectorize the last part #TODO: Vectorize the last part
cols=[] cols=[]