diff --git a/basic.ipynb b/basic.ipynb index 682f8a4..3a301a8 100644 --- a/basic.ipynb +++ b/basic.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "encoding_width = 64\n", + "database_size = 100" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -9,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -54,28 +64,26 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Generate and encrypt query vector\n", - "x = np.array([1.5, 2, 3.3, 4])\n", - "cx = np.array([HE_client.encrypt(x[j]) for j in range(len(x))])" + "x = np.random.rand(encoding_width)\n", + "#cx = np.array([HE_client.encrypt(x[j]) for j in range(len(x))])\n", + "cx = HE_client.encrypt(x)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Client] sending HE_client= and cx=[\n", - " \n", - " \n", - " ]\n" + "[Client] sending HE_client= and cx=\n" ] } ], @@ -86,7 +94,8 @@ "s_public_key = HE_client.to_bytes_public_key()\n", "s_relin_key = HE_client.to_bytes_relin_key()\n", "s_rotate_key = HE_client.to_bytes_rotate_key()\n", - "s_cx = [cx[j].to_bytes() for j in range(len(cx))]\n", + "#s_cx = [cx[j].to_bytes() for j in range(len(cx))]\n", + "s_cx = cx.to_bytes()\n", "\n", "print(f\"[Client] sending HE_client={HE_client} and cx={cx}\")" ] @@ -100,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -113,31 +122,25 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# document matrix containing rows of document encoding vectors\n", - "D = [\n", - " [0.5, -1.5, 4, 5],\n", - " [1.0, 1.5, 4, 5]\n", - "]" + "D = np.random.rand(database_size, encoding_width)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[Server] received HE_server= and cx=[\n", - " \n", - " \n", - " ]\n", - "[Server] Distances computed! Responding: res=[(, ), (, )]\n" + "[Server] received HE_server= and cx=\n", + "[Server] Distances computed! Responding: res=[(, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, ), (, )]\n" ] } ], @@ -147,15 +150,17 @@ "HE_server.from_bytes_public_key(s_public_key)\n", "HE_server.from_bytes_relin_key(s_relin_key)\n", "HE_server.from_bytes_rotate_key(s_rotate_key)\n", - "cx = np.array([PyCtxt(pyfhel=HE_server, bytestring=s_cx[j]) for j in range(len(s_cx))])\n", + "#cx = np.array([PyCtxt(pyfhel=HE_server, bytestring=s_cx[j]) for j in range(len(s_cx))])\n", + "cx = PyCtxt(pyfhel=HE_server, bytestring=s_cx)\n", "print(f\"[Server] received HE_server={HE_server} and cx={cx}\")\n", "\n", "# Encode each document weights in plaintext\n", "res = []\n", "\n", "for i in range(len(D)):\n", - " d = np.array(D[i])\n", - " cd = np.array([HE_server.encrypt(d[j]) for j in range(len(d))])\n", + " #d = np.array(D[i])\n", + " #cd = np.array([HE_server.encrypt(d[j]) for j in range(len(d))])\n", + " cd = HE_server.encrypt(D[i])\n", " # Compute distance bewteen recieved query and D[i]\n", " res.append(hyperbolic_distance_parts(cx, cd))\n", "\n", @@ -173,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -185,20 +190,9 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 20, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Client] Response received! \n", - "Result is [np.float64(0.20736836029574815), np.float64(0.07564769989949309)] \n", - "Should be [np.float64(0.20738785895993414), np.float64(0.07565488467449914)]\n", - "Diff [1.94986642e-05 7.18477501e-06]\n" - ] - } - ], + "outputs": [], "source": [ "#res = np.array([HE_client.decrypt(c_res[j]) for j in range(len(c_res))])[:,0]\n", "#res = HE_client.decrypt(c_res)\n", @@ -212,8 +206,13 @@ " c_res.append(dist)\n", "\n", "# Checking result\n", - "expected = [hyperbolic_distance(x, np.array(w)) for w in D]\n", - "print(f\"[Client] Response received! \\nResult is {c_res} \\nShould be {expected}\\nDiff {np.abs(np.array(c_res) - np.array(expected))}\")" + "expected_res = [hyperbolic_distance(x, np.array(w)) for w in D]\n", + "#print(f\"[Client] Response received! \\nResult is {c_res} \\nShould be {expected}\\nDiff {np.abs(np.array(c_res) - np.array(expected))}\")\n", + "for i in range(len(c_res)):\n", + " result = c_res[i]\n", + " expected = expected_res[i]\n", + " #print(f\"got: {result}, expected: {expected}\")\n", + " assert np.abs(result - expected) < 1e-3" ] } ],