mirror of
https://github.com/ltcptgeneral/CS-239-Cryptography-Project.git
synced 2025-11-10 11:36:51 +00:00
convert notebook to script
This commit is contained in:
292
basic.ipynb
292
basic.ipynb
@@ -1,292 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"encoding_width = 2**13 # restricted by ckks size\n",
|
|
||||||
"database_size = 100\n",
|
|
||||||
"import time"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Client Setup"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from Pyfhel import Pyfhel, PyCtxt"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[Client] Initializing Pyfhel session and data...\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(f\"[Client] Initializing Pyfhel session and data...\")\n",
|
|
||||||
"HE_client = Pyfhel() # Creating empty Pyfhel object\n",
|
|
||||||
"ckks_params = {\n",
|
|
||||||
" 'scheme': 'CKKS', # can also be 'ckks'\n",
|
|
||||||
" 'n': 2**14, # Polynomial modulus degree. For CKKS, n/2 values can be\n",
|
|
||||||
" # encoded in a single ciphertext. \n",
|
|
||||||
" # Typ. 2^D for D in [10, 15]\n",
|
|
||||||
" 'scale': 2**30, # All the encodings will use it for float->fixed point\n",
|
|
||||||
" # conversion: x_fix = round(x_float * scale)\n",
|
|
||||||
" # You can use this as default scale or use a different\n",
|
|
||||||
" # scale on each operation (set in HE.encryptFrac)\n",
|
|
||||||
" 'qi_sizes': [60, 30, 30, 30, 60] # Number of bits of each prime in the chain. \n",
|
|
||||||
" # Intermediate values should be close to log2(scale)\n",
|
|
||||||
" # for each operation, to have small rounding errors.\n",
|
|
||||||
"}\n",
|
|
||||||
"HE_client.contextGen(**ckks_params) # Generate context for bfv scheme\n",
|
|
||||||
"HE_client.keyGen() # Generates both a public and a private key\n",
|
|
||||||
"HE_client.relinKeyGen()\n",
|
|
||||||
"HE_client.rotateKeyGen()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Generate and encrypt query vector\n",
|
|
||||||
"x = np.random.rand(encoding_width)\n",
|
|
||||||
"#cx = np.array([HE_client.encrypt(x[j]) for j in range(len(x))])\n",
|
|
||||||
"# precompute 1 - |x|^2 for query vector\n",
|
|
||||||
"du = 1 - x @ x\n",
|
|
||||||
"cx = HE_client.encrypt(x)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[Client] Sending HE_client=<ckks Pyfhel obj at 0x74bfcf9f92b0, [pk:Y, sk:Y, rtk:Y, rlk:Y, contx(n=16384, t=0, sec=128, qi=[60, 30, 30, 30, 60], scale=1073741824.0, )]> and cx=<Pyfhel Ciphertext at 0x74bfe7e34090, scheme=ckks, size=2/2, scale_bits=30, mod_level=0>\n",
|
|
||||||
"[Client] Sent 106.78689 MB\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Serializing data and public context information\n",
|
|
||||||
"\n",
|
|
||||||
"s_context = HE_client.to_bytes_context()\n",
|
|
||||||
"s_public_key = HE_client.to_bytes_public_key()\n",
|
|
||||||
"s_relin_key = HE_client.to_bytes_relin_key()\n",
|
|
||||||
"s_rotate_key = HE_client.to_bytes_rotate_key()\n",
|
|
||||||
"#s_cx = [cx[j].to_bytes() for j in range(len(cx))]\n",
|
|
||||||
"s_cx = cx.to_bytes()\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"[Client] Sending HE_client={HE_client} and cx={cx}\")\n",
|
|
||||||
"print(f\"[Client] Sent {(len(s_context) + len(s_public_key) + len(s_relin_key) + len(s_rotate_key) + len(s_cx)) / (10**6)} MB\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Server Mock"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"ename": "NameError",
|
|
||||||
"evalue": "name 'nb' is not defined",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;129m@nb\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mhyperbolic_distance_parts\u001b[39m(u, v): \u001b[38;5;66;03m# returns only the numerator and denominator of the hyperbolic distance formula\u001b[39;00m\n\u001b[1;32m 3\u001b[0m diff \u001b[38;5;241m=\u001b[39m u \u001b[38;5;241m-\u001b[39m v\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m#du = -(1 - u @ u) # for some reason we need to negate this\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m#dv = -(1 - v @ v) # for some reason we need to negate this\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;66;03m#return diff @ diff, du * dv # returns the numerator and denominator\u001b[39;00m\n",
|
|
||||||
"\u001b[0;31mNameError\u001b[0m: name 'nb' is not defined"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"def hyperbolic_distance_parts(u, v): # returns only the numerator and denominator of the hyperbolic distance formula\n",
|
|
||||||
" diff = u - v\n",
|
|
||||||
" #du = -(1 - u @ u) # for some reason we need to negate this\n",
|
|
||||||
" #dv = -(1 - v @ v) # for some reason we need to negate this\n",
|
|
||||||
" #return diff @ diff, du * dv # returns the numerator and denominator\n",
|
|
||||||
" return diff @ diff # returns the numerator and denominator\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# document matrix containing rows of document encoding vectors\n",
|
|
||||||
"D = np.random.rand(database_size, encoding_width)\n",
|
|
||||||
"# precompute 1 - |D|^2 for each row vector in D\n",
|
|
||||||
"dv = []\n",
|
|
||||||
"for i in range(len(D)):\n",
|
|
||||||
" v = D[i]\n",
|
|
||||||
" dv.append(1 - v @ v)\n",
|
|
||||||
"dv = np.array(dv)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 16,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[Server] received HE_server=<ckks Pyfhel obj at 0x73188071c170, [pk:Y, sk:-, rtk:Y, rlk:Y, contx(n=16384, t=0, sec=128, qi=[60, 30, 30, 30, 60], scale=1073741824.0, )]> and cx=<Pyfhel Ciphertext at 0x73185d540f40, scheme=ckks, size=2/2, scale_bits=30, mod_level=0>\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "TypingError",
|
|
||||||
"evalue": "Failed in nopython mode pipeline (step: nopython frontend)\nnon-precise type pyobject\nDuring: typing of argument at /tmp/ipykernel_5991/1812651176.py (1)\n\nFile \"../../tmp/ipykernel_5991/1812651176.py\", line 1:\n<source missing, REPL/exec in use?> \n\nThis error may have been caused by the following argument(s):\n- argument 0: Cannot determine Numba type of <class 'Pyfhel.PyCtxt.PyCtxt'>\n- argument 1: Cannot determine Numba type of <class 'Pyfhel.PyCtxt.PyCtxt'>\n",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mTypingError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn[16], line 18\u001b[0m\n\u001b[1;32m 16\u001b[0m cd \u001b[38;5;241m=\u001b[39m HE_server\u001b[38;5;241m.\u001b[39mencrypt(D[i])\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# Compute distance bewteen recieved query and D[i]\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m res\u001b[38;5;241m.\u001b[39mappend(\u001b[43mhyperbolic_distance_parts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcd\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 19\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Server] Compute took \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mend\u001b[38;5;250m \u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;250m \u001b[39mstart\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124ms with bandwidth \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(D)\u001b[38;5;250m \u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;250m \u001b[39m(end\u001b[38;5;241m-\u001b[39mstart)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m documents/s\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
|
||||||
"File \u001b[0;32m~/.venvs/cs239/lib/python3.9/site-packages/numba/core/dispatcher.py:423\u001b[0m, in \u001b[0;36m_DispatcherBase._compile_for_args\u001b[0;34m(self, *args, **kws)\u001b[0m\n\u001b[1;32m 419\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mstr\u001b[39m(e)\u001b[38;5;241m.\u001b[39mrstrip()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mThis error may have been caused \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 420\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mby the following argument(s):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00margs_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 421\u001b[0m e\u001b[38;5;241m.\u001b[39mpatch_message(msg)\n\u001b[0;32m--> 423\u001b[0m \u001b[43merror_rewrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtyping\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 424\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m errors\u001b[38;5;241m.\u001b[39mUnsupportedError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 425\u001b[0m \u001b[38;5;66;03m# Something unsupported is present in the user code, add help info\u001b[39;00m\n\u001b[1;32m 426\u001b[0m error_rewrite(e, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124munsupported_error\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
|
||||||
"File \u001b[0;32m~/.venvs/cs239/lib/python3.9/site-packages/numba/core/dispatcher.py:364\u001b[0m, in \u001b[0;36m_DispatcherBase._compile_for_args.<locals>.error_rewrite\u001b[0;34m(e, issue_type)\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 363\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 364\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(\u001b[38;5;28;01mNone\u001b[39;00m)\n",
|
|
||||||
"\u001b[0;31mTypingError\u001b[0m: Failed in nopython mode pipeline (step: nopython frontend)\nnon-precise type pyobject\nDuring: typing of argument at /tmp/ipykernel_5991/1812651176.py (1)\n\nFile \"../../tmp/ipykernel_5991/1812651176.py\", line 1:\n<source missing, REPL/exec in use?> \n\nThis error may have been caused by the following argument(s):\n- argument 0: Cannot determine Numba type of <class 'Pyfhel.PyCtxt.PyCtxt'>\n- argument 1: Cannot determine Numba type of <class 'Pyfhel.PyCtxt.PyCtxt'>\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"HE_server = Pyfhel()\n",
|
|
||||||
"HE_server.from_bytes_context(s_context)\n",
|
|
||||||
"HE_server.from_bytes_public_key(s_public_key)\n",
|
|
||||||
"HE_server.from_bytes_relin_key(s_relin_key)\n",
|
|
||||||
"HE_server.from_bytes_rotate_key(s_rotate_key)\n",
|
|
||||||
"#cx = np.array([PyCtxt(pyfhel=HE_server, bytestring=s_cx[j]) for j in range(len(s_cx))])\n",
|
|
||||||
"cx = PyCtxt(pyfhel=HE_server, bytestring=s_cx)\n",
|
|
||||||
"print(f\"[Server] received HE_server={HE_server} and cx={cx}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Encode each document weights in plaintext\n",
|
|
||||||
"res = []\n",
|
|
||||||
"start = time.time()\n",
|
|
||||||
"for i in range(len(D)):\n",
|
|
||||||
" #d = np.array(D[i])\n",
|
|
||||||
" #cd = np.array([HE_server.encrypt(d[j]) for j in range(len(d))])\n",
|
|
||||||
" cd = HE_server.encrypt(D[i])\n",
|
|
||||||
" # Compute distance bewteen recieved query and D[i]\n",
|
|
||||||
" res.append(hyperbolic_distance_parts(cx, cd))\n",
|
|
||||||
"end = time.time()\n",
|
|
||||||
"print(f\"[Server] Compute took {end - start}s with bandwidth {len(D) / (end-start)} documents/s\")\n",
|
|
||||||
"\n",
|
|
||||||
"s_res = [res[j].to_bytes() for j in range(len(res))]\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"[Server] Distances computed! Responding: res={res}\")\n",
|
|
||||||
"print(f\"[Server] Sent {(np.sum([len(s_res[i]) for i in range(len(s_res))])) / (10**6)} MB\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Note that the time is mostly restricted by database size and not encoding size"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Client Parse Response"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 19,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def hyperbolic_distance(u, v):\n",
|
|
||||||
" num = ((u - v) @ (u - v))\n",
|
|
||||||
" den = (1 - (u @ u)) * (1 - (v @ v))\n",
|
|
||||||
" return np.arccosh(1 + 2 * (num / den))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 20,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#res = np.array([HE_client.decrypt(c_res[j]) for j in range(len(c_res))])[:,0]\n",
|
|
||||||
"#res = HE_client.decrypt(c_res)\n",
|
|
||||||
"c_res = []\n",
|
|
||||||
"for i in range(len(s_res)):\n",
|
|
||||||
" c_num = PyCtxt(pyfhel=HE_server, bytestring=s_res[i])\n",
|
|
||||||
" #c_den = PyCtxt(pyfhel=HE_server, bytestring=s_res[i][1])\n",
|
|
||||||
" p_num = HE_client.decrypt(c_num)[0]\n",
|
|
||||||
" #p_den = HE_client.decrypt(c_den)[0]\n",
|
|
||||||
" #dist = np.arccosh(1 + 2 * (p_num / p_den))\n",
|
|
||||||
" # compute final score\n",
|
|
||||||
" dist = np.arccosh(1 + 2 * (p_num / (du * dv[i])))\n",
|
|
||||||
" #print(dist)\n",
|
|
||||||
" c_res.append(dist)\n",
|
|
||||||
"\n",
|
|
||||||
"# Checking result\n",
|
|
||||||
"expected_res = [hyperbolic_distance(x, np.array(w)) for w in D]\n",
|
|
||||||
"#print(f\"[Client] Response received! \\nResult is {c_res} \\nShould be {expected}\\nDiff {np.abs(np.array(c_res) - np.array(expected))}\")\n",
|
|
||||||
"for i in range(len(c_res)):\n",
|
|
||||||
" result = c_res[i]\n",
|
|
||||||
" expected = expected_res[i]\n",
|
|
||||||
" if np.abs(result - expected) < 1e-3:\n",
|
|
||||||
" pass\n",
|
|
||||||
" else:\n",
|
|
||||||
" print(f\"got: {result}, expected: {expected}\")\n",
|
|
||||||
" assert False"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "cs239",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.9.20"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
||||||
128
basic.py
Normal file
128
basic.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
encoding_width = 2**13 # restricted by ckks size
|
||||||
|
database_size = 100
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Client Setup
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from Pyfhel import Pyfhel, PyCtxt
|
||||||
|
|
||||||
|
print(f"[Client] Initializing Pyfhel session and data...")
|
||||||
|
HE_client = Pyfhel() # Creating empty Pyfhel object
|
||||||
|
ckks_params = {
|
||||||
|
'scheme': 'CKKS', # can also be 'ckks'
|
||||||
|
'n': 2**14, # Polynomial modulus degree. For CKKS, n/2 values can be
|
||||||
|
# encoded in a single ciphertext.
|
||||||
|
# Typ. 2^D for D in [10, 15]
|
||||||
|
'scale': 2**30, # All the encodings will use it for float->fixed point
|
||||||
|
# conversion: x_fix = round(x_float * scale)
|
||||||
|
# You can use this as default scale or use a different
|
||||||
|
# scale on each operation (set in HE.encryptFrac)
|
||||||
|
'qi_sizes': [60, 30, 30, 30, 60] # Number of bits of each prime in the chain.
|
||||||
|
# Intermediate values should be close to log2(scale)
|
||||||
|
# for each operation, to have small rounding errors.
|
||||||
|
}
|
||||||
|
HE_client.contextGen(**ckks_params) # Generate context for bfv scheme
|
||||||
|
HE_client.keyGen() # Generates both a public and a private key
|
||||||
|
HE_client.relinKeyGen()
|
||||||
|
HE_client.rotateKeyGen()
|
||||||
|
|
||||||
|
# Generate and encrypt query vector
|
||||||
|
x = np.random.rand(encoding_width)
|
||||||
|
#cx = np.array([HE_client.encrypt(x[j]) for j in range(len(x))])
|
||||||
|
# precompute 1 - |x|^2 for query vector
|
||||||
|
du = 1 - x @ x
|
||||||
|
cx = HE_client.encrypt(x)
|
||||||
|
|
||||||
|
# Serializing data and public context information
|
||||||
|
|
||||||
|
s_context = HE_client.to_bytes_context()
|
||||||
|
s_public_key = HE_client.to_bytes_public_key()
|
||||||
|
s_relin_key = HE_client.to_bytes_relin_key()
|
||||||
|
s_rotate_key = HE_client.to_bytes_rotate_key()
|
||||||
|
#s_cx = [cx[j].to_bytes() for j in range(len(cx))]
|
||||||
|
s_cx = cx.to_bytes()
|
||||||
|
|
||||||
|
print(f"[Client] Sending HE_client={HE_client} and cx={cx}")
|
||||||
|
print(f"[Client] Sent {(len(s_context) + len(s_public_key) + len(s_relin_key) + len(s_rotate_key) + len(s_cx)) / (10**6)} MB")
|
||||||
|
|
||||||
|
# Server Mock
|
||||||
|
|
||||||
|
def hyperbolic_distance_parts(u, v): # returns only the numerator and denominator of the hyperbolic distance formula
|
||||||
|
diff = u - v
|
||||||
|
#du = -(1 - u @ u) # for some reason we need to negate this
|
||||||
|
#dv = -(1 - v @ v) # for some reason we need to negate this
|
||||||
|
#return diff @ diff, du * dv # returns the numerator and denominator
|
||||||
|
return diff @ diff
|
||||||
|
|
||||||
|
|
||||||
|
# document matrix containing rows of document encoding vectors
|
||||||
|
D = np.random.rand(database_size, encoding_width)
|
||||||
|
# precompute 1 - |D|^2 for each row vector in D
|
||||||
|
dv = []
|
||||||
|
for i in range(len(D)):
|
||||||
|
v = D[i]
|
||||||
|
dv.append(1 - v @ v)
|
||||||
|
dv = np.array(dv)
|
||||||
|
|
||||||
|
HE_server = Pyfhel()
|
||||||
|
HE_server.from_bytes_context(s_context)
|
||||||
|
HE_server.from_bytes_public_key(s_public_key)
|
||||||
|
HE_server.from_bytes_relin_key(s_relin_key)
|
||||||
|
HE_server.from_bytes_rotate_key(s_rotate_key)
|
||||||
|
#cx = np.array([PyCtxt(pyfhel=HE_server, bytestring=s_cx[j]) for j in range(len(s_cx))])
|
||||||
|
cx = PyCtxt(pyfhel=HE_server, bytestring=s_cx)
|
||||||
|
print(f"[Server] received HE_server={HE_server} and cx={cx}")
|
||||||
|
|
||||||
|
# Encode each document weights in plaintext
|
||||||
|
res = []
|
||||||
|
start = time.time()
|
||||||
|
for i in range(len(D)):
|
||||||
|
#d = np.array(D[i])
|
||||||
|
#cd = np.array([HE_server.encrypt(d[j]) for j in range(len(d))])
|
||||||
|
cd = HE_server.encrypt(D[i])
|
||||||
|
# Compute distance bewteen recieved query and D[i]
|
||||||
|
res.append(hyperbolic_distance_parts(cx, cd))
|
||||||
|
end = time.time()
|
||||||
|
print(f"[Server] Compute took {end - start}s with bandwidth {len(D) / (end-start)} documents/s")
|
||||||
|
|
||||||
|
s_res = [res[j].to_bytes() for j in range(len(res))]
|
||||||
|
|
||||||
|
print(f"[Server] Distances computed! Responding: res={res[0]}...")
|
||||||
|
print(f"[Server] Sent {(np.sum([len(s_res[i]) for i in range(len(s_res))])) / (10**6)} MB")
|
||||||
|
|
||||||
|
# Note that the time is mostly restricted by database size and not encoding size
|
||||||
|
|
||||||
|
# Client Parse Response
|
||||||
|
|
||||||
|
def hyperbolic_distance(u, v):
|
||||||
|
num = ((u - v) @ (u - v))
|
||||||
|
den = (1 - (u @ u)) * (1 - (v @ v))
|
||||||
|
return np.arccosh(1 + 2 * (num / den))
|
||||||
|
|
||||||
|
#res = np.array([HE_client.decrypt(c_res[j]) for j in range(len(c_res))])[:,0]
|
||||||
|
#res = HE_client.decrypt(c_res)
|
||||||
|
c_res = []
|
||||||
|
for i in range(len(s_res)):
|
||||||
|
#c_num = PyCtxt(pyfhel=HE_server, bytestring=s_res[i])
|
||||||
|
c_num = PyCtxt(pyfhel=HE_server, bytestring=s_res[i])
|
||||||
|
#c_den = PyCtxt(pyfhel=HE_server, bytestring=s_res[i][1])
|
||||||
|
p_num = HE_client.decrypt(c_num)[0]
|
||||||
|
#p_den = HE_client.decrypt(c_den)[0]
|
||||||
|
#dist = np.arccosh(1 + 2 * (p_num / p_den))
|
||||||
|
# compute final score
|
||||||
|
dist = np.arccosh(1 + 2 * (p_num / (du * dv[i])))
|
||||||
|
#print(dist)
|
||||||
|
c_res.append(dist)
|
||||||
|
|
||||||
|
# Checking result
|
||||||
|
expected_res = [hyperbolic_distance(x, np.array(w)) for w in D]
|
||||||
|
#print(f"[Client] Response received! \nResult is {c_res} \nShould be {expected}\nDiff {np.abs(np.array(c_res) - np.array(expected))}")
|
||||||
|
for i in range(len(c_res)):
|
||||||
|
result = c_res[i]
|
||||||
|
expected = expected_res[i]
|
||||||
|
if np.abs(result - expected) < 1e-3:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(f"got: {result}, expected: {expected}")
|
||||||
|
assert False
|
||||||
@@ -3,4 +3,4 @@ mkdir ~/.venvs
|
|||||||
rm -rf ~/.venvs/cs239
|
rm -rf ~/.venvs/cs239
|
||||||
python3.9 -m venv ~/.venvs/cs239
|
python3.9 -m venv ~/.venvs/cs239
|
||||||
source ~/.venvs/cs239/bin/activate
|
source ~/.venvs/cs239/bin/activate
|
||||||
pip install numpy pyfhel
|
pip install ipykernel numpy pyfhel
|
||||||
Reference in New Issue
Block a user