Files
IdealRMT-DecisionTrees/TreeCompress.ipynb

261 lines
9.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"id": "ec310f34",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import math\n",
"from collections import defaultdict"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5b54797e",
"metadata": {},
"outputs": [],
"source": [
"f = open(\"tree.json\")\n",
"tree = json.loads(f.read())\n",
"#features = tree[\"features\"]\n",
"paths = tree[\"paths\"]\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a38fdb8a",
"metadata": {},
"outputs": [],
"source": [
"# First cleanup the tree by rounding the decision points to integer values\n",
"# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n",
"i = 0\n",
"\n",
"path_ids = set()\n",
"path_classes = set()\n",
"\n",
"# for each path in the tree\n",
"for path in paths:\n",
"\t# assign a path id \n",
"\tpath[\"id\"] = i\n",
"\tpath_ids.add(i)\n",
"\tpath_classes.add(path[\"classification\"])\n",
"\ti += 1\t\n",
"\t# for each condition\n",
"\tconditions = path[\"conditions\"]\n",
"\tfor condition in conditions:\n",
"\t\t# if the round the thresholds using floor\n",
"\t\toperation = condition[\"operation\"]\n",
"\t\tif operation == \"<=\": # if a <= x.y, then a <= x is as strong given integer values\n",
"\t\t\tcondition[\"value\"] = math.floor(condition[\"value\"])\n",
"\t\telse: # if a > x.y, then a > x is as strong given integer values\n",
"\t\t\tcondition[\"value\"] = math.floor(condition[\"value\"])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2fd4f738",
"metadata": {},
"outputs": [],
"source": [
"# Find all breakpoints for each feature and create a set of disjoint ranges\n",
"\n",
"breakpoints = defaultdict(set)\n",
"for path in paths:\n",
"\tconditions = path[\"conditions\"]\n",
"\tfor condition in conditions:\n",
"\t\tfeature = condition[\"feature\"]\n",
"\t\tvalue = condition[\"value\"]\n",
"\t\tbreakpoints[feature].add(value)\n",
"\n",
"for feature in breakpoints:\n",
"\tpoints = list(breakpoints[feature])\n",
"\tpoints.sort()\n",
"\tbreakpoints[feature] = points"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "98cde024",
"metadata": {},
"outputs": [],
"source": [
"# collapse all paths to ranges for each feature\n",
"# because of how decision trees work, all conditions on a path must be true to reach the leaf node\n",
"# intuitively, a collection of statements x > a, x > b, x < c, x < d ... which must all be satistifed\n",
"# logicall can be collapsed into a singular range\n",
"\n",
"# for each path\n",
"for path in paths:\n",
"\tconditions = path[\"conditions\"]\n",
"\tcompressed = {}\n",
"\n",
"\t# create a new compressed feature dict with 1 entry for each feature\n",
"\tfor feature in breakpoints:\n",
"\t\tcompressed[feature] = {\"min\": None, \"max\": None}\n",
"\t\n",
"\t# for each condition in the path\n",
"\tfor condition in conditions:\n",
"\t\tfeature = condition[\"feature\"]\n",
"\t\toperation = condition[\"operation\"]\n",
"\t\tvalue = condition[\"value\"]\n",
"\n",
"\t\t# move the min/max for the corresponding feature in compressed\n",
"\t\tif operation == \"<=\" and compressed[feature][\"min\"] is None:\n",
"\t\t\tcompressed[feature][\"max\"] = value\n",
"\t\telif operation == \">\" and compressed[feature][\"max\"] is None:\n",
"\t\t\tcompressed[feature][\"min\"] = value\n",
"\t\telif operation == \"<=\" and value < compressed[feature][\"min\"]:\n",
"\t\t\tcompressed[feature][\"max\"] = value\n",
"\t\telif operation == \">\" and value > compressed[feature][\"max\"]:\n",
"\t\t\tcompressed[feature][\"min\"] = value\n",
"\n",
"\tpath[\"compressed\"] = compressed"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b6fbadbf",
"metadata": {},
"outputs": [],
"source": [
"# for each path, add the path's id to buckets corresponding to each breakpoint's range\n",
"# ie if breakpoints = [0, 1, 2, 3]\n",
"# then buckets = [(< 0), (0 - 1), (1 - 2), (2 - 3), (> 3)]\n",
"# therefore, each entry in buckets is paths less (or equal) than that entry in breakpoints but greater (stricly) than the previous\n",
"# the last entry would correspond to the paths which are greater than the last entry in breakpoints\n",
"\n",
"# helper function given a range and value x returs if x is in the range\n",
"def is_in_range(x, lower, upper):\n",
"\tif lower is None and upper is None:\n",
"\t\treturn True\n",
"\telif lower is None:\n",
"\t\treturn x <= upper\n",
"\telif upper is None:\n",
"\t\treturn x > lower\n",
"\telse:\n",
"\t\treturn x <= upper and x > lower\n",
"\n",
"# create buckets for each feature, where each is a list of sets\n",
"buckets_id = {}\n",
"buckets_class = {}\n",
"for feature in breakpoints:\n",
"\tnum_points = len(breakpoints[feature])\n",
"\tbuckets_id[feature] = []\n",
"\tbuckets_class[feature] = []\n",
"\t# each index in the feature corresponds to the corresponding breakpoint value in breakpoints\n",
"\t# each index holds a set, which is the membership of paths in that range\n",
"\tfor i in range(0, num_points + 1):\n",
"\t\tbuckets_id[feature].append(set())\n",
"\t\tbuckets_class[feature].append(set())\n",
"\n",
"# for each path\n",
"for path in paths:\n",
"\t# for each feature in the compressed path conditions\n",
"\tfor feature_name in path[\"compressed\"]:\n",
"\t\tfeature = path[\"compressed\"][feature_name]\n",
"\t\tlower = feature[\"min\"]\n",
"\t\tupper = feature[\"max\"]\n",
"\t\tID = path[\"id\"]\n",
"\t\tClass = path[\"classification\"]\n",
"\n",
"\t\t# for each bucket which encompases the condition's range, add this path's id to the sets \n",
"\t\ti = 0\n",
"\t\tfor bp in breakpoints[feature_name]:\n",
"\t\t\tin_range = is_in_range(bp, lower, upper)\n",
"\t\t\tif in_range:\n",
"\t\t\t\tbuckets_id[feature_name][i].add(ID)\n",
"\t\t\t\tbuckets_class[feature_name][i].add(Class)\n",
"\t\t\ti += 1"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0a767971",
"metadata": {},
"outputs": [],
"source": [
"# combine breakpoints and buckets to one representation\n",
"\n",
"compressed_layers = defaultdict(list)\n",
"for feature_name in buckets_id:\n",
"\tlower = None\n",
"\tupper = breakpoints[feature_name][0]\n",
"\tpaths = buckets_id[feature_name][0]\n",
"\tclasses = buckets_class[feature_name][0]\n",
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n",
"\t\tlower = breakpoints[feature_name][i-1]\n",
"\t\tupper = breakpoints[feature_name][i]\n",
"\t\tmembers = buckets_id[feature_name][i]\n",
"\t\tclasses = buckets_class[feature_name][i]\n",
"\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n",
"\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n",
"\tupper = None\n",
"\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n",
"\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n",
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\t#print(\"=\"*40)\n",
"\n",
"compressed_tree = {\n",
"\t\"paths\": path_ids,\n",
"\t\"classes\": path_classes,\n",
"\t\"layers\": compressed_layers,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "561b0bc1",
"metadata": {},
"outputs": [],
"source": [
"class SetEncoder(json.JSONEncoder):\n",
" def default(self, obj):\n",
" if isinstance(obj, set):\n",
" return list(obj)\n",
" return json.JSONEncoder.default(self, obj)\n",
"\n",
"f = open(\"compressed_tree.json\", \"w+\")\n",
"f.write(json.dumps(compressed_tree, indent = 4, cls=SetEncoder))\n",
"f.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}