tra-analysis/data analysis/cudaRegressTesting.ipynb
2019-09-21 13:35:51 -05:00

1222 lines
38 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"def factorial(n):\n",
" if n==0:\n",
" return 1\n",
" else:\n",
" return n*factorial(n-1)\n",
"\n",
"def take_all_pwrs(vec,pwr):\n",
" #todo: vectorize (kinda)\n",
" combins=torch.combinations(vec, r=pwr, with_replacement=True)\n",
" out=torch.ones(combins.size()[0])\n",
" for i in torch.t(combins):\n",
" out *= i\n",
" return out\n",
"\n",
"class LinearRegKernel():\n",
" parameters= []\n",
" weights=None\n",
" bias=None\n",
" def __init__(self, num_vars):\n",
" self.weights=torch.rand(num_vars, requires_grad=True)\n",
" self.bias=torch.rand(1, requires_grad=True)\n",
" self.parameters=[self.weights,self.bias]\n",
" def forward(self,mtx):\n",
" long_bias=self.bias.repeat([1,mtx.size()[1]])\n",
" return torch.matmul(self.weights,mtx)+long_bias\n",
" \n",
"class SigmoidalRegKernel():\n",
" parameters= []\n",
" weights=None\n",
" bias=None\n",
" sigmoid=torch.nn.Sigmoid()\n",
" def __init__(self, num_vars):\n",
" self.weights=torch.rand(num_vars)\n",
" self.bias=torch.rand(1)\n",
" self.parameters=[self.weights,self.bias]\n",
" def forward(self,mtx):\n",
" long_bias=self.bias.repeat([1,mtx.size()[1]])\n",
" return self.sigmoid(torch.matmul(self.weights,mtx)+long_bias)\n",
"\n",
"class LogRegKernel():\n",
" parameters= []\n",
" weights=None\n",
" bias=None\n",
" def __init__(self, num_vars):\n",
" self.weights=torch.rand(num_vars)\n",
" self.bias=torch.rand(1)\n",
" self.parameters=[self.weights,self.bias]\n",
" def forward(self,mtx):\n",
" long_bias=self.bias.repeat([1,mtx.size()[1]])\n",
" return torch.log(torch.matmul(self.weights,mtx)+long_bias)\n",
"\n",
"class ExpRegKernel():\n",
" parameters= []\n",
" weights=None\n",
" bias=None\n",
" def __init__(self, num_vars):\n",
" self.weights=torch.rand(num_vars)\n",
" self.bias=torch.rand(1)\n",
" self.parameters=[self.weights,self.bias]\n",
" def forward(self,mtx):\n",
" long_bias=self.bias.repeat([1,mtx.size()[1]])\n",
" return torch.exp(torch.matmul(self.weights,mtx)+long_bias)\n",
"\n",
"class PolyRegKernel():\n",
" parameters= []\n",
" weights=None\n",
" bias=None\n",
" power=None\n",
" def __init__(self, num_vars, power):\n",
" self.power=power\n",
" num_terms=int(factorial(num_vars+power-1) / factorial(power) / factorial(num_vars-1))\n",
" self.weights=torch.rand(num_terms)\n",
" self.bias=torch.rand(1)\n",
" self.parameters=[self.weights,self.bias]\n",
" def forward(self,mtx):\n",
" #TODO: Vectorize the last part\n",
" cols=[]\n",
" for i in torch.t(mtx):\n",
" cols.append(take_all_pwrs(i,self.power))\n",
" new_mtx=torch.t(torch.stack(cols))\n",
" long_bias=self.bias.repeat([1,mtx.size()[1]])\n",
" return torch.matmul(self.weights,new_mtx)+long_bias\n",
"\n",
"def SGDTrain(kernel, data, ground, loss=torch.nn.MSELoss(), iterations=1000, learning_rate=.1, return_losses=False):\n",
" optim=torch.optim.SGD(kernel.parameters, lr=learning_rate)\n",
" if (return_losses):\n",
" losses=[]\n",
" for i in range(iterations):\n",
" with torch.set_grad_enabled(True):\n",
" pred=kernel.forward(data)\n",
" ls=loss(pred,ground)\n",
" losses.append(ls.item())\n",
" ls.backward()\n",
" optim.step()\n",
" return [kernel,losses]\n",
" else:\n",
" for i in range(iterations):\n",
" with torch.set_grad_enabled(True):\n",
" pred=kernel.forward(data)\n",
" ls=loss(pred,ground)\n",
" ls.backward()\n",
" optim.step() \n",
" return kernel\n",
"\n",
"def CustomTrain(kernel, optim, data, ground, loss=torch.nn.MSELoss(), iterations=1000, return_losses=False):\n",
" if (return_losses):\n",
" losses=[]\n",
" for i in range(iterations):\n",
" with torch.set_grad_enabled(True):\n",
" pred=kernel.forward(data)\n",
" ls=loss(pred,ground)\n",
" losses.append(ls.item())\n",
" ls.backward()\n",
" optim.step()\n",
" return [kernel,losses]\n",
" else:\n",
" for i in range(iterations):\n",
" with torch.set_grad_enabled(True):\n",
" pred=kernel.forward(data)\n",
" ls=loss(pred,ground)\n",
" ls.backward()\n",
" optim.step() \n",
" return kernel"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[46.92313766479492,\n",
" 38.630210876464844,\n",
" 25.048364639282227,\n",
" 11.097504615783691,\n",
" 1.831160068511963,\n",
" 0.6059606671333313,\n",
" 7.865695476531982,\n",
" 20.980566024780273,\n",
" 35.19981002807617,\n",
" 45.37260818481445,\n",
" 47.813907623291016,\n",
" 41.63934326171875,\n",
" 29.08551788330078,\n",
" 14.699889183044434,\n",
" 3.693437099456787,\n",
" 0.05307894945144653,\n",
" 5.097426891326904,\n",
" 16.999135971069336,\n",
" 31.446857452392578,\n",
" 43.20696258544922,\n",
" 48.019371032714844,\n",
" 44.1407585144043,\n",
" 32.976009368896484,\n",
" 18.569347381591797,\n",
" 6.139341354370117,\n",
" 0.18853749334812164,\n",
" 2.872443675994873,\n",
" 13.218727111816406,\n",
" 27.47943878173828,\n",
" 40.48865509033203,\n",
" 47.53378677368164,\n",
" 46.06269454956055,\n",
" 36.60813522338867,\n",
" 22.59479522705078,\n",
" 9.098745346069336,\n",
" 1.008663535118103,\n",
" 1.2549662590026855,\n",
" 9.74829387664795,\n",
" 23.41189193725586,\n",
" 37.29610824584961,\n",
" 46.37138748168945,\n",
" 47.35015106201172,\n",
" 39.877723693847656,\n",
" 26.660755157470703,\n",
" 12.486827850341797,\n",
" 2.4901671409606934,\n",
" 0.2918170988559723,\n",
" 6.687965393066406,\n",
" 19.361522674560547,\n",
" 33.72147750854492,\n",
" 44.5659294128418,\n",
" 47.96643829345703,\n",
" 42.69106674194336,\n",
" 30.650619506835938,\n",
" 16.20648956298828,\n",
" 4.590782642364502,\n",
" 0.011036033742129803,\n",
" 4.126081943511963,\n",
" 15.445150375366211,\n",
" 29.867881774902344,\n",
" 42.16966247558594,\n",
" 47.8941535949707,\n",
" 44.96760177612305,\n",
" 34.449989318847656,\n",
" 20.151098251342773,\n",
" 7.250439643859863,\n",
" 0.42103928327560425,\n",
" 2.136667490005493,\n",
" 11.77574634552002,\n",
" 25.846508026123047,\n",
" 39.251853942871094,\n",
" 47.135738372802734,\n",
" 46.642189025878906,\n",
" 37.949920654296875,\n",
" 24.20751190185547,\n",
" 10.392948150634766,\n",
" 1.5103331804275513,\n",
" 0.7772310972213745,\n",
" 8.45913028717041,\n",
" 21.773271560668945,\n",
" 35.89667510986328,\n",
" 45.713226318359375,\n",
" 47.666908264160156,\n",
" 41.04998016357422,\n",
" 28.259319305419922,\n",
" 13.928166389465332,\n",
" 3.247812271118164,\n",
" 0.0870814174413681,\n",
" 5.590890407562256,\n",
" 17.765518188476562,\n",
" 32.20081329345703,\n",
" 43.6677131652832,\n",
" 48.012451171875,\n",
" 43.66118240356445,\n",
" 32.190120697021484,\n",
" 17.754549026489258,\n",
" 5.583629608154297,\n",
" 0.08617715537548065,\n",
" 3.253612995147705,\n",
" 13.938592910766602,\n",
" 28.270626068115234,\n",
" 41.058109283447266,\n",
" 47.66895294189453,\n",
" 45.70848846435547,\n",
" 35.88692092895508,\n",
" 21.7620792388916,\n",
" 8.450617790222168,\n",
" 0.7745394706726074,\n",
" 1.5145018100738525,\n",
" 10.402535438537598,\n",
" 24.219114303588867,\n",
" 37.959407806396484,\n",
" 46.646209716796875,\n",
" 47.132911682128906,\n",
" 39.24329376220703,\n",
" 25.83540153503418,\n",
" 11.766210556030273,\n",
" 2.1322507858276367,\n",
" 0.42343956232070923,\n",
" 7.258889675140381,\n",
" 20.162643432617188,\n",
" 34.460548400878906,\n",
" 44.97346496582031,\n",
" 47.893306732177734,\n",
" 42.16252136230469,\n",
" 29.857162475585938,\n",
" 15.434839248657227,\n",
" 4.120039939880371,\n",
" 0.011575348675251007,\n",
" 4.597830295562744,\n",
" 16.217615127563477,\n",
" 30.661922454833984,\n",
" 42.698577880859375,\n",
" 47.96756362915039,\n",
" 44.56039047241211,\n",
" 33.7114143371582,\n",
" 19.350709915161133,\n",
" 6.680446147918701,\n",
" 0.2904495596885681,\n",
" 2.495577335357666,\n",
" 12.497186660766602,\n",
" 26.67244529724121,\n",
" 39.88663101196289,\n",
" 47.35319900512695,\n",
" 46.36759948730469,\n",
" 37.28697204589844,\n",
" 23.400856018066406,\n",
" 9.739489555358887,\n",
" 1.2517011165618896,\n",
" 1.0122485160827637,\n",
" 9.108006477355957,\n",
" 22.60649871826172,\n",
" 36.618160247802734,\n",
" 46.06753158569336,\n",
" 47.531795501708984,\n",
" 40.480674743652344,\n",
" 27.46846580505371,\n",
" 13.208854675292969,\n",
" 2.867344617843628,\n",
" 0.1901664286851883,\n",
" 6.147210121154785,\n",
" 18.580703735351562,\n",
" 32.98683547973633,\n",
" 44.14723205566406,\n",
" 48.019229888916016,\n",
" 43.20034408569336,\n",
" 31.436246871948242,\n",
" 16.98845863342285,\n",
" 5.090616703033447,\n",
" 0.052681032568216324,\n",
" 3.699662685394287,\n",
" 14.710546493530273,\n",
" 29.09680938720703,\n",
" 41.64723205566406,\n",
" 47.81559371948242,\n",
" 45.367515563964844,\n",
" 35.189842224121094,\n",
" 20.9693660736084,\n",
" 7.857354640960693,\n",
" 0.603532075881958,\n",
" 1.8355506658554077,\n",
" 11.10714340209961,\n",
" 25.059782028198242,\n",
" 38.6392707824707,\n",
" 46.926578521728516,\n",
" 46.91971206665039,\n",
" 38.621158599853516,\n",
" 25.036964416503906,\n",
" 11.087864875793457,\n",
" 1.8267663717269897,\n",
" 0.6083835959434509,\n",
" 7.8740363121032715,\n",
" 20.99178123474121,\n",
" 35.20979690551758,\n",
" 45.37771987915039,\n",
" 47.81224822998047,\n",
" 41.63146209716797,\n",
" 29.074235916137695,\n",
" 14.689230918884277,\n",
" 3.6872079372406006,\n",
" 0.05347258225083351,\n",
" 5.10423469543457,\n",
" 17.00982666015625,\n",
" 31.457477569580078,\n",
" 43.21358108520508,\n",
" 48.01952362060547,\n",
" 44.13429260253906,\n",
" 32.96518325805664,\n",
" 18.55799102783203,\n",
" 6.13147497177124,\n",
" 0.18690752983093262,\n",
" 2.8775391578674316,\n",
" 13.22860050201416,\n",
" 27.49040412902832,\n",
" 40.49662399291992,\n",
" 47.535770416259766,\n",
" 46.05784606933594,\n",
" 36.59809875488281,\n",
" 22.583091735839844,\n",
" 9.089487075805664,\n",
" 1.0050803422927856,\n",
" 1.258231520652771,\n",
" 9.757102012634277,\n",
" 23.42291831970215,\n",
" 37.30523681640625,\n",
" 46.37517547607422,\n",
" 47.347103118896484,\n",
" 39.868804931640625,\n",
" 26.649065017700195,\n",
" 12.47647762298584,\n",
" 2.484764575958252,\n",
" 0.2931888997554779,\n",
" 6.6954803466796875,\n",
" 19.37232780456543,\n",
" 33.73152542114258,\n",
" 44.57144546508789,\n",
" 47.965293884277344,\n",
" 42.68354415893555,\n",
" 30.639318466186523,\n",
" 16.195371627807617,\n",
" 4.583748817443848,\n",
" 0.010506330989301205,\n",
" 4.132124423980713,\n",
" 15.455449104309082,\n",
" 29.87859344482422,\n",
" 42.176780700683594,\n",
" 47.894996643066406,\n",
" 44.96173858642578,\n",
" 34.43943786621094,\n",
" 20.139570236206055,\n",
" 7.242005825042725,\n",
" 0.4186532497406006,\n",
" 2.14109206199646,\n",
" 11.785279273986816,\n",
" 25.857601165771484,\n",
" 39.26039505004883,\n",
" 47.13853454589844,\n",
" 46.63816452026367,\n",
" 37.9404296875,\n",
" 24.195924758911133,\n",
" 10.383381843566895,\n",
" 1.5061837434768677,\n",
" 0.7799347043037415,\n",
" 8.467645645141602,\n",
" 21.784454345703125,\n",
" 35.90642547607422,\n",
" 45.71794509887695,\n",
" 47.664859771728516,\n",
" 41.04186248779297,\n",
" 28.24802589416504,\n",
" 13.917757987976074,\n",
" 3.2420341968536377,\n",
" 0.08800103515386581,\n",
" 5.598155498504639,\n",
" 17.7764835357666,\n",
" 32.21149826049805,\n",
" 43.67424392700195,\n",
" 48.01245880126953,\n",
" 43.6546745300293,\n",
" 32.1794548034668,\n",
" 17.743610382080078,\n",
" 5.5763959884643555,\n",
" 0.0852908343076706,\n",
" 3.2594220638275146,\n",
" 13.949024200439453,\n",
" 28.281938552856445,\n",
" 41.06624221801758,\n",
" 47.67100143432617,\n",
" 45.70375061035156,\n",
" 35.87717819213867,\n",
" 21.750917434692383,\n",
" 8.44212818145752,\n",
" 0.7718656063079834,\n",
" 1.5186808109283447,\n",
" 10.412126541137695,\n",
" 24.230722427368164,\n",
" 37.968902587890625,\n",
" 46.65023422241211,\n",
" 47.13009262084961,\n",
" 39.2347526550293,\n",
" 25.824317932128906,\n",
" 11.756693840026855,\n",
" 2.127849578857422,\n",
" 0.4258512556552887,\n",
" 7.267345905303955,\n",
" 20.17418098449707,\n",
" 34.47111892700195,\n",
" 44.97934341430664,\n",
" 47.892486572265625,\n",
" 42.15541458129883,\n",
" 29.84645652770996,\n",
" 15.42454719543457,\n",
" 4.114015102386475,\n",
" 0.012123552151024342,\n",
" 4.604880332946777,\n",
" 16.228744506835938,\n",
" 30.673227310180664,\n",
" 42.70610046386719,\n",
" 47.968711853027344,\n",
" 44.554874420166016,\n",
" 33.70136642456055,\n",
" 19.339908599853516,\n",
" 6.672937393188477,\n",
" 0.28908833861351013,\n",
" 2.5009899139404297,\n",
" 12.507547378540039,\n",
" 26.684123992919922,\n",
" 39.895545959472656,\n",
" 47.35624694824219,\n",
" 46.363800048828125,\n",
" 37.2778434753418,\n",
" 23.389827728271484,\n",
" 9.730683326721191,\n",
" 1.2484371662139893,\n",
" 1.0158343315124512,\n",
" 9.117264747619629,\n",
" 22.61819839477539,\n",
" 36.628177642822266,\n",
" 46.072364807128906,\n",
" 47.529808044433594,\n",
" 40.472686767578125,\n",
" 27.45749282836914,\n",
" 13.198968887329102,\n",
" 2.862241268157959,\n",
" 0.19179345667362213,\n",
" 6.155076026916504,\n",
" 18.592052459716797,\n",
" 32.99764633178711,\n",
" 44.1536750793457,\n",
" 48.01906204223633,\n",
" 43.193702697753906,\n",
" 31.42561912536621,\n",
" 16.977758407592773,\n",
" 5.083798885345459,\n",
" 0.052278175950050354,\n",
" 3.705883502960205,\n",
" 14.721196174621582,\n",
" 29.108083724975586,\n",
" 41.65509033203125,\n",
" 47.81724548339844,\n",
" 45.362396240234375,\n",
" 35.179847717285156,\n",
" 20.958147048950195,\n",
" 7.849003314971924,\n",
" 0.601097047328949,\n",
" 1.8399325609207153,\n",
" 11.116769790649414,\n",
" 25.07118034362793,\n",
" 38.64830780029297,\n",
" 46.929988861083984,\n",
" 46.916259765625,\n",
" 38.61207580566406,\n",
" 25.025535583496094,\n",
" 11.078215599060059,\n",
" 1.8223642110824585,\n",
" 0.6108006834983826,\n",
" 7.882364273071289,\n",
" 21.002965927124023,\n",
" 35.219749450683594,\n",
" 45.382789611816406,\n",
" 47.81056213378906,\n",
" 41.623565673828125,\n",
" 29.06293487548828,\n",
" 14.678563117980957,\n",
" 3.6809725761413574,\n",
" 0.05386172607541084,\n",
" 5.11103630065918,\n",
" 17.020503997802734,\n",
" 31.46808433532715,\n",
" 43.22019577026367,\n",
" 48.019657135009766,\n",
" 44.12782287597656,\n",
" 32.954349517822266,\n",
" 18.546627044677734,\n",
" 6.123597621917725,\n",
" 0.18527379631996155,\n",
" 2.882638692855835,\n",
" 13.23847770690918,\n",
" 27.501373291015625,\n",
" 40.50460433959961,\n",
" 47.537757873535156,\n",
" 46.053016662597656,\n",
" 36.58807373046875,\n",
" 22.571388244628906,\n",
" 9.08022689819336,\n",
" 1.0014973878860474,\n",
" 1.261501669883728,\n",
" 9.765915870666504,\n",
" 23.4339599609375,\n",
" 37.31438064575195,\n",
" 46.37897491455078,\n",
" 47.34406280517578,\n",
" 39.859893798828125,\n",
" 26.63738250732422,\n",
" 12.466124534606934,\n",
" 2.479365825653076,\n",
" 0.2945669889450073,\n",
" 6.703005790710449,\n",
" 19.38314437866211,\n",
" 33.74158477783203,\n",
" 44.57697677612305,\n",
" 47.96417236328125,\n",
" 42.67604064941406,\n",
" 30.628026962280273,\n",
" 16.18425941467285,\n",
" 4.576717853546143,\n",
" 0.00998594518750906,\n",
" 4.138182640075684,\n",
" 15.465778350830078,\n",
" 29.889331817626953,\n",
" 42.183929443359375,\n",
" 47.895851135253906,\n",
" 44.95588684082031,\n",
" 34.42888641357422,\n",
" 20.128040313720703,\n",
" 7.233575820922852,\n",
" 0.41627734899520874,\n",
" 2.1455371379852295,\n",
" 11.794844627380371,\n",
" 25.86873435974121,\n",
" 39.268978118896484,\n",
" 47.14138412475586,\n",
" 46.6341552734375,\n",
" 37.93095397949219,\n",
" 24.184341430664062,\n",
" 10.373819351196289,\n",
" 1.502044439315796,\n",
" 0.7826577425003052,\n",
" 8.476187705993652,\n",
" 21.795671463012695,\n",
" 35.916202545166016,\n",
" 45.722713470458984,\n",
" 47.66282272338867,\n",
" 41.03374099731445,\n",
" 28.236730575561523,\n",
" 13.907356262207031,\n",
" 3.236266851425171,\n",
" 0.08893852680921555,\n",
" 5.605445861816406,\n",
" 17.787477493286133,\n",
" 32.22220230102539,\n",
" 43.680789947509766,\n",
" 48.012481689453125,\n",
" 43.64816665649414,\n",
" 32.16879653930664,\n",
" 17.73267364501953,\n",
" 5.569169998168945,\n",
" 0.08441997319459915,\n",
" 3.265254020690918,\n",
" 13.95948314666748,\n",
" 28.293272018432617,\n",
" 41.074398040771484,\n",
" 47.67307662963867,\n",
" 45.69904327392578,\n",
" 35.86745071411133,\n",
" 21.73975372314453,\n",
" 8.433645248413086,\n",
" 0.7692044377326965,\n",
" 1.5228785276412964,\n",
" 10.421737670898438,\n",
" 24.24234390258789,\n",
" 37.97840881347656,\n",
" 46.654273986816406,\n",
" 47.12729263305664,\n",
" 39.22621536254883,\n",
" 25.813234329223633,\n",
" 11.747184753417969,\n",
" 2.1234588623046875,\n",
" 0.4282756745815277,\n",
" 7.275813102722168,\n",
" 20.18573760986328,\n",
" 34.481693267822266,\n",
" 44.985206604003906,\n",
" 47.891632080078125,\n",
" 42.148277282714844,\n",
" 29.83574676513672,\n",
" 15.414258003234863,\n",
" 4.107995510101318,\n",
" 0.012680652551352978,\n",
" 4.6119384765625,\n",
" 16.23987579345703,\n",
" 30.684524536132812,\n",
" 42.713600158691406,\n",
" 47.96981430053711,\n",
" 44.54933547973633,\n",
" 33.69130325317383,\n",
" 19.3291015625,\n",
" 6.665433883666992,\n",
" 0.2877325117588043,\n",
" 2.5064029693603516,\n",
" 12.517895698547363,\n",
" 26.695791244506836,\n",
" 39.90443420410156,\n",
" 47.3592529296875,\n",
" 46.35997009277344,\n",
" 37.26869583129883,\n",
" 23.378786087036133,\n",
" 9.721874237060547,\n",
" 1.2451752424240112,\n",
" 1.019419550895691,\n",
" 9.126518249511719,\n",
" 22.6298828125,\n",
" 36.638179779052734,\n",
" 46.07716751098633,\n",
" 47.52777862548828,\n",
" 40.46467590332031,\n",
" 27.446502685546875,\n",
" 13.189085006713867,\n",
" 2.857137680053711,\n",
" 0.19341740012168884,\n",
" 6.162932395935059,\n",
" 18.603382110595703,\n",
" 33.0084342956543,\n",
" 44.160099029541016,\n",
" 48.018863677978516,\n",
" 43.187034606933594,\n",
" 31.414968490600586,\n",
" 16.967052459716797,\n",
" 5.076975345611572,\n",
" 0.051871158182621,\n",
" 3.7120935916900635,\n",
" 14.731826782226562,\n",
" 29.119327545166016,\n",
" 41.66291809082031,\n",
" 47.81885528564453,\n",
" 45.35725402832031,\n",
" 35.16982650756836,\n",
" 20.946914672851562,\n",
" 7.840641975402832,\n",
" 0.5986564755439758,\n",
" 1.8443087339401245,\n",
" 11.126387596130371,\n",
" 25.08255386352539,\n",
" 38.657318115234375,\n",
" 46.93337631225586,\n",
" 46.91277313232422,\n",
" 38.60296630859375,\n",
" 25.014087677001953,\n",
" 11.068550109863281,\n",
" 1.817955732345581,\n",
" 0.6132127642631531,\n",
" 7.890689373016357,\n",
" 21.014148712158203,\n",
" 35.229698181152344,\n",
" 45.387840270996094,\n",
" 47.808837890625,\n",
" 41.61564636230469,\n",
" 29.051618576049805,\n",
" 14.66788101196289,\n",
" 3.6747305393218994,\n",
" 0.05424736067652702,\n",
" 5.11783504486084,\n",
" 17.031171798706055,\n",
" 31.478673934936523,\n",
" 43.22679138183594,\n",
" 48.019779205322266,\n",
" 44.12131881713867,\n",
" 32.94350814819336,\n",
" 18.535255432128906,\n",
" 6.115721702575684,\n",
" 0.18364056944847107,\n",
" 2.887730598449707,\n",
" 13.248343467712402,\n",
" 27.512325286865234,\n",
" 40.51255798339844,\n",
" 47.53971481323242,\n",
" 46.04814529418945,\n",
" 36.578025817871094,\n",
" 22.559669494628906,\n",
" 9.070963859558105,\n",
" 0.9979168772697449,\n",
" 1.264772891998291,\n",
" 9.774725914001465,\n",
" 23.44498634338379,\n",
" 37.3234977722168,\n",
" 46.382755279541016,\n",
" 47.34099578857422,\n",
" 39.85095977783203,\n",
" 26.625690460205078,\n",
" 12.455771446228027,\n",
" 2.473971128463745,\n",
" 0.2959519922733307,\n",
" 6.710536479949951,\n",
" 19.393962860107422,\n",
" 33.75163269042969,\n",
" 44.58250045776367,\n",
" 47.96303176879883,\n",
" 42.668521881103516,\n",
" 30.616731643676758,\n",
" 16.17315101623535,\n",
" 4.569698333740234,\n",
" 0.009476072154939175,\n",
" 4.144247531890869,\n",
" 15.476103782653809,\n",
" 29.900062561035156,\n",
" 42.19105911254883,\n",
" 47.896690368652344,\n",
" 44.950016021728516,\n",
" 34.41832733154297,\n",
" 20.116519927978516,\n",
" 7.225159645080566,\n",
" 0.4139162003993988,\n",
" 2.1499907970428467,\n",
" 11.804407119750977,\n",
" 25.879854202270508,\n",
" 39.27754211425781,\n",
" 47.14419937133789,\n",
" 46.63013458251953,\n",
" 37.92147445678711,\n",
" 24.17276382446289,\n",
" 10.364273071289062,\n",
" 1.4979220628738403,\n",
" 0.7853949666023254,\n",
" 8.484739303588867,\n",
" 21.806888580322266,\n",
" 35.92597961425781,\n",
" 45.72746276855469,\n",
" 47.660797119140625,\n",
" 41.025638580322266,\n",
" 28.22545623779297,\n",
" 13.896970748901367,\n",
" 3.2305147647857666,\n",
" 0.08989264070987701,\n",
" 5.61275053024292,\n",
" 17.79848289489746,\n",
" 32.23292541503906,\n",
" 43.68734359741211,\n",
" 48.01250076293945,\n",
" 43.641666412353516,\n",
" 32.158138275146484,\n",
" 17.72174835205078,\n",
" 5.561957836151123,\n",
" 0.08356527239084244,\n",
" 3.2711024284362793,\n",
" 13.96995735168457,\n",
" 28.304622650146484,\n",
" 41.082550048828125,\n",
" 47.675132751464844,\n",
" 45.6943244934082,\n",
" 35.85771179199219,\n",
" 21.728591918945312,\n",
" 8.425172805786133,\n",
" 0.7665574550628662,\n",
" 1.5270906686782837,\n",
" 10.431361198425293,\n",
" 24.253976821899414,\n",
" 37.9879264831543,\n",
" 46.6583137512207,\n",
" 47.12447738647461,\n",
" 39.21766662597656,\n",
" 25.80215072631836,\n",
" 11.737676620483398,\n",
" 2.11907696723938,\n",
" 0.4307127594947815,\n",
" 7.284297466278076,\n",
" 20.197315216064453,\n",
" 34.492271423339844,\n",
" 44.991085052490234,\n",
" 47.890804290771484,\n",
" 42.14115905761719,\n",
" 29.825029373168945,\n",
" 15.403965950012207,\n",
" 4.101979732513428,\n",
" 0.013246170245110989,\n",
" 4.619010925292969,\n",
" 16.251022338867188,\n",
" 30.695838928222656,\n",
" 42.72111511230469,\n",
" 47.970951080322266,\n",
" 44.54378890991211,\n",
" 33.68124008178711,\n",
" 19.318286895751953,\n",
" 6.657920837402344,\n",
" 0.2863781154155731,\n",
" 2.5118298530578613,\n",
" 12.528270721435547,\n",
" 26.707487106323242,\n",
" 39.91333770751953,\n",
" 47.362281799316406,\n",
" 46.356143951416016,\n",
" 37.259525299072266,\n",
" 23.36772918701172,\n",
" 9.713057518005371,\n",
" 1.2419092655181885,\n",
" 1.023009181022644,\n",
" 9.135781288146973,\n",
" 22.64158058166504,\n",
" 36.648189544677734,\n",
" 46.081966400146484,\n",
" 47.5257453918457,\n",
" 40.45664596557617,\n",
" 27.435495376586914,\n",
" 13.179183006286621,\n",
" 2.8520262241363525,\n",
" 0.19503945112228394,\n",
" 6.170792102813721,\n",
" 18.61471939086914,\n",
" 33.01922607421875,\n",
" 44.16651153564453,\n",
" 48.0186653137207,\n",
" 43.180362701416016,\n",
" 31.404300689697266,\n",
" 16.956331253051758,\n",
" 5.070141315460205,\n",
" 0.05145828425884247,\n",
" 3.7183074951171875,\n",
" 14.742465019226074,\n",
" 29.130584716796875,\n",
" 41.67076110839844,\n",
" 47.82048416137695,\n",
" 45.35210418701172,\n",
" 35.159793853759766,\n",
" 20.9356632232666,\n",
" 7.832267761230469,\n",
" 0.5962072014808655,\n",
" 1.8486831188201904,\n",
" 11.136005401611328,\n",
" 25.09393310546875,\n",
" 38.666343688964844,\n",
" 46.93675231933594,\n",
" 46.90928268432617,\n",
" 38.5938606262207,\n",
" 25.00263214111328,\n",
" 11.05887508392334,\n",
" 1.8135403394699097,\n",
" 0.615619421005249,\n",
" 7.899007320404053,\n",
" 21.025320053100586,\n",
" 35.23963165283203,\n",
" 45.392887115478516,\n",
" 47.80710983276367,\n",
" 41.607704162597656,\n",
" 29.040283203125,\n",
" 14.657190322875977,\n",
" 3.6684818267822266,\n",
" 0.05462893843650818,\n",
" 5.124626636505127,\n",
" 17.041828155517578,\n",
" 31.48925018310547,\n",
" 43.23335647583008,\n",
" 48.019866943359375,\n",
" 44.11480712890625,\n",
" 32.932640075683594,\n",
" 18.52387046813965,\n",
" 6.107840061187744,\n",
" 0.18200506269931793,\n",
" 2.892824411392212,\n",
" 13.258208274841309,\n",
" 27.523269653320312,\n",
" 40.520511627197266,\n",
" 47.54166793823242,\n",
" 46.043270111083984,\n",
" 36.567962646484375,\n",
" 22.547948837280273,\n",
" 9.061697959899902,\n",
" 0.9943376779556274,\n",
" 1.2680473327636719,\n",
" 9.783537864685059,\n",
" 23.45601463317871,\n",
" 37.332618713378906,\n",
" 46.386539459228516,\n",
" 47.33793258666992,\n",
" 39.842037200927734,\n",
" 26.614004135131836,\n",
" 12.445426940917969,\n",
" 2.4685826301574707,\n",
" 0.2973437011241913,\n",
" 6.718074798583984,\n",
" 19.404788970947266,\n",
" 33.76170349121094,\n",
" 44.58802795410156,\n",
" 47.961891174316406,\n",
" 42.661014556884766,\n",
" 30.605440139770508,\n",
" 16.162050247192383,\n",
" 4.562687873840332,\n",
" 0.008976398035883904,\n",
" 4.150323867797852,\n",
" 15.486442565917969,\n",
" 29.910804748535156,\n",
" 42.198211669921875,\n",
" 47.89754867553711,\n",
" 44.94417190551758,\n",
" 34.40779113769531,\n",
" 20.10500717163086,\n",
" 7.21675443649292,\n",
" 0.41156795620918274,\n",
" 2.154458999633789,\n",
" 11.813986778259277,\n",
" 25.890987396240234,\n",
" 39.28611755371094,\n",
" 47.14704513549805,\n",
" 46.62612533569336,\n",
" 37.912010192871094,\n",
" 24.16120147705078,\n",
" 10.354740142822266,\n",
" 1.493814468383789,\n",
" 0.7881485223770142,\n",
" 8.493309020996094,\n",
" 21.818126678466797,\n",
" 35.93577194213867,\n",
" 45.732215881347656,\n",
" 47.658775329589844,\n",
" 41.01753234863281,\n",
" 28.214183807373047,\n",
" 13.886595726013184,\n",
" 3.2247796058654785,\n",
" 0.09086360037326813,\n",
" 5.6200714111328125,\n",
" 17.809499740600586,\n",
" 32.243648529052734,\n",
" 43.69390106201172,\n",
" 48.012516021728516,\n",
" 43.63515853881836,\n",
" 32.147483825683594,\n",
" 17.710830688476562,\n",
" 5.554759502410889,\n",
" 0.08272704482078552,\n",
" 3.276963949203491,\n",
" 13.980438232421875,\n",
" 28.315967559814453,\n",
" 41.09071350097656,\n",
" 47.677188873291016,\n",
" 45.68959426879883,\n",
" 35.84798049926758,\n",
" 21.717445373535156,\n",
" 8.416711807250977,\n",
" 0.7639248371124268,\n",
" 1.531317114830017,\n",
" 10.440995216369629,\n",
" 24.26561164855957,\n",
" 37.9974365234375,\n",
" 46.6623420715332,\n",
" 47.121665954589844,\n",
" 39.20912551879883,\n",
" 25.79106903076172,\n",
" 11.72817611694336,\n",
" 2.1147053241729736,\n",
" 0.43316206336021423,\n",
" 7.292791366577148,\n",
" 20.20888328552246,\n",
" 34.502838134765625,\n",
" 44.9969482421875,\n",
" 47.88994216918945,\n",
" 42.13400650024414,\n",
" 29.814308166503906,\n",
" 15.393670082092285,\n",
" 4.095966815948486,\n",
" 0.013819677755236626,\n",
" 4.626088619232178,\n",
" 16.262165069580078,\n",
" 30.707136154174805,\n",
" 42.72860336303711,\n",
" 47.9720458984375,\n",
" 44.53822326660156,\n",
" 33.6711540222168,\n",
" 19.307464599609375,\n",
" 6.6504130363464355,\n",
" 0.2850286662578583,\n",
" 2.5172553062438965,\n",
" 12.538633346557617,\n",
" 26.719165802001953,\n",
" 39.92222213745117,\n",
" 47.365272521972656,\n",
" 46.35228729248047,\n",
" 37.250343322753906,\n",
" 23.356666564941406,\n",
" 9.704236030578613,\n",
" 1.2386443614959717,\n",
" 1.026596188545227,\n",
" 9.145030975341797,\n",
" 22.65325355529785,\n",
" 36.65815734863281,\n",
" 46.086727142333984,\n",
" 47.5236701965332,\n",
" 40.44859313964844,\n",
" 27.42446517944336,\n",
" 13.169272422790527,\n",
" 2.8469107151031494,\n",
" 0.19665808975696564,\n",
" 6.178638458251953,\n",
" 18.62603187561035,\n",
" 33.02998352050781,\n",
" 44.172882080078125,\n",
" 48.01840591430664,\n",
" 43.17363357543945,\n",
" 31.39360809326172,\n",
" 16.94559097290039,\n",
" 5.063300609588623,\n",
" 0.051040634512901306,\n",
" 3.7245066165924072,\n",
" 14.753074645996094,\n",
" 29.141794204711914,\n",
" 41.67854309082031,\n",
" 47.82203674316406,\n",
" 45.34687805175781,\n",
" 35.149723052978516,\n",
" 20.92438507080078,\n",
" 7.8238844871521,\n",
" 0.593753457069397,\n",
" 1.853044867515564,\n",
" 11.14559268951416,\n",
" 25.105268478393555,\n",
" 38.6753044128418,\n",
" 46.94007110595703,\n",
" 46.905723571777344,\n",
" 38.584686279296875,\n",
" 24.991140365600586,\n",
" 11.04918384552002,\n",
" 1.809116244316101,\n",
" 0.6180207133293152,\n",
" 7.907314777374268,\n",
" 21.03647232055664,\n",
" 35.24953079223633,\n",
" 45.397884368896484,\n",
" 47.805320739746094,\n",
" 41.599708557128906,\n",
" 29.028907775878906,\n",
" 14.646471977233887,\n",
" 3.6622226238250732,\n",
" 0.05500727519392967,\n",
" 5.131416320800781,\n",
" 17.052480697631836,\n",
" 31.499801635742188,\n",
" 43.23990249633789,\n",
" 48.019920349121094,\n",
" 44.10824203491211,\n",
" 32.92172622680664,\n",
" 18.5124568939209,\n",
" 6.099941730499268,\n",
" 0.18036723136901855,\n",
" 2.897919178009033,\n",
" 13.268067359924316,\n",
" 27.534196853637695,\n",
" 40.52841567993164,\n",
" 47.54356384277344,\n",
" 46.0383415222168,\n",
" 36.557857513427734,\n",
" 22.536197662353516,\n",
" 9.052422523498535,\n",
" 0.990759015083313,\n",
" 1.2713243961334229,\n",
" 9.79234504699707,\n",
" 23.46702003479004,\n",
" 37.34169387817383,\n",
" 46.3902473449707,\n",
" 47.33479690551758,\n",
" 39.833038330078125,\n",
" 26.602272033691406,\n",
" 12.43505859375,\n",
" 2.4631946086883545,\n",
" 0.29874250292778015,\n",
" 6.725613594055176,\n",
" 19.415590286254883,\n",
" 33.7717170715332,\n",
" 44.59348678588867,\n",
" 47.9606819152832,\n",
" 42.65342330932617,\n",
" 30.594093322753906,\n",
" 16.150917053222656,\n",
" 4.555673122406006,\n",
" 0.008487294428050518,\n",
" 4.156408786773682,\n",
" 15.496768951416016,\n",
" 29.921506881713867,\n",
" 42.20529556274414,\n",
" 47.89832305908203,\n",
" 44.938236236572266,\n",
" 34.39719772338867,\n",
" 20.093469619750977,\n",
" 7.208343029022217,\n",
" 0.4092309772968292,\n",
" 2.1589431762695312,\n",
" 11.823572158813477,\n",
" 25.902109146118164,\n",
" 39.29465866088867,\n",
" 47.14982223510742,\n",
" 46.622066497802734,\n",
" 37.90248489379883,\n",
" 24.149599075317383,\n",
" ...]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model=SGDTrain(LinearRegKernel(3),torch.tensor([[1,2],[3,4],[5,6]]).to(torch.float),torch.tensor([[1,2]]).to(torch.float),iterations=10000, learning_rate=.001, return_losses=True)\n",
"model[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}