diff --git a/.gitignore b/.gitignore index 7a7b094..3a7c2d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ data.* __pycache__ -tree.json \ No newline at end of file +tree.json +compressed_tree.json \ No newline at end of file diff --git a/RMTConvert.ipynb b/RMTConvert.ipynb new file mode 100644 index 0000000..4f0bc9a --- /dev/null +++ b/RMTConvert.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 25, + "id": "ec310f34", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import math\n", + "from collections import defaultdict" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "5b54797e", + "metadata": {}, + "outputs": [], + "source": [ + "f = open(\"tree.json\")\n", + "tree = json.loads(f.read())\n", + "#features = tree[\"features\"]\n", + "paths = tree[\"paths\"]\n", + "f.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "a38fdb8a", + "metadata": {}, + "outputs": [], + "source": [ + "# First cleanup the tree by rounding the decision points to integer values\n", + "# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n", + "# we also enumerate all the paths for later use\n", + "\n", + "i = 0\n", + "\n", + "# for each path in the tree\n", + "for path in paths:\n", + "\t# assign a path id \n", + "\tpath[\"id\"] = i\n", + "\ti += 1\n", + "\t# for each condition\n", + "\tconditions = path[\"conditions\"]\n", + "\tfor condition in conditions:\n", + "\t\t# if the round the thresholds using floor\n", + "\t\toperation = condition[\"operation\"]\n", + "\t\tif operation == \"<=\": # if a <= x.y, then a <= x is as strong given integer values\n", + "\t\t\tcondition[\"value\"] = math.floor(condition[\"value\"])\n", + "\t\telse: # if a > x.y, then a > x is as strong given integer values\n", + "\t\t\tcondition[\"value\"] = math.floor(condition[\"value\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2fd4f738", + "metadata": {}, + "outputs": [], + "source": [ + "# Find all breakpoints for each feature and create a set of disjoint ranges\n", + "\n", + "breakpoints = defaultdict(set)\n", + "for path in paths:\n", + "\tconditions = path[\"conditions\"]\n", + "\tfor condition in conditions:\n", + "\t\tfeature = condition[\"feature\"]\n", + "\t\tvalue = condition[\"value\"]\n", + "\t\tbreakpoints[feature].add(value)\n", + "\n", + "for feature in breakpoints:\n", + "\tpoints = list(breakpoints[feature])\n", + "\tpoints.sort()\n", + "\tbreakpoints[feature] = points" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "98cde024", + "metadata": {}, + "outputs": [], + "source": [ + "# collapse all paths to ranges for each feature\n", + "# because of how decision trees work, all conditions on a path must be true to reach the leaf node\n", + "# intuitively, a collection of statements x > a, x > b, x < c, x < d ... which must all be satistifed\n", + "# logicall can be collapsed into a singular range\n", + "\n", + "# for each path\n", + "for path in paths:\n", + "\tconditions = path[\"conditions\"]\n", + "\tcompressed = {}\n", + "\n", + "\t# create a new compressed feature dict with 1 entry for each feature\n", + "\tfor feature in breakpoints:\n", + "\t\tcompressed[feature] = {\"min\": None, \"max\": None}\n", + "\t\n", + "\t# for each condition in the path\n", + "\tfor condition in conditions:\n", + "\t\tfeature = condition[\"feature\"]\n", + "\t\toperation = condition[\"operation\"]\n", + "\t\tvalue = condition[\"value\"]\n", + "\n", + "\t\t# move the min/max for the corresponding feature in compressed\n", + "\t\tif operation == \"<=\" and compressed[feature][\"min\"] is None:\n", + "\t\t\tcompressed[feature][\"max\"] = value\n", + "\t\telif operation == \">\" and compressed[feature][\"max\"] is None:\n", + "\t\t\tcompressed[feature][\"min\"] = value\n", + "\t\telif operation == \"<=\" and value < compressed[feature][\"min\"]:\n", + "\t\t\tcompressed[feature][\"max\"] = value\n", + "\t\telif operation == \">\" and value > compressed[feature][\"max\"]:\n", + "\t\t\tcompressed[feature][\"min\"] = value\n", + "\n", + "\tpath[\"compressed\"] = compressed" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "b6fbadbf", + "metadata": {}, + "outputs": [], + "source": [ + "# for each path, add the path's id to buckets corresponding to each breakpoint's range\n", + "# ie if breakpoints = [0, 1, 2, 3]\n", + "# then buckets = [(< 0), (0 - 1), (1 - 2), (2 - 3), (> 3)]\n", + "# therefore, each entry in buckets is paths less (or equal) than that entry in breakpoints but greater (stricly) than the previous\n", + "# the last entry would correspond to the paths which are greater than the last entry in breakpoints\n", + "\n", + "# helper function given a range and value x returs if x is in the range\n", + "def is_in_range(x, lower, upper):\n", + "\tif lower is None and upper is None:\n", + "\t\treturn True\n", + "\telif lower is None:\n", + "\t\treturn x <= upper\n", + "\telif upper is None:\n", + "\t\treturn x > lower\n", + "\telse:\n", + "\t\treturn x <= upper and x > lower\n", + "\n", + "# create buckets for each feature, where each is a list of sets\n", + "buckets = {}\n", + "for feature in breakpoints:\n", + "\tnum_points = len(breakpoints[feature])\n", + "\tbuckets[feature] = []\n", + "\t# each index in the feature corresponds to the corresponding breakpoint value in breakpoints\n", + "\t# each index holds a set, which is the membership of paths in that range\n", + "\tfor i in range(0, num_points + 1):\n", + "\t\tbuckets[feature].append(set())\n", + "\n", + "# for each path\n", + "for path in paths:\n", + "\t# for each feature in the compressed path conditions\n", + "\tfor feature_name in path[\"compressed\"]:\n", + "\t\tfeature = path[\"compressed\"][feature_name]\n", + "\t\tlower = feature[\"min\"]\n", + "\t\tupper = feature[\"max\"]\n", + "\t\tID = path[\"id\"]\n", + "\n", + "\t\t# for each bucket which encompases the condition's range, add this path's id to the sets \n", + "\t\ti = 0\n", + "\t\tfor bp in breakpoints[feature_name]:\n", + "\t\t\tin_range = is_in_range(bp, lower, upper)\n", + "\t\t\tif in_range:\n", + "\t\t\t\tbuckets[feature_name][i].add(ID)\n", + "\t\t\ti += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "0a767971", + "metadata": {}, + "outputs": [], + "source": [ + "# combine breakpoints and buckets to one representation\n", + "\n", + "compressed_tree = defaultdict(list)\n", + "for feature_name in buckets:\n", + "\tlower = None\n", + "\tupper = breakpoints[feature_name][0]\n", + "\tmembers = buckets[feature_name][0]\n", + "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", + "\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": list(members)})\n", + "\tfor i in range(1, len(buckets[feature_name]) - 1):\n", + "\t\tlower = breakpoints[feature_name][i-1]\n", + "\t\tupper = breakpoints[feature_name][i]\n", + "\t\tmembers = buckets[feature_name][i]\n", + "\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n", + "\t\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": list(members)})\n", + "\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n", + "\tupper = None\n", + "\tmembers = buckets[feature_name][len(buckets[feature_name]) - 1]\n", + "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", + "\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": list(members)})\n", + "\t#print(\"=\"*40)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "561b0bc1", + "metadata": {}, + "outputs": [], + "source": [ + "f = open(\"compressed_tree.json\", \"w+\")\n", + "f.write(json.dumps(compressed_tree, indent = 4))\n", + "f.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "switch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}