{ "cells": [ { "cell_type": "code", "execution_count": 73, "id": "ec310f34", "metadata": {}, "outputs": [], "source": [ "import json\n", "import math\n", "from collections import defaultdict" ] }, { "cell_type": "code", "execution_count": 74, "id": "5b54797e", "metadata": {}, "outputs": [], "source": [ "f = open(\"tree.json\")\n", "tree = json.loads(f.read())\n", "#features = tree[\"features\"]\n", "paths = tree[\"paths\"]\n", "f.close()" ] }, { "cell_type": "code", "execution_count": 75, "id": "a38fdb8a", "metadata": {}, "outputs": [], "source": [ "# First cleanup the tree by rounding the decision points to integer values\n", "# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n", "i = 0\n", "\n", "path_ids = set()\n", "path_classes = tree[\"classes\"]\n", "\n", "# for each path in the tree\n", "for path in paths:\n", "\t# assign a path id \n", "\tpath[\"id\"] = i\n", "\tpath_ids.add(i)\n", "\t#path_classes.add(path[\"classification\"])\n", "\ti += 1\t\n", "\t# for each condition\n", "\tconditions = path[\"conditions\"]\n", "\tfor condition in conditions:\n", "\t\t# if the round the thresholds using floor\n", "\t\toperation = condition[\"operation\"]\n", "\t\tif operation == \"<=\": # if a <= x.y, then a <= x is as strong given integer values\n", "\t\t\tcondition[\"value\"] = math.floor(condition[\"value\"])\n", "\t\telse: # if a > x.y, then a > x is as strong given integer values\n", "\t\t\tcondition[\"value\"] = math.floor(condition[\"value\"])" ] }, { "cell_type": "code", "execution_count": 76, "id": "2fd4f738", "metadata": {}, "outputs": [], "source": [ "# Find all breakpoints for each feature and create a set of disjoint ranges\n", "\n", "breakpoints = defaultdict(set)\n", "for path in paths:\n", "\tconditions = path[\"conditions\"]\n", "\tfor condition in conditions:\n", "\t\tfeature = condition[\"feature\"]\n", "\t\tvalue = condition[\"value\"]\n", "\t\tbreakpoints[feature].add(value)\n", "\n", "for feature in breakpoints:\n", "\tpoints = list(breakpoints[feature])\n", "\tpoints.sort()\n", "\tbreakpoints[feature] = points" ] }, { "cell_type": "code", "execution_count": 77, "id": "98cde024", "metadata": {}, "outputs": [], "source": [ "# collapse all paths to ranges for each feature\n", "# because of how decision trees work, all conditions on a path must be true to reach the leaf node\n", "# intuitively, a collection of statements x > a, x > b, x < c, x < d ... which must all be satistifed\n", "# logicall can be collapsed into a singular range\n", "\n", "# for each path\n", "for path in paths:\n", "\tconditions = path[\"conditions\"]\n", "\tcompressed = {}\n", "\n", "\t# create a new compressed feature dict with 1 entry for each feature\n", "\tfor feature in breakpoints:\n", "\t\tcompressed[feature] = {\"min\": None, \"max\": None}\n", "\t\n", "\t# for each condition in the path\n", "\tfor condition in conditions:\n", "\t\tfeature = condition[\"feature\"]\n", "\t\toperation = condition[\"operation\"]\n", "\t\tvalue = condition[\"value\"]\n", "\n", "\t\t# move the min/max for the corresponding feature in compressed\n", "\t\tif operation == \"<=\" and compressed[feature][\"max\"] is None:\n", "\t\t\tcompressed[feature][\"max\"] = value\n", "\t\telif operation == \">\" and compressed[feature][\"min\"] is None:\n", "\t\t\tcompressed[feature][\"min\"] = value\n", "\t\telif operation == \"<=\" and value < compressed[feature][\"max\"]:\n", "\t\t\tcompressed[feature][\"max\"] = value\n", "\t\telif operation == \">\" and value > compressed[feature][\"min\"]:\n", "\t\t\tcompressed[feature][\"min\"] = value\n", "\n", "\tpath[\"compressed\"] = compressed" ] }, { "cell_type": "code", "execution_count": 78, "id": "b6fbadbf", "metadata": {}, "outputs": [], "source": [ "# for each path, add the path's id to buckets corresponding to each breakpoint's range\n", "# ie if breakpoints = [0, 1, 2, 3]\n", "# then buckets = [(< 0), (0 - 1), (1 - 2), (2 - 3), (> 3)]\n", "# therefore, each entry in buckets is paths less (or equal) than that entry in breakpoints but greater (stricly) than the previous\n", "# the last entry would correspond to the paths which are greater than the last entry in breakpoints\n", "\n", "# helper function given a range and value x returs if x is in the range\n", "def is_in_range(x, lower, upper):\n", "\tif lower is None and upper is None:\n", "\t\treturn True\n", "\telif lower is None:\n", "\t\treturn x <= upper\n", "\telif upper is None:\n", "\t\treturn x > lower\n", "\telse:\n", "\t\treturn x <= upper and x > lower\n", "\n", "# create buckets for each feature, where each is a list of sets\n", "buckets_id = {}\n", "buckets_class = {}\n", "for feature in breakpoints:\n", "\tnum_points = len(breakpoints[feature])\n", "\tbuckets_id[feature] = []\n", "\tbuckets_class[feature] = []\n", "\t# each index in the feature corresponds to the corresponding breakpoint value in breakpoints\n", "\t# each index holds a set, which is the membership of paths in that range\n", "\tfor i in range(0, num_points + 1):\n", "\t\tbuckets_id[feature].append(set())\n", "\t\tbuckets_class[feature].append(set())\n", "\n", "# for each path\n", "for path in paths:\n", "\t# for each feature in the compressed path conditions\n", "\tfor feature_name in path[\"compressed\"]:\n", "\t\tfeature = path[\"compressed\"][feature_name]\n", "\t\tlower = feature[\"min\"]\n", "\t\tupper = feature[\"max\"]\n", "\t\tID = path[\"id\"]\n", "\t\tClass = path[\"classification\"]\n", "\n", "\t\t# for each bucket which encompases the condition's range, add this path's id to the sets \n", "\t\ti = 0\n", "\t\tfor bp in breakpoints[feature_name]:\n", "\t\t\tif is_in_range(bp, lower, upper):\n", "\t\t\t\tbuckets_id[feature_name][i].add(ID)\n", "\t\t\t\tbuckets_class[feature_name][i].add(Class)\n", "\t\t\ti += 1\n", "\n", "\t\tif is_in_range(bp+1, lower, upper):\n", "\t\t\tbuckets_id[feature_name][i].add(ID)\n", "\t\t\tbuckets_class[feature_name][i].add(Class)" ] }, { "cell_type": "code", "execution_count": 79, "id": "0a767971", "metadata": {}, "outputs": [], "source": [ "# combine breakpoints and buckets to one representation\n", "\n", "compressed_layers = defaultdict(list)\n", "for feature_name in buckets_id:\n", "\tlower = None\n", "\tupper = breakpoints[feature_name][0]\n", "\tpaths = buckets_id[feature_name][0]\n", "\tclasses = buckets_class[feature_name][0]\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n", "\t\tlower = breakpoints[feature_name][i-1]\n", "\t\tupper = breakpoints[feature_name][i]\n", "\t\tpaths = buckets_id[feature_name][i]\n", "\t\tclasses = buckets_class[feature_name][i]\n", "\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n", "\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n", "\tupper = None\n", "\tpaths = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n", "\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\t#print(\"=\"*40)\n", "\n", "path_to_class = {}\n", "for i in range(len(tree[\"paths\"])):\n", " path = tree[\"paths\"][i]\n", " path_to_class[path[\"id\"]] = path[\"classification\"]\n", "\n", "compressed_tree = {\n", "\t\"paths\": path_ids,\n", "\t\"classes\": path_classes,\n", "\t\"layers\": compressed_layers,\n", " \"path_to_class\": path_to_class,\n", "}" ] }, { "cell_type": "code", "execution_count": 80, "id": "561b0bc1", "metadata": {}, "outputs": [], "source": [ "class SetEncoder(json.JSONEncoder):\n", " def default(self, obj):\n", " if isinstance(obj, set):\n", " return list(obj)\n", " return json.JSONEncoder.default(self, obj)\n", "\n", "f = open(\"compressed_tree.json\", \"w+\")\n", "f.write(json.dumps(compressed_tree, indent = 4, cls=SetEncoder))\n", "f.close()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }