diff --git a/data analysis/cudaRegressTesting.ipynb b/data analysis/cudaRegressTesting.ipynb new file mode 100644 index 00000000..1a0ad492 --- /dev/null +++ b/data analysis/cudaRegressTesting.ipynb @@ -0,0 +1,1221 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "def factorial(n):\n", + " if n==0:\n", + " return 1\n", + " else:\n", + " return n*factorial(n-1)\n", + "\n", + "def take_all_pwrs(vec,pwr):\n", + " #todo: vectorize (kinda)\n", + " combins=torch.combinations(vec, r=pwr, with_replacement=True)\n", + " out=torch.ones(combins.size()[0])\n", + " for i in torch.t(combins):\n", + " out *= i\n", + " return out\n", + "\n", + "class LinearRegKernel():\n", + " parameters= []\n", + " weights=None\n", + " bias=None\n", + " def __init__(self, num_vars):\n", + " self.weights=torch.rand(num_vars, requires_grad=True)\n", + " self.bias=torch.rand(1, requires_grad=True)\n", + " self.parameters=[self.weights,self.bias]\n", + " def forward(self,mtx):\n", + " long_bias=self.bias.repeat([1,mtx.size()[1]])\n", + " return torch.matmul(self.weights,mtx)+long_bias\n", + " \n", + "class SigmoidalRegKernel():\n", + " parameters= []\n", + " weights=None\n", + " bias=None\n", + " sigmoid=torch.nn.Sigmoid()\n", + " def __init__(self, num_vars):\n", + " self.weights=torch.rand(num_vars)\n", + " self.bias=torch.rand(1)\n", + " self.parameters=[self.weights,self.bias]\n", + " def forward(self,mtx):\n", + " long_bias=self.bias.repeat([1,mtx.size()[1]])\n", + " return self.sigmoid(torch.matmul(self.weights,mtx)+long_bias)\n", + "\n", + "class LogRegKernel():\n", + " parameters= []\n", + " weights=None\n", + " bias=None\n", + " def __init__(self, num_vars):\n", + " self.weights=torch.rand(num_vars)\n", + " self.bias=torch.rand(1)\n", + " self.parameters=[self.weights,self.bias]\n", + " def forward(self,mtx):\n", + " long_bias=self.bias.repeat([1,mtx.size()[1]])\n", + " return torch.log(torch.matmul(self.weights,mtx)+long_bias)\n", + "\n", + "class ExpRegKernel():\n", + " parameters= []\n", + " weights=None\n", + " bias=None\n", + " def __init__(self, num_vars):\n", + " self.weights=torch.rand(num_vars)\n", + " self.bias=torch.rand(1)\n", + " self.parameters=[self.weights,self.bias]\n", + " def forward(self,mtx):\n", + " long_bias=self.bias.repeat([1,mtx.size()[1]])\n", + " return torch.exp(torch.matmul(self.weights,mtx)+long_bias)\n", + "\n", + "class PolyRegKernel():\n", + " parameters= []\n", + " weights=None\n", + " bias=None\n", + " power=None\n", + " def __init__(self, num_vars, power):\n", + " self.power=power\n", + " num_terms=int(factorial(num_vars+power-1) / factorial(power) / factorial(num_vars-1))\n", + " self.weights=torch.rand(num_terms)\n", + " self.bias=torch.rand(1)\n", + " self.parameters=[self.weights,self.bias]\n", + " def forward(self,mtx):\n", + " #TODO: Vectorize the last part\n", + " cols=[]\n", + " for i in torch.t(mtx):\n", + " cols.append(take_all_pwrs(i,self.power))\n", + " new_mtx=torch.t(torch.stack(cols))\n", + " long_bias=self.bias.repeat([1,mtx.size()[1]])\n", + " return torch.matmul(self.weights,new_mtx)+long_bias\n", + "\n", + "def SGDTrain(kernel, data, ground, loss=torch.nn.MSELoss(), iterations=1000, learning_rate=.1, return_losses=False):\n", + " optim=torch.optim.SGD(kernel.parameters, lr=learning_rate)\n", + " if (return_losses):\n", + " losses=[]\n", + " for i in range(iterations):\n", + " with torch.set_grad_enabled(True):\n", + " pred=kernel.forward(data)\n", + " ls=loss(pred,ground)\n", + " losses.append(ls.item())\n", + " ls.backward()\n", + " optim.step()\n", + " return [kernel,losses]\n", + " else:\n", + " for i in range(iterations):\n", + " with torch.set_grad_enabled(True):\n", + " pred=kernel.forward(data)\n", + " ls=loss(pred,ground)\n", + " ls.backward()\n", + " optim.step() \n", + " return kernel\n", + "\n", + "def CustomTrain(kernel, optim, data, ground, loss=torch.nn.MSELoss(), iterations=1000, return_losses=False):\n", + " if (return_losses):\n", + " losses=[]\n", + " for i in range(iterations):\n", + " with torch.set_grad_enabled(True):\n", + " pred=kernel.forward(data)\n", + " ls=loss(pred,ground)\n", + " losses.append(ls.item())\n", + " ls.backward()\n", + " optim.step()\n", + " return [kernel,losses]\n", + " else:\n", + " for i in range(iterations):\n", + " with torch.set_grad_enabled(True):\n", + " pred=kernel.forward(data)\n", + " ls=loss(pred,ground)\n", + " ls.backward()\n", + " optim.step() \n", + " return kernel" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[46.92313766479492,\n", + " 38.630210876464844,\n", + " 25.048364639282227,\n", + " 11.097504615783691,\n", + " 1.831160068511963,\n", + " 0.6059606671333313,\n", + " 7.865695476531982,\n", + " 20.980566024780273,\n", + " 35.19981002807617,\n", + " 45.37260818481445,\n", + " 47.813907623291016,\n", + " 41.63934326171875,\n", + " 29.08551788330078,\n", + " 14.699889183044434,\n", + " 3.693437099456787,\n", + " 0.05307894945144653,\n", + " 5.097426891326904,\n", + " 16.999135971069336,\n", + " 31.446857452392578,\n", + " 43.20696258544922,\n", + " 48.019371032714844,\n", + " 44.1407585144043,\n", + " 32.976009368896484,\n", + " 18.569347381591797,\n", + " 6.139341354370117,\n", + " 0.18853749334812164,\n", + " 2.872443675994873,\n", + " 13.218727111816406,\n", + " 27.47943878173828,\n", + " 40.48865509033203,\n", + " 47.53378677368164,\n", + " 46.06269454956055,\n", + " 36.60813522338867,\n", + " 22.59479522705078,\n", + " 9.098745346069336,\n", + " 1.008663535118103,\n", + " 1.2549662590026855,\n", + " 9.74829387664795,\n", + " 23.41189193725586,\n", + " 37.29610824584961,\n", + " 46.37138748168945,\n", + " 47.35015106201172,\n", + " 39.877723693847656,\n", + " 26.660755157470703,\n", + " 12.486827850341797,\n", + " 2.4901671409606934,\n", + " 0.2918170988559723,\n", + " 6.687965393066406,\n", + " 19.361522674560547,\n", + " 33.72147750854492,\n", + " 44.5659294128418,\n", + " 47.96643829345703,\n", + " 42.69106674194336,\n", + " 30.650619506835938,\n", + " 16.20648956298828,\n", + " 4.590782642364502,\n", + " 0.011036033742129803,\n", + " 4.126081943511963,\n", + " 15.445150375366211,\n", + " 29.867881774902344,\n", + " 42.16966247558594,\n", + " 47.8941535949707,\n", + " 44.96760177612305,\n", + " 34.449989318847656,\n", + " 20.151098251342773,\n", + " 7.250439643859863,\n", + " 0.42103928327560425,\n", + " 2.136667490005493,\n", + " 11.77574634552002,\n", + " 25.846508026123047,\n", + " 39.251853942871094,\n", + " 47.135738372802734,\n", + " 46.642189025878906,\n", + " 37.949920654296875,\n", + " 24.20751190185547,\n", + " 10.392948150634766,\n", + " 1.5103331804275513,\n", + " 0.7772310972213745,\n", + " 8.45913028717041,\n", + " 21.773271560668945,\n", + " 35.89667510986328,\n", + " 45.713226318359375,\n", + " 47.666908264160156,\n", + " 41.04998016357422,\n", + " 28.259319305419922,\n", + " 13.928166389465332,\n", + " 3.247812271118164,\n", + " 0.0870814174413681,\n", + " 5.590890407562256,\n", + " 17.765518188476562,\n", + " 32.20081329345703,\n", + " 43.6677131652832,\n", + " 48.012451171875,\n", + " 43.66118240356445,\n", + " 32.190120697021484,\n", + " 17.754549026489258,\n", + " 5.583629608154297,\n", + " 0.08617715537548065,\n", + " 3.253612995147705,\n", + " 13.938592910766602,\n", + " 28.270626068115234,\n", + " 41.058109283447266,\n", + " 47.66895294189453,\n", + " 45.70848846435547,\n", + " 35.88692092895508,\n", + " 21.7620792388916,\n", + " 8.450617790222168,\n", + " 0.7745394706726074,\n", + " 1.5145018100738525,\n", + " 10.402535438537598,\n", + " 24.219114303588867,\n", + " 37.959407806396484,\n", + " 46.646209716796875,\n", + " 47.132911682128906,\n", + " 39.24329376220703,\n", + " 25.83540153503418,\n", + " 11.766210556030273,\n", + " 2.1322507858276367,\n", + " 0.42343956232070923,\n", + " 7.258889675140381,\n", + " 20.162643432617188,\n", + " 34.460548400878906,\n", + " 44.97346496582031,\n", + " 47.893306732177734,\n", + " 42.16252136230469,\n", + " 29.857162475585938,\n", + " 15.434839248657227,\n", + " 4.120039939880371,\n", + " 0.011575348675251007,\n", + " 4.597830295562744,\n", + " 16.217615127563477,\n", + " 30.661922454833984,\n", + " 42.698577880859375,\n", + " 47.96756362915039,\n", + " 44.56039047241211,\n", + " 33.7114143371582,\n", + " 19.350709915161133,\n", + " 6.680446147918701,\n", + " 0.2904495596885681,\n", + " 2.495577335357666,\n", + " 12.497186660766602,\n", + " 26.67244529724121,\n", + " 39.88663101196289,\n", + " 47.35319900512695,\n", + " 46.36759948730469,\n", + " 37.28697204589844,\n", + " 23.400856018066406,\n", + " 9.739489555358887,\n", + " 1.2517011165618896,\n", + " 1.0122485160827637,\n", + " 9.108006477355957,\n", + " 22.60649871826172,\n", + " 36.618160247802734,\n", + " 46.06753158569336,\n", + " 47.531795501708984,\n", + " 40.480674743652344,\n", + " 27.46846580505371,\n", + " 13.208854675292969,\n", + " 2.867344617843628,\n", + " 0.1901664286851883,\n", + " 6.147210121154785,\n", + " 18.580703735351562,\n", + " 32.98683547973633,\n", + " 44.14723205566406,\n", + " 48.019229888916016,\n", + " 43.20034408569336,\n", + " 31.436246871948242,\n", + " 16.98845863342285,\n", + " 5.090616703033447,\n", + " 0.052681032568216324,\n", + " 3.699662685394287,\n", + " 14.710546493530273,\n", + " 29.09680938720703,\n", + " 41.64723205566406,\n", + " 47.81559371948242,\n", + " 45.367515563964844,\n", + " 35.189842224121094,\n", + " 20.9693660736084,\n", + " 7.857354640960693,\n", + " 0.603532075881958,\n", + " 1.8355506658554077,\n", + " 11.10714340209961,\n", + " 25.059782028198242,\n", + " 38.6392707824707,\n", + " 46.926578521728516,\n", + " 46.91971206665039,\n", + " 38.621158599853516,\n", + " 25.036964416503906,\n", + " 11.087864875793457,\n", + " 1.8267663717269897,\n", + " 0.6083835959434509,\n", + " 7.8740363121032715,\n", + " 20.99178123474121,\n", + " 35.20979690551758,\n", + " 45.37771987915039,\n", + " 47.81224822998047,\n", + " 41.63146209716797,\n", + " 29.074235916137695,\n", + " 14.689230918884277,\n", + " 3.6872079372406006,\n", + " 0.05347258225083351,\n", + " 5.10423469543457,\n", + " 17.00982666015625,\n", + " 31.457477569580078,\n", + " 43.21358108520508,\n", + " 48.01952362060547,\n", + " 44.13429260253906,\n", + " 32.96518325805664,\n", + " 18.55799102783203,\n", + " 6.13147497177124,\n", + " 0.18690752983093262,\n", + " 2.8775391578674316,\n", + " 13.22860050201416,\n", + " 27.49040412902832,\n", + " 40.49662399291992,\n", + " 47.535770416259766,\n", + " 46.05784606933594,\n", + " 36.59809875488281,\n", + " 22.583091735839844,\n", + " 9.089487075805664,\n", + " 1.0050803422927856,\n", + " 1.258231520652771,\n", + " 9.757102012634277,\n", + " 23.42291831970215,\n", + " 37.30523681640625,\n", + " 46.37517547607422,\n", + " 47.347103118896484,\n", + " 39.868804931640625,\n", + " 26.649065017700195,\n", + " 12.47647762298584,\n", + " 2.484764575958252,\n", + " 0.2931888997554779,\n", + " 6.6954803466796875,\n", + " 19.37232780456543,\n", + " 33.73152542114258,\n", + " 44.57144546508789,\n", + " 47.965293884277344,\n", + " 42.68354415893555,\n", + " 30.639318466186523,\n", + " 16.195371627807617,\n", + " 4.583748817443848,\n", + " 0.010506330989301205,\n", + " 4.132124423980713,\n", + " 15.455449104309082,\n", + " 29.87859344482422,\n", + " 42.176780700683594,\n", + " 47.894996643066406,\n", + " 44.96173858642578,\n", + " 34.43943786621094,\n", + " 20.139570236206055,\n", + " 7.242005825042725,\n", + " 0.4186532497406006,\n", + " 2.14109206199646,\n", + " 11.785279273986816,\n", + " 25.857601165771484,\n", + " 39.26039505004883,\n", + " 47.13853454589844,\n", + " 46.63816452026367,\n", + " 37.9404296875,\n", + " 24.195924758911133,\n", + " 10.383381843566895,\n", + " 1.5061837434768677,\n", + " 0.7799347043037415,\n", + " 8.467645645141602,\n", + " 21.784454345703125,\n", + " 35.90642547607422,\n", + " 45.71794509887695,\n", + " 47.664859771728516,\n", + " 41.04186248779297,\n", + " 28.24802589416504,\n", + " 13.917757987976074,\n", + " 3.2420341968536377,\n", + " 0.08800103515386581,\n", + " 5.598155498504639,\n", + " 17.7764835357666,\n", + " 32.21149826049805,\n", + " 43.67424392700195,\n", + " 48.01245880126953,\n", + " 43.6546745300293,\n", + " 32.1794548034668,\n", + " 17.743610382080078,\n", + " 5.5763959884643555,\n", + " 0.0852908343076706,\n", + " 3.2594220638275146,\n", + " 13.949024200439453,\n", + " 28.281938552856445,\n", + " 41.06624221801758,\n", + " 47.67100143432617,\n", + " 45.70375061035156,\n", + " 35.87717819213867,\n", + " 21.750917434692383,\n", + " 8.44212818145752,\n", + " 0.7718656063079834,\n", + " 1.5186808109283447,\n", + " 10.412126541137695,\n", + " 24.230722427368164,\n", + " 37.968902587890625,\n", + " 46.65023422241211,\n", + " 47.13009262084961,\n", + " 39.2347526550293,\n", + " 25.824317932128906,\n", + " 11.756693840026855,\n", + " 2.127849578857422,\n", + " 0.4258512556552887,\n", + " 7.267345905303955,\n", + " 20.17418098449707,\n", + " 34.47111892700195,\n", + " 44.97934341430664,\n", + " 47.892486572265625,\n", + " 42.15541458129883,\n", + " 29.84645652770996,\n", + " 15.42454719543457,\n", + " 4.114015102386475,\n", + " 0.012123552151024342,\n", + " 4.604880332946777,\n", + " 16.228744506835938,\n", + " 30.673227310180664,\n", + " 42.70610046386719,\n", + " 47.968711853027344,\n", + " 44.554874420166016,\n", + " 33.70136642456055,\n", + " 19.339908599853516,\n", + " 6.672937393188477,\n", + " 0.28908833861351013,\n", + " 2.5009899139404297,\n", + " 12.507547378540039,\n", + " 26.684123992919922,\n", + " 39.895545959472656,\n", + " 47.35624694824219,\n", + " 46.363800048828125,\n", + " 37.2778434753418,\n", + " 23.389827728271484,\n", + " 9.730683326721191,\n", + " 1.2484371662139893,\n", + " 1.0158343315124512,\n", + " 9.117264747619629,\n", + " 22.61819839477539,\n", + " 36.628177642822266,\n", + " 46.072364807128906,\n", + " 47.529808044433594,\n", + " 40.472686767578125,\n", + " 27.45749282836914,\n", + " 13.198968887329102,\n", + " 2.862241268157959,\n", + " 0.19179345667362213,\n", + " 6.155076026916504,\n", + " 18.592052459716797,\n", + " 32.99764633178711,\n", + " 44.1536750793457,\n", + " 48.01906204223633,\n", + " 43.193702697753906,\n", + " 31.42561912536621,\n", + " 16.977758407592773,\n", + " 5.083798885345459,\n", + " 0.052278175950050354,\n", + " 3.705883502960205,\n", + " 14.721196174621582,\n", + " 29.108083724975586,\n", + " 41.65509033203125,\n", + " 47.81724548339844,\n", + " 45.362396240234375,\n", + " 35.179847717285156,\n", + " 20.958147048950195,\n", + " 7.849003314971924,\n", + " 0.601097047328949,\n", + " 1.8399325609207153,\n", + " 11.116769790649414,\n", + " 25.07118034362793,\n", + " 38.64830780029297,\n", + " 46.929988861083984,\n", + " 46.916259765625,\n", + " 38.61207580566406,\n", + " 25.025535583496094,\n", + " 11.078215599060059,\n", + " 1.8223642110824585,\n", + " 0.6108006834983826,\n", + " 7.882364273071289,\n", + " 21.002965927124023,\n", + " 35.219749450683594,\n", + " 45.382789611816406,\n", + " 47.81056213378906,\n", + " 41.623565673828125,\n", + " 29.06293487548828,\n", + " 14.678563117980957,\n", + " 3.6809725761413574,\n", + " 0.05386172607541084,\n", + " 5.11103630065918,\n", + " 17.020503997802734,\n", + " 31.46808433532715,\n", + " 43.22019577026367,\n", + " 48.019657135009766,\n", + " 44.12782287597656,\n", + " 32.954349517822266,\n", + " 18.546627044677734,\n", + " 6.123597621917725,\n", + " 0.18527379631996155,\n", + " 2.882638692855835,\n", + " 13.23847770690918,\n", + " 27.501373291015625,\n", + " 40.50460433959961,\n", + " 47.537757873535156,\n", + " 46.053016662597656,\n", + " 36.58807373046875,\n", + " 22.571388244628906,\n", + " 9.08022689819336,\n", + " 1.0014973878860474,\n", + " 1.261501669883728,\n", + " 9.765915870666504,\n", + " 23.4339599609375,\n", + " 37.31438064575195,\n", + " 46.37897491455078,\n", + " 47.34406280517578,\n", + " 39.859893798828125,\n", + " 26.63738250732422,\n", + " 12.466124534606934,\n", + " 2.479365825653076,\n", + " 0.2945669889450073,\n", + " 6.703005790710449,\n", + " 19.38314437866211,\n", + " 33.74158477783203,\n", + " 44.57697677612305,\n", + " 47.96417236328125,\n", + " 42.67604064941406,\n", + " 30.628026962280273,\n", + " 16.18425941467285,\n", + " 4.576717853546143,\n", + " 0.00998594518750906,\n", + " 4.138182640075684,\n", + " 15.465778350830078,\n", + " 29.889331817626953,\n", + " 42.183929443359375,\n", + " 47.895851135253906,\n", + " 44.95588684082031,\n", + " 34.42888641357422,\n", + " 20.128040313720703,\n", + " 7.233575820922852,\n", + " 0.41627734899520874,\n", + " 2.1455371379852295,\n", + " 11.794844627380371,\n", + " 25.86873435974121,\n", + " 39.268978118896484,\n", + " 47.14138412475586,\n", + " 46.6341552734375,\n", + " 37.93095397949219,\n", + " 24.184341430664062,\n", + " 10.373819351196289,\n", + " 1.502044439315796,\n", + " 0.7826577425003052,\n", + " 8.476187705993652,\n", + " 21.795671463012695,\n", + " 35.916202545166016,\n", + " 45.722713470458984,\n", + " 47.66282272338867,\n", + " 41.03374099731445,\n", + " 28.236730575561523,\n", + " 13.907356262207031,\n", + " 3.236266851425171,\n", + " 0.08893852680921555,\n", + " 5.605445861816406,\n", + " 17.787477493286133,\n", + " 32.22220230102539,\n", + " 43.680789947509766,\n", + " 48.012481689453125,\n", + " 43.64816665649414,\n", + " 32.16879653930664,\n", + " 17.73267364501953,\n", + " 5.569169998168945,\n", + " 0.08441997319459915,\n", + " 3.265254020690918,\n", + " 13.95948314666748,\n", + " 28.293272018432617,\n", + " 41.074398040771484,\n", + " 47.67307662963867,\n", + " 45.69904327392578,\n", + " 35.86745071411133,\n", + " 21.73975372314453,\n", + " 8.433645248413086,\n", + " 0.7692044377326965,\n", + " 1.5228785276412964,\n", + " 10.421737670898438,\n", + " 24.24234390258789,\n", + " 37.97840881347656,\n", + " 46.654273986816406,\n", + " 47.12729263305664,\n", + " 39.22621536254883,\n", + " 25.813234329223633,\n", + " 11.747184753417969,\n", + " 2.1234588623046875,\n", + " 0.4282756745815277,\n", + " 7.275813102722168,\n", + " 20.18573760986328,\n", + " 34.481693267822266,\n", + " 44.985206604003906,\n", + " 47.891632080078125,\n", + " 42.148277282714844,\n", + " 29.83574676513672,\n", + " 15.414258003234863,\n", + " 4.107995510101318,\n", + " 0.012680652551352978,\n", + " 4.6119384765625,\n", + " 16.23987579345703,\n", + " 30.684524536132812,\n", + " 42.713600158691406,\n", + " 47.96981430053711,\n", + " 44.54933547973633,\n", + " 33.69130325317383,\n", + " 19.3291015625,\n", + " 6.665433883666992,\n", + " 0.2877325117588043,\n", + " 2.5064029693603516,\n", + " 12.517895698547363,\n", + " 26.695791244506836,\n", + " 39.90443420410156,\n", + " 47.3592529296875,\n", + " 46.35997009277344,\n", + " 37.26869583129883,\n", + " 23.378786087036133,\n", + " 9.721874237060547,\n", + " 1.2451752424240112,\n", + " 1.019419550895691,\n", + " 9.126518249511719,\n", + " 22.6298828125,\n", + " 36.638179779052734,\n", + " 46.07716751098633,\n", + " 47.52777862548828,\n", + " 40.46467590332031,\n", + " 27.446502685546875,\n", + " 13.189085006713867,\n", + " 2.857137680053711,\n", + " 0.19341740012168884,\n", + " 6.162932395935059,\n", + " 18.603382110595703,\n", + " 33.0084342956543,\n", + " 44.160099029541016,\n", + " 48.018863677978516,\n", + " 43.187034606933594,\n", + " 31.414968490600586,\n", + " 16.967052459716797,\n", + " 5.076975345611572,\n", + " 0.051871158182621,\n", + " 3.7120935916900635,\n", + " 14.731826782226562,\n", + " 29.119327545166016,\n", + " 41.66291809082031,\n", + " 47.81885528564453,\n", + " 45.35725402832031,\n", + " 35.16982650756836,\n", + " 20.946914672851562,\n", + " 7.840641975402832,\n", + " 0.5986564755439758,\n", + " 1.8443087339401245,\n", + " 11.126387596130371,\n", + " 25.08255386352539,\n", + " 38.657318115234375,\n", + " 46.93337631225586,\n", + " 46.91277313232422,\n", + " 38.60296630859375,\n", + " 25.014087677001953,\n", + " 11.068550109863281,\n", + " 1.817955732345581,\n", + " 0.6132127642631531,\n", + " 7.890689373016357,\n", + " 21.014148712158203,\n", + " 35.229698181152344,\n", + " 45.387840270996094,\n", + " 47.808837890625,\n", + " 41.61564636230469,\n", + " 29.051618576049805,\n", + " 14.66788101196289,\n", + " 3.6747305393218994,\n", + " 0.05424736067652702,\n", + " 5.11783504486084,\n", + " 17.031171798706055,\n", + " 31.478673934936523,\n", + " 43.22679138183594,\n", + " 48.019779205322266,\n", + " 44.12131881713867,\n", + " 32.94350814819336,\n", + " 18.535255432128906,\n", + " 6.115721702575684,\n", + " 0.18364056944847107,\n", + " 2.887730598449707,\n", + " 13.248343467712402,\n", + " 27.512325286865234,\n", + " 40.51255798339844,\n", + " 47.53971481323242,\n", + " 46.04814529418945,\n", + " 36.578025817871094,\n", + " 22.559669494628906,\n", + " 9.070963859558105,\n", + " 0.9979168772697449,\n", + " 1.264772891998291,\n", + " 9.774725914001465,\n", + " 23.44498634338379,\n", + " 37.3234977722168,\n", + " 46.382755279541016,\n", + " 47.34099578857422,\n", + " 39.85095977783203,\n", + " 26.625690460205078,\n", + " 12.455771446228027,\n", + " 2.473971128463745,\n", + " 0.2959519922733307,\n", + " 6.710536479949951,\n", + " 19.393962860107422,\n", + " 33.75163269042969,\n", + " 44.58250045776367,\n", + " 47.96303176879883,\n", + " 42.668521881103516,\n", + " 30.616731643676758,\n", + " 16.17315101623535,\n", + " 4.569698333740234,\n", + " 0.009476072154939175,\n", + " 4.144247531890869,\n", + " 15.476103782653809,\n", + " 29.900062561035156,\n", + " 42.19105911254883,\n", + " 47.896690368652344,\n", + " 44.950016021728516,\n", + " 34.41832733154297,\n", + " 20.116519927978516,\n", + " 7.225159645080566,\n", + " 0.4139162003993988,\n", + " 2.1499907970428467,\n", + " 11.804407119750977,\n", + " 25.879854202270508,\n", + " 39.27754211425781,\n", + " 47.14419937133789,\n", + " 46.63013458251953,\n", + " 37.92147445678711,\n", + " 24.17276382446289,\n", + " 10.364273071289062,\n", + " 1.4979220628738403,\n", + " 0.7853949666023254,\n", + " 8.484739303588867,\n", + " 21.806888580322266,\n", + " 35.92597961425781,\n", + " 45.72746276855469,\n", + " 47.660797119140625,\n", + " 41.025638580322266,\n", + " 28.22545623779297,\n", + " 13.896970748901367,\n", + " 3.2305147647857666,\n", + " 0.08989264070987701,\n", + " 5.61275053024292,\n", + " 17.79848289489746,\n", + " 32.23292541503906,\n", + " 43.68734359741211,\n", + " 48.01250076293945,\n", + " 43.641666412353516,\n", + " 32.158138275146484,\n", + " 17.72174835205078,\n", + " 5.561957836151123,\n", + " 0.08356527239084244,\n", + " 3.2711024284362793,\n", + " 13.96995735168457,\n", + " 28.304622650146484,\n", + " 41.082550048828125,\n", + " 47.675132751464844,\n", + " 45.6943244934082,\n", + " 35.85771179199219,\n", + " 21.728591918945312,\n", + " 8.425172805786133,\n", + " 0.7665574550628662,\n", + " 1.5270906686782837,\n", + " 10.431361198425293,\n", + " 24.253976821899414,\n", + " 37.9879264831543,\n", + " 46.6583137512207,\n", + " 47.12447738647461,\n", + " 39.21766662597656,\n", + " 25.80215072631836,\n", + " 11.737676620483398,\n", + " 2.11907696723938,\n", + " 0.4307127594947815,\n", + " 7.284297466278076,\n", + " 20.197315216064453,\n", + " 34.492271423339844,\n", + " 44.991085052490234,\n", + " 47.890804290771484,\n", + " 42.14115905761719,\n", + " 29.825029373168945,\n", + " 15.403965950012207,\n", + " 4.101979732513428,\n", + " 0.013246170245110989,\n", + " 4.619010925292969,\n", + " 16.251022338867188,\n", + " 30.695838928222656,\n", + " 42.72111511230469,\n", + " 47.970951080322266,\n", + " 44.54378890991211,\n", + " 33.68124008178711,\n", + " 19.318286895751953,\n", + " 6.657920837402344,\n", + " 0.2863781154155731,\n", + " 2.5118298530578613,\n", + " 12.528270721435547,\n", + " 26.707487106323242,\n", + " 39.91333770751953,\n", + " 47.362281799316406,\n", + " 46.356143951416016,\n", + " 37.259525299072266,\n", + " 23.36772918701172,\n", + " 9.713057518005371,\n", + " 1.2419092655181885,\n", + " 1.023009181022644,\n", + " 9.135781288146973,\n", + " 22.64158058166504,\n", + " 36.648189544677734,\n", + " 46.081966400146484,\n", + " 47.5257453918457,\n", + " 40.45664596557617,\n", + " 27.435495376586914,\n", + " 13.179183006286621,\n", + " 2.8520262241363525,\n", + " 0.19503945112228394,\n", + " 6.170792102813721,\n", + " 18.61471939086914,\n", + " 33.01922607421875,\n", + " 44.16651153564453,\n", + " 48.0186653137207,\n", + " 43.180362701416016,\n", + " 31.404300689697266,\n", + " 16.956331253051758,\n", + " 5.070141315460205,\n", + " 0.05145828425884247,\n", + " 3.7183074951171875,\n", + " 14.742465019226074,\n", + " 29.130584716796875,\n", + " 41.67076110839844,\n", + " 47.82048416137695,\n", + " 45.35210418701172,\n", + " 35.159793853759766,\n", + " 20.9356632232666,\n", + " 7.832267761230469,\n", + " 0.5962072014808655,\n", + " 1.8486831188201904,\n", + " 11.136005401611328,\n", + " 25.09393310546875,\n", + " 38.666343688964844,\n", + " 46.93675231933594,\n", + " 46.90928268432617,\n", + " 38.5938606262207,\n", + " 25.00263214111328,\n", + " 11.05887508392334,\n", + " 1.8135403394699097,\n", + " 0.615619421005249,\n", + " 7.899007320404053,\n", + " 21.025320053100586,\n", + " 35.23963165283203,\n", + " 45.392887115478516,\n", + " 47.80710983276367,\n", + " 41.607704162597656,\n", + " 29.040283203125,\n", + " 14.657190322875977,\n", + " 3.6684818267822266,\n", + " 0.05462893843650818,\n", + " 5.124626636505127,\n", + " 17.041828155517578,\n", + " 31.48925018310547,\n", + " 43.23335647583008,\n", + " 48.019866943359375,\n", + " 44.11480712890625,\n", + " 32.932640075683594,\n", + " 18.52387046813965,\n", + " 6.107840061187744,\n", + " 0.18200506269931793,\n", + " 2.892824411392212,\n", + " 13.258208274841309,\n", + " 27.523269653320312,\n", + " 40.520511627197266,\n", + " 47.54166793823242,\n", + " 46.043270111083984,\n", + " 36.567962646484375,\n", + " 22.547948837280273,\n", + " 9.061697959899902,\n", + " 0.9943376779556274,\n", + " 1.2680473327636719,\n", + " 9.783537864685059,\n", + " 23.45601463317871,\n", + " 37.332618713378906,\n", + " 46.386539459228516,\n", + " 47.33793258666992,\n", + " 39.842037200927734,\n", + " 26.614004135131836,\n", + " 12.445426940917969,\n", + " 2.4685826301574707,\n", + " 0.2973437011241913,\n", + " 6.718074798583984,\n", + " 19.404788970947266,\n", + " 33.76170349121094,\n", + " 44.58802795410156,\n", + " 47.961891174316406,\n", + " 42.661014556884766,\n", + " 30.605440139770508,\n", + " 16.162050247192383,\n", + " 4.562687873840332,\n", + " 0.008976398035883904,\n", + " 4.150323867797852,\n", + " 15.486442565917969,\n", + " 29.910804748535156,\n", + " 42.198211669921875,\n", + " 47.89754867553711,\n", + " 44.94417190551758,\n", + " 34.40779113769531,\n", + " 20.10500717163086,\n", + " 7.21675443649292,\n", + " 0.41156795620918274,\n", + " 2.154458999633789,\n", + " 11.813986778259277,\n", + " 25.890987396240234,\n", + " 39.28611755371094,\n", + " 47.14704513549805,\n", + " 46.62612533569336,\n", + " 37.912010192871094,\n", + " 24.16120147705078,\n", + " 10.354740142822266,\n", + " 1.493814468383789,\n", + " 0.7881485223770142,\n", + " 8.493309020996094,\n", + " 21.818126678466797,\n", + " 35.93577194213867,\n", + " 45.732215881347656,\n", + " 47.658775329589844,\n", + " 41.01753234863281,\n", + " 28.214183807373047,\n", + " 13.886595726013184,\n", + " 3.2247796058654785,\n", + " 0.09086360037326813,\n", + " 5.6200714111328125,\n", + " 17.809499740600586,\n", + " 32.243648529052734,\n", + " 43.69390106201172,\n", + " 48.012516021728516,\n", + " 43.63515853881836,\n", + " 32.147483825683594,\n", + " 17.710830688476562,\n", + " 5.554759502410889,\n", + " 0.08272704482078552,\n", + " 3.276963949203491,\n", + " 13.980438232421875,\n", + " 28.315967559814453,\n", + " 41.09071350097656,\n", + " 47.677188873291016,\n", + " 45.68959426879883,\n", + " 35.84798049926758,\n", + " 21.717445373535156,\n", + " 8.416711807250977,\n", + " 0.7639248371124268,\n", + " 1.531317114830017,\n", + " 10.440995216369629,\n", + " 24.26561164855957,\n", + " 37.9974365234375,\n", + " 46.6623420715332,\n", + " 47.121665954589844,\n", + " 39.20912551879883,\n", + " 25.79106903076172,\n", + " 11.72817611694336,\n", + " 2.1147053241729736,\n", + " 0.43316206336021423,\n", + " 7.292791366577148,\n", + " 20.20888328552246,\n", + " 34.502838134765625,\n", + " 44.9969482421875,\n", + " 47.88994216918945,\n", + " 42.13400650024414,\n", + " 29.814308166503906,\n", + " 15.393670082092285,\n", + " 4.095966815948486,\n", + " 0.013819677755236626,\n", + " 4.626088619232178,\n", + " 16.262165069580078,\n", + " 30.707136154174805,\n", + " 42.72860336303711,\n", + " 47.9720458984375,\n", + " 44.53822326660156,\n", + " 33.6711540222168,\n", + " 19.307464599609375,\n", + " 6.6504130363464355,\n", + " 0.2850286662578583,\n", + " 2.5172553062438965,\n", + " 12.538633346557617,\n", + " 26.719165802001953,\n", + " 39.92222213745117,\n", + " 47.365272521972656,\n", + " 46.35228729248047,\n", + " 37.250343322753906,\n", + " 23.356666564941406,\n", + " 9.704236030578613,\n", + " 1.2386443614959717,\n", + " 1.026596188545227,\n", + " 9.145030975341797,\n", + " 22.65325355529785,\n", + " 36.65815734863281,\n", + " 46.086727142333984,\n", + " 47.5236701965332,\n", + " 40.44859313964844,\n", + " 27.42446517944336,\n", + " 13.169272422790527,\n", + " 2.8469107151031494,\n", + " 0.19665808975696564,\n", + " 6.178638458251953,\n", + " 18.62603187561035,\n", + " 33.02998352050781,\n", + " 44.172882080078125,\n", + " 48.01840591430664,\n", + " 43.17363357543945,\n", + " 31.39360809326172,\n", + " 16.94559097290039,\n", + " 5.063300609588623,\n", + " 0.051040634512901306,\n", + " 3.7245066165924072,\n", + " 14.753074645996094,\n", + " 29.141794204711914,\n", + " 41.67854309082031,\n", + " 47.82203674316406,\n", + " 45.34687805175781,\n", + " 35.149723052978516,\n", + " 20.92438507080078,\n", + " 7.8238844871521,\n", + " 0.593753457069397,\n", + " 1.853044867515564,\n", + " 11.14559268951416,\n", + " 25.105268478393555,\n", + " 38.6753044128418,\n", + " 46.94007110595703,\n", + " 46.905723571777344,\n", + " 38.584686279296875,\n", + " 24.991140365600586,\n", + " 11.04918384552002,\n", + " 1.809116244316101,\n", + " 0.6180207133293152,\n", + " 7.907314777374268,\n", + " 21.03647232055664,\n", + " 35.24953079223633,\n", + " 45.397884368896484,\n", + " 47.805320739746094,\n", + " 41.599708557128906,\n", + " 29.028907775878906,\n", + " 14.646471977233887,\n", + " 3.6622226238250732,\n", + " 0.05500727519392967,\n", + " 5.131416320800781,\n", + " 17.052480697631836,\n", + " 31.499801635742188,\n", + " 43.23990249633789,\n", + " 48.019920349121094,\n", + " 44.10824203491211,\n", + " 32.92172622680664,\n", + " 18.5124568939209,\n", + " 6.099941730499268,\n", + " 0.18036723136901855,\n", + " 2.897919178009033,\n", + " 13.268067359924316,\n", + " 27.534196853637695,\n", + " 40.52841567993164,\n", + " 47.54356384277344,\n", + " 46.0383415222168,\n", + " 36.557857513427734,\n", + " 22.536197662353516,\n", + " 9.052422523498535,\n", + " 0.990759015083313,\n", + " 1.2713243961334229,\n", + " 9.79234504699707,\n", + " 23.46702003479004,\n", + " 37.34169387817383,\n", + " 46.3902473449707,\n", + " 47.33479690551758,\n", + " 39.833038330078125,\n", + " 26.602272033691406,\n", + " 12.43505859375,\n", + " 2.4631946086883545,\n", + " 0.29874250292778015,\n", + " 6.725613594055176,\n", + " 19.415590286254883,\n", + " 33.7717170715332,\n", + " 44.59348678588867,\n", + " 47.9606819152832,\n", + " 42.65342330932617,\n", + " 30.594093322753906,\n", + " 16.150917053222656,\n", + " 4.555673122406006,\n", + " 0.008487294428050518,\n", + " 4.156408786773682,\n", + " 15.496768951416016,\n", + " 29.921506881713867,\n", + " 42.20529556274414,\n", + " 47.89832305908203,\n", + " 44.938236236572266,\n", + " 34.39719772338867,\n", + " 20.093469619750977,\n", + " 7.208343029022217,\n", + " 0.4092309772968292,\n", + " 2.1589431762695312,\n", + " 11.823572158813477,\n", + " 25.902109146118164,\n", + " 39.29465866088867,\n", + " 47.14982223510742,\n", + " 46.622066497802734,\n", + " 37.90248489379883,\n", + " 24.149599075317383,\n", + " ...]" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model=SGDTrain(LinearRegKernel(3),torch.tensor([[1,2],[3,4],[5,6]]).to(torch.float),torch.tensor([[1,2]]).to(torch.float),iterations=10000, learning_rate=.001, return_losses=True)\n", + "model[1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}