diff --git a/basic.ipynb b/basic.ipynb index 812fbe3..5dd1600 100644 --- a/basic.ipynb +++ b/basic.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 44, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -64,26 +64,28 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Generate and encrypt query vector\n", "x = np.random.rand(encoding_width)\n", "#cx = np.array([HE_client.encrypt(x[j]) for j in range(len(x))])\n", + "# precompute 1 - |x|^2 for query vector\n", + "du = 1 - x @ x\n", "cx = HE_client.encrypt(x)" ] }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Client] sending HE_client= and cx=\n" + "[Client] sending HE_client= and cx=\n" ] } ], @@ -109,38 +111,45 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def hyperbolic_distance_parts(u, v): # returns only the numerator and denominator of the hyperbolic distance formula\n", " diff = u - v\n", - " du = -(1 - u @ u) # for some reason we need to negate this\n", - " dv = -(1 - v @ v) # for some reason we need to negate this\n", - " return diff @ diff, du * dv # returns the numerator and denominator\n" + " #du = -(1 - u @ u) # for some reason we need to negate this\n", + " #dv = -(1 - v @ v) # for some reason we need to negate this\n", + " #return diff @ diff, du * dv # returns the numerator and denominator\n", + " return diff @ diff # returns the numerator and denominator\n" ] }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# document matrix containing rows of document encoding vectors\n", - "D = np.random.rand(database_size, encoding_width)" + "D = np.random.rand(database_size, encoding_width)\n", + "# precompute 1 - |D|^2 for each row vector in D\n", + "dv = []\n", + "for i in range(len(D)):\n", + " v = D[i]\n", + " dv.append(1 - v @ v)\n", + "dv = np.array(dv)" ] }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Server] received HE_server= and cx=\n", - "[Server] Distances computed! Responding: res=[(, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, )]\n" + "[Server] received HE_server= and cx=\n", + "[Server] Distances computed! Responding: res=[, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ]\n" ] } ], @@ -164,7 +173,7 @@ " # Compute distance bewteen recieved query and D[i]\n", " res.append(hyperbolic_distance_parts(cx, cd))\n", "\n", - "s_res = [(res[j][0].to_bytes(), res[j][1].to_bytes()) for j in range(len(res))]\n", + "s_res = [res[j].to_bytes() for j in range(len(res))]\n", "\n", "print(f\"[Server] Distances computed! Responding: res={res}\")" ] @@ -185,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -197,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -205,11 +214,14 @@ "#res = HE_client.decrypt(c_res)\n", "c_res = []\n", "for i in range(len(s_res)):\n", - " c_num = PyCtxt(pyfhel=HE_server, bytestring=s_res[i][0])\n", - " c_den = PyCtxt(pyfhel=HE_server, bytestring=s_res[i][1])\n", + " c_num = PyCtxt(pyfhel=HE_server, bytestring=s_res[i])\n", + " #c_den = PyCtxt(pyfhel=HE_server, bytestring=s_res[i][1])\n", " p_num = HE_client.decrypt(c_num)[0]\n", - " p_den = HE_client.decrypt(c_den)[0]\n", - " dist = np.arccosh(1 + 2 * (p_num / p_den))\n", + " #p_den = HE_client.decrypt(c_den)[0]\n", + " #dist = np.arccosh(1 + 2 * (p_num / p_den))\n", + " # compute final score\n", + " dist = np.arccosh(1 + 2 * (p_num / (du * dv[i])))\n", + " #print(dist)\n", " c_res.append(dist)\n", "\n", "# Checking result\n", @@ -218,8 +230,11 @@ "for i in range(len(c_res)):\n", " result = c_res[i]\n", " expected = expected_res[i]\n", - " #print(f\"got: {result}, expected: {expected}\")\n", - " assert np.abs(result - expected) < 1e-3" + " if np.abs(result - expected) < 1e-3:\n", + " pass\n", + " else:\n", + " print(f\"got: {result}, expected: {expected}\")\n", + " assert False" ] } ],