fix PolyRegKernel

This commit is contained in:
jlevine18 2019-10-09 22:23:56 -05:00 committed by GitHub
parent 6a082825eb
commit e47d2ab142

View File

@ -533,10 +533,13 @@ class Regression:
def take_all_pwrs(self, vec, pwr):
#todo: vectorize (kinda)
combins=torch.combinations(vec, r=pwr, with_replacement=True)
out=torch.ones(combins.size()[0])
for i in torch.t(combins):
out=torch.ones(combins.size()[0]).to(device).to(torch.float)
for i in torch.t(combins).to(device).to(torch.float):
out *= i
return torch.cat(out,take_all_pwrs(vec, pwr-1))
if pwr == 1:
return out
else:
return torch.cat((out,self.take_all_pwrs(vec, pwr-1)))
def forward(self,mtx):
#TODO: Vectorize the last part
cols=[]