add list of paths and classes to compressed_tree output

This commit is contained in:
2025-06-02 21:48:10 +00:00
parent 9729c6e68c
commit 0d5e51f582

View File

@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 1,
"id": "ec310f34", "id": "ec310f34",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -14,7 +14,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 2,
"id": "5b54797e", "id": "5b54797e",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -28,22 +28,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 3,
"id": "a38fdb8a", "id": "a38fdb8a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# First cleanup the tree by rounding the decision points to integer values\n", "# First cleanup the tree by rounding the decision points to integer values\n",
"# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n", "# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n",
"# we also enumerate all the paths for later use\n",
"\n",
"i = 0\n", "i = 0\n",
"\n", "\n",
"path_ids = set()\n",
"path_classes = set()\n",
"\n",
"# for each path in the tree\n", "# for each path in the tree\n",
"for path in paths:\n", "for path in paths:\n",
"\t# assign a path id \n", "\t# assign a path id \n",
"\tpath[\"id\"] = i\n", "\tpath[\"id\"] = i\n",
"\ti += 1\n", "\tpath_ids.add(i)\n",
"\tpath_classes.add(path[\"classification\"])\n",
"\ti += 1\t\n",
"\t# for each condition\n", "\t# for each condition\n",
"\tconditions = path[\"conditions\"]\n", "\tconditions = path[\"conditions\"]\n",
"\tfor condition in conditions:\n", "\tfor condition in conditions:\n",
@@ -57,7 +60,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 4,
"id": "2fd4f738", "id": "2fd4f738",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -80,7 +83,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 5,
"id": "98cde024", "id": "98cde024",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -120,7 +123,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 6,
"id": "b6fbadbf", "id": "b6fbadbf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -177,40 +180,46 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 7,
"id": "0a767971", "id": "0a767971",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# combine breakpoints and buckets to one representation\n", "# combine breakpoints and buckets to one representation\n",
"\n", "\n",
"compressed_tree = defaultdict(list)\n", "compressed_layers = defaultdict(list)\n",
"for feature_name in buckets_id:\n", "for feature_name in buckets_id:\n",
"\tlower = None\n", "\tlower = None\n",
"\tupper = breakpoints[feature_name][0]\n", "\tupper = breakpoints[feature_name][0]\n",
"\tpaths = buckets_id[feature_name][0]\n", "\tpaths = buckets_id[feature_name][0]\n",
"\tclasses = buckets_class[feature_name][0]\n", "\tclasses = buckets_class[feature_name][0]\n",
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
"\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n", "\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n",
"\t\tlower = breakpoints[feature_name][i-1]\n", "\t\tlower = breakpoints[feature_name][i-1]\n",
"\t\tupper = breakpoints[feature_name][i]\n", "\t\tupper = breakpoints[feature_name][i]\n",
"\t\tmembers = buckets_id[feature_name][i]\n", "\t\tmembers = buckets_id[feature_name][i]\n",
"\t\tclasses = buckets_class[feature_name][i]\n", "\t\tclasses = buckets_class[feature_name][i]\n",
"\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n", "\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n",
"\t\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n", "\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n",
"\tupper = None\n", "\tupper = None\n",
"\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n", "\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n",
"\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n", "\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n",
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
"\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\t#print(\"=\"*40)" "\t#print(\"=\"*40)\n",
"\n",
"compressed_tree = {\n",
"\t\"paths\": path_ids,\n",
"\t\"classes\": path_classes,\n",
"\t\"layers\": compressed_layers,\n",
"}"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 8,
"id": "561b0bc1", "id": "561b0bc1",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],