Playpredict use average betaI, gammaU, gammaI when item or user has not been seen before

This commit is contained in:
ltcptgeneral 2023-11-04 17:06:48 -07:00
parent aa423db398
commit 0a096b2af3
2 changed files with 2391 additions and 2379 deletions

View File

@ -9,7 +9,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -22,7 +22,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -45,7 +45,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -60,7 +60,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -95,7 +95,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -116,7 +116,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -132,7 +132,15 @@
"\n", "\n",
" # Prediction for a single instance\n", " # Prediction for a single instance\n",
" def predict(self, u, i):\n", " def predict(self, u, i):\n",
" p = self.betaI[i] + tf.tensordot(self.gammaU[u], self.gammaI[i], 1)\n", " bi = self.bi\n",
" gu = self.gu\n",
" gi = self.gi\n",
" if u != None:\n",
" gu = self.gammaU[u]\n",
" if i != None:\n",
" bi = self.betaI[i]\n",
" gi = self.gammaI[i]\n",
" p = bi + tf.tensordot(gu, gi, 1)\n",
" return p\n", " return p\n",
"\n", "\n",
" # Regularizer\n", " # Regularizer\n",
@ -153,7 +161,12 @@
" def call(self, sampleU, sampleI, sampleJ):\n", " def call(self, sampleU, sampleI, sampleJ):\n",
" x_ui = self.score(sampleU, sampleI)\n", " x_ui = self.score(sampleU, sampleI)\n",
" x_uj = self.score(sampleU, sampleJ)\n", " x_uj = self.score(sampleU, sampleJ)\n",
" return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))" " return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_ui - x_uj)))\n",
" \n",
" def finalize(self):\n",
" self.bi = np.average(self.betaI, axis=0)\n",
" self.gu = np.average(self.gammaU, axis=0)\n",
" self.gi = np.average(self.gammaI, axis=0)"
] ]
}, },
{ {
@ -165,7 +178,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -175,8 +188,6 @@
" pass\n", " pass\n",
"\n", "\n",
" def fit(self, data, threshold=0.6, K=5, iters=100): # data is an array of (user, game, review) tuples\n", " def fit(self, data, threshold=0.6, K=5, iters=100): # data is an array of (user, game, review) tuples\n",
" self.topGames = self.getTopGames(threshold)\n",
"\n",
" self.userIDs = {}\n", " self.userIDs = {}\n",
" self.itemIDs = {}\n", " self.itemIDs = {}\n",
" interactions = []\n", " interactions = []\n",
@ -221,85 +232,86 @@
" for i in range(iters):\n", " for i in range(iters):\n",
" obj = trainingStepBPR(self.modelBPR, interactions)\n", " obj = trainingStepBPR(self.modelBPR, interactions)\n",
" if (i % 10 == 9): print(\"iteration \" + str(i+1) + \", objective = \" + str(obj))\n", " if (i % 10 == 9): print(\"iteration \" + str(i+1) + \", objective = \" + str(obj))\n",
"\n",
" self.modelBPR.finalize()\n",
" \n", " \n",
" def predict(self, user, game, threshold=0.5):\n", " def predict(self, user, game, threshold=0.5):\n",
" if user in self.userIDs and game in self.itemIDs:\n", " uid = None\n",
" pred = self.modelBPR.predict(self.userIDs[user], self.itemIDs[game]).numpy()\n", " gid = None\n",
" return int(pred > threshold)\n", " if user in self.userIDs:\n",
" else:\n", " uid = self.userIDs[user]\n",
" return int(game in self.topGames)\n", " if game in self.itemIDs:\n",
"\n", " gid = self.itemIDs[game]\n",
" def getTopGames (self, threshold):\n", " pred = self.modelBPR.predict(uid, gid).numpy()\n",
" gameCount = defaultdict(int)\n", " return int(pred > threshold)\n"
" totalPlayed = 0\n",
"\n",
" for user,game,_ in readJSON(\"train.json.gz\"):\n",
" gameCount[game] += 1\n",
" totalPlayed += 1\n",
"\n",
" mostPopular = [(gameCount[x], x) for x in gameCount]\n",
" mostPopular.sort()\n",
" mostPopular.reverse()\n",
"\n",
" return1 = set()\n",
" count = 0\n",
" for ic, i in mostPopular:\n",
" count += ic\n",
" return1.add(i)\n",
" if count > totalPlayed * threshold: break\n",
" return return1\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"iteration 10, objective = 0.51180786\n", "iteration 10, objective = 0.5121787\n",
"iteration 20, objective = 0.48082852\n", "iteration 20, objective = 0.4860348\n",
"iteration 30, objective = 0.47100148\n", "iteration 30, objective = 0.4675451\n",
"iteration 40, objective = 0.45862892\n", "iteration 40, objective = 0.46167055\n",
"iteration 50, objective = 0.45290428\n", "iteration 50, objective = 0.45701832\n",
"iteration 60, objective = 0.44695023\n", "iteration 60, objective = 0.44749424\n",
"iteration 70, objective = 0.4453482\n", "iteration 70, objective = 0.44757926\n",
"iteration 80, objective = 0.444919\n", "iteration 80, objective = 0.4452785\n",
"iteration 90, objective = 0.4451945\n", "iteration 90, objective = 0.4446122\n",
"iteration 100, objective = 0.44311014\n", "iteration 100, objective = 0.44039646\n",
"iteration 110, objective = 0.44101325\n", "iteration 110, objective = 0.44507992\n",
"iteration 120, objective = 0.43727913\n", "iteration 120, objective = 0.44116876\n",
"iteration 130, objective = 0.43938398\n", "iteration 130, objective = 0.4395796\n",
"iteration 140, objective = 0.43788543\n", "iteration 140, objective = 0.4408364\n",
"iteration 150, objective = 0.43573555\n", "iteration 150, objective = 0.44295114\n",
"iteration 160, objective = 0.4379884\n", "iteration 160, objective = 0.43921968\n",
"iteration 170, objective = 0.43852594\n", "iteration 170, objective = 0.44189137\n",
"iteration 180, objective = 0.4391472\n", "iteration 180, objective = 0.43661243\n",
"iteration 190, objective = 0.4318109\n", "iteration 190, objective = 0.43899748\n",
"iteration 200, objective = 0.4389726\n", "iteration 200, objective = 0.4371943\n"
"PlayPredictor accuracy: 0.7234723472347235\n"
] ]
} }
], ],
"source": [ "source": [
"model = PlayPredictor()\n", "model = PlayPredictor()\n",
"model.fit(train, K=6, iters=200)\n", "model.fit(train, K=6, iters=200)"
"\n",
"error = 0\n",
"balanced_valid = get_balanced_set(dataset, valid)\n",
"for user, game, review in balanced_valid:\n",
" pred = model.predict(user, game, threshold=0.5)\n",
" if pred != review[\"played\"]:\n",
" error += 1\n",
"\n",
"print(f\"PlayPredictor accuracy: \", 1 - error / len(balanced_valid))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[7044 2955]\n",
" [2598 7401]]\n",
"PlayPredictor accuracy: 0.7223222322232223\n"
]
}
],
"source": [
"CM = np.array([[0,0], [0,0]])\n",
"balanced_valid = get_balanced_set(dataset, valid)\n",
"for user, game, review in balanced_valid:\n",
" pred = model.predict(user, game, threshold=0.5)\n",
" CM[review[\"played\"]][pred] += 1\n",
"\n",
"print(CM)\n",
"print(f\"PlayPredictor accuracy: \", 1 - (CM[1][0] + CM[0][1]) / len(balanced_valid))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -315,7 +327,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -392,7 +404,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -423,7 +435,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [

File diff suppressed because it is too large Load Diff