ok fixed half of it

This commit is contained in:
art 2019-10-08 13:49:19 -05:00
parent f47be637a0
commit 8eac3d5af1

View File

@ -429,24 +429,6 @@ class Regression:
#todo: document completely #todo: document completely
def factorial(n):
if n==0:
return 1
else:
return n*factorial(n-1)
def num_poly_terms(num_vars, power):
if power == 0:
return 0
return int(factorial(num_vars+power-1) / factorial(power) / factorial(num_vars-1)) + num_poly_terms(num_vars, power-1)
def take_all_pwrs(vec,pwr):
#todo: vectorize (kinda)
combins=torch.combinations(vec, r=pwr, with_replacement=True)
out=torch.ones(combins.size()[0])
for i in torch.t(combins):
out *= i
return torch.cat(out,take_all_pwrs(vec, pwr-1))
def set_device(new_device): def set_device(new_device):
global device global device
device=new_device device=new_device
@ -535,15 +517,31 @@ class Regression:
power=None power=None
def __init__(self, num_vars, power): def __init__(self, num_vars, power):
self.power=power self.power=power
num_terms=num_poly_terms(num_vars, power) num_terms=self.num_poly_terms(num_vars, power)
self.weights=torch.rand(num_terms, requires_grad=True, device=device) self.weights=torch.rand(num_terms, requires_grad=True, device=device)
self.bias=torch.rand(1, requires_grad=True, device=device) self.bias=torch.rand(1, requires_grad=True, device=device)
self.parameters=[self.weights,self.bias] self.parameters=[self.weights,self.bias]
def num_poly_terms(self,num_vars, power):
if power == 0:
return 0
return int(self.factorial(num_vars+power-1) / self.factorial(power) / self.factorial(num_vars-1)) + self.num_poly_terms(num_vars, power-1)
def factorial(self,n):
if n==0:
return 1
else:
return n*self.factorial(n-1)
def take_all_pwrs(self, vec,pwr):
#todo: vectorize (kinda)
combins=torch.combinations(vec, r=pwr, with_replacement=True)
out=torch.ones(combins.size()[0])
for i in torch.t(combins):
out *= i
return torch.cat(out,take_all_pwrs(vec, pwr-1))
def forward(self,mtx): def forward(self,mtx):
#TODO: Vectorize the last part #TODO: Vectorize the last part
cols=[] cols=[]
for i in torch.t(mtx): for i in torch.t(mtx):
cols.append(take_all_pwrs(i,self.power)) cols.append(self.take_all_pwrs(i,self.power))
new_mtx=torch.t(torch.stack(cols)) new_mtx=torch.t(torch.stack(cols))
long_bias=self.bias.repeat([1,mtx.size()[1]]) long_bias=self.bias.repeat([1,mtx.size()[1]])
return torch.matmul(self.weights,new_mtx)+long_bias return torch.matmul(self.weights,new_mtx)+long_bias