Playpredict use average betaI, gammaU, gammaI when item or user has not been seen before
This commit is contained in:
parent
aa423db398
commit
0a096b2af3
@ -9,7 +9,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -22,7 +22,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -45,7 +45,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -95,7 +95,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -116,7 +116,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -132,7 +132,15 @@
|
||||
"\n",
|
||||
" # Prediction for a single instance\n",
|
||||
" def predict(self, u, i):\n",
|
||||
" p = self.betaI[i] + tf.tensordot(self.gammaU[u], self.gammaI[i], 1)\n",
|
||||
" bi = self.bi\n",
|
||||
" gu = self.gu\n",
|
||||
" gi = self.gi\n",
|
||||
" if u != None:\n",
|
||||
" gu = self.gammaU[u]\n",
|
||||
" if i != None:\n",
|
||||
" bi = self.betaI[i]\n",
|
||||
" gi = self.gammaI[i]\n",
|
||||
" p = bi + tf.tensordot(gu, gi, 1)\n",
|
||||
" return p\n",
|
||||
"\n",
|
||||
" # Regularizer\n",
|
||||
@ -153,7 +161,12 @@
|
||||
" def call(self, sampleU, sampleI, sampleJ):\n",
|
||||
" x_ui = self.score(sampleU, sampleI)\n",
|
||||
" x_uj = self.score(sampleU, sampleJ)\n",
|
||||
" return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))"
|
||||
" return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))\n",
|
||||
" \n",
|
||||
" def finalize(self):\n",
|
||||
" self.bi = np.average(self.betaI, axis=0)\n",
|
||||
" self.gu = np.average(self.gammaU, axis=0)\n",
|
||||
" self.gi = np.average(self.gammaI, axis=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -165,7 +178,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -175,8 +188,6 @@
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def fit(self, data, threshold=0.6, K=5, iters=100): # data is an array of (user, game, review) tuples\n",
|
||||
" self.topGames = self.getTopGames(threshold)\n",
|
||||
"\n",
|
||||
" self.userIDs = {}\n",
|
||||
" self.itemIDs = {}\n",
|
||||
" interactions = []\n",
|
||||
@ -222,84 +233,85 @@
|
||||
" obj = trainingStepBPR(self.modelBPR, interactions)\n",
|
||||
" if (i % 10 == 9): print(\"iteration \" + str(i+1) + \", objective = \" + str(obj))\n",
|
||||
"\n",
|
||||
" self.modelBPR.finalize()\n",
|
||||
" \n",
|
||||
" def predict(self, user, game, threshold=0.5):\n",
|
||||
" if user in self.userIDs and game in self.itemIDs:\n",
|
||||
" pred = self.modelBPR.predict(self.userIDs[user], self.itemIDs[game]).numpy()\n",
|
||||
" return int(pred > threshold)\n",
|
||||
" else:\n",
|
||||
" return int(game in self.topGames)\n",
|
||||
"\n",
|
||||
" def getTopGames (self, threshold):\n",
|
||||
" gameCount = defaultdict(int)\n",
|
||||
" totalPlayed = 0\n",
|
||||
"\n",
|
||||
" for user,game,_ in readJSON(\"train.json.gz\"):\n",
|
||||
" gameCount[game] += 1\n",
|
||||
" totalPlayed += 1\n",
|
||||
"\n",
|
||||
" mostPopular = [(gameCount[x], x) for x in gameCount]\n",
|
||||
" mostPopular.sort()\n",
|
||||
" mostPopular.reverse()\n",
|
||||
"\n",
|
||||
" return1 = set()\n",
|
||||
" count = 0\n",
|
||||
" for ic, i in mostPopular:\n",
|
||||
" count += ic\n",
|
||||
" return1.add(i)\n",
|
||||
" if count > totalPlayed * threshold: break\n",
|
||||
" return return1\n"
|
||||
" uid = None\n",
|
||||
" gid = None\n",
|
||||
" if user in self.userIDs:\n",
|
||||
" uid = self.userIDs[user]\n",
|
||||
" if game in self.itemIDs:\n",
|
||||
" gid = self.itemIDs[game]\n",
|
||||
" pred = self.modelBPR.predict(uid, gid).numpy()\n",
|
||||
" return int(pred > threshold)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"iteration 10, objective = 0.51180786\n",
|
||||
"iteration 20, objective = 0.48082852\n",
|
||||
"iteration 30, objective = 0.47100148\n",
|
||||
"iteration 40, objective = 0.45862892\n",
|
||||
"iteration 50, objective = 0.45290428\n",
|
||||
"iteration 60, objective = 0.44695023\n",
|
||||
"iteration 70, objective = 0.4453482\n",
|
||||
"iteration 80, objective = 0.444919\n",
|
||||
"iteration 90, objective = 0.4451945\n",
|
||||
"iteration 100, objective = 0.44311014\n",
|
||||
"iteration 110, objective = 0.44101325\n",
|
||||
"iteration 120, objective = 0.43727913\n",
|
||||
"iteration 130, objective = 0.43938398\n",
|
||||
"iteration 140, objective = 0.43788543\n",
|
||||
"iteration 150, objective = 0.43573555\n",
|
||||
"iteration 160, objective = 0.4379884\n",
|
||||
"iteration 170, objective = 0.43852594\n",
|
||||
"iteration 180, objective = 0.4391472\n",
|
||||
"iteration 190, objective = 0.4318109\n",
|
||||
"iteration 200, objective = 0.4389726\n",
|
||||
"PlayPredictor accuracy: 0.7234723472347235\n"
|
||||
"iteration 10, objective = 0.5121787\n",
|
||||
"iteration 20, objective = 0.4860348\n",
|
||||
"iteration 30, objective = 0.4675451\n",
|
||||
"iteration 40, objective = 0.46167055\n",
|
||||
"iteration 50, objective = 0.45701832\n",
|
||||
"iteration 60, objective = 0.44749424\n",
|
||||
"iteration 70, objective = 0.44757926\n",
|
||||
"iteration 80, objective = 0.4452785\n",
|
||||
"iteration 90, objective = 0.4446122\n",
|
||||
"iteration 100, objective = 0.44039646\n",
|
||||
"iteration 110, objective = 0.44507992\n",
|
||||
"iteration 120, objective = 0.44116876\n",
|
||||
"iteration 130, objective = 0.4395796\n",
|
||||
"iteration 140, objective = 0.4408364\n",
|
||||
"iteration 150, objective = 0.44295114\n",
|
||||
"iteration 160, objective = 0.43921968\n",
|
||||
"iteration 170, objective = 0.44189137\n",
|
||||
"iteration 180, objective = 0.43661243\n",
|
||||
"iteration 190, objective = 0.43899748\n",
|
||||
"iteration 200, objective = 0.4371943\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = PlayPredictor()\n",
|
||||
"model.fit(train, K=6, iters=200)\n",
|
||||
"\n",
|
||||
"error = 0\n",
|
||||
"balanced_valid = get_balanced_set(dataset, valid)\n",
|
||||
"for user, game, review in balanced_valid:\n",
|
||||
" pred = model.predict(user, game, threshold=0.5)\n",
|
||||
" if pred != review[\"played\"]:\n",
|
||||
" error += 1\n",
|
||||
"\n",
|
||||
"print(f\"PlayPredictor accuracy: \", 1 - error / len(balanced_valid))"
|
||||
"model.fit(train, K=6, iters=200)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[7044 2955]\n",
|
||||
" [2598 7401]]\n",
|
||||
"PlayPredictor accuracy: 0.7223222322232223\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"CM = np.array([[0,0], [0,0]])\n",
|
||||
"balanced_valid = get_balanced_set(dataset, valid)\n",
|
||||
"for user, game, review in balanced_valid:\n",
|
||||
" pred = model.predict(user, game, threshold=0.5)\n",
|
||||
" CM[review[\"played\"]][pred] += 1\n",
|
||||
"\n",
|
||||
"print(CM)\n",
|
||||
"print(f\"PlayPredictor accuracy: \", 1 - (CM[1][0] + CM[0][1]) / len(balanced_valid))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -315,7 +327,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -392,7 +404,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -423,7 +435,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user