mirror of
https://github.com/titanscouting/tra-analysis.git
synced 2024-12-26 17:49:09 +00:00
regression v 1.0.0.003
This commit is contained in:
parent
a20c155cfc
commit
13c536282d
@ -619,10 +619,12 @@ class Regression:
|
||||
# this module is cuda-optimized and vectorized (except for one small part)
|
||||
# setup:
|
||||
|
||||
__version__ = "1.0.0.002"
|
||||
__version__ = "1.0.0.003"
|
||||
|
||||
# changelog should be viewed using print(analysis.regression.__changelog__)
|
||||
__changelog__ = """
|
||||
1.0.0.003:
|
||||
- bug fixes
|
||||
1.0.0.002:
|
||||
-Added more parameters to log, exponential, polynomial
|
||||
-Added SigmoidalRegKernelArthur, because Arthur apparently needs
|
||||
@ -653,12 +655,13 @@ class Regression:
|
||||
'CustomTrain'
|
||||
]
|
||||
|
||||
global device
|
||||
|
||||
device = "cuda:0" if torch.torch.cuda.is_available() else "cpu"
|
||||
|
||||
#todo: document completely
|
||||
|
||||
def set_device(self, new_device):
|
||||
global device
|
||||
device=new_device
|
||||
|
||||
class LinearRegKernel():
|
||||
@ -777,7 +780,7 @@ class Regression:
|
||||
long_bias=self.bias.repeat([1,mtx.size()[1]])
|
||||
return torch.matmul(self.weights,new_mtx)+long_bias
|
||||
|
||||
def SGDTrain(kernel, data, ground, loss=torch.nn.MSELoss(), iterations=1000, learning_rate=.1, return_losses=False):
|
||||
def SGDTrain(self, kernel, data, ground, loss=torch.nn.MSELoss(), iterations=1000, learning_rate=.1, return_losses=False):
|
||||
optim=torch.optim.SGD(kernel.parameters, lr=learning_rate)
|
||||
data_cuda=data.to(device)
|
||||
ground_cuda=ground.to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user