From 0d5e51f582b7cfbd2fe1d4d2c496685dc09fc7b3 Mon Sep 17 00:00:00 2001 From: Arthur Lu Date: Mon, 2 Jun 2025 21:48:10 +0000 Subject: [PATCH] add list of paths and classes to compressed_tree output --- RMTConvert.ipynb | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/RMTConvert.ipynb b/RMTConvert.ipynb index 17335bf..d252a6b 100644 --- a/RMTConvert.ipynb +++ b/RMTConvert.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "id": "ec310f34", "metadata": {}, "outputs": [], @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "id": "5b54797e", "metadata": {}, "outputs": [], @@ -28,22 +28,25 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "id": "a38fdb8a", "metadata": {}, "outputs": [], "source": [ "# First cleanup the tree by rounding the decision points to integer values\n", "# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n", - "# we also enumerate all the paths for later use\n", - "\n", "i = 0\n", "\n", + "path_ids = set()\n", + "path_classes = set()\n", + "\n", "# for each path in the tree\n", "for path in paths:\n", "\t# assign a path id \n", "\tpath[\"id\"] = i\n", - "\ti += 1\n", + "\tpath_ids.add(i)\n", + "\tpath_classes.add(path[\"classification\"])\n", + "\ti += 1\t\n", "\t# for each condition\n", "\tconditions = path[\"conditions\"]\n", "\tfor condition in conditions:\n", @@ -57,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "id": "2fd4f738", "metadata": {}, "outputs": [], @@ -80,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "id": "98cde024", "metadata": {}, "outputs": [], @@ -120,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "id": "b6fbadbf", "metadata": {}, "outputs": [], @@ -177,40 +180,46 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "id": "0a767971", "metadata": {}, "outputs": [], "source": [ "# combine breakpoints and buckets to one representation\n", "\n", - "compressed_tree = defaultdict(list)\n", + "compressed_layers = defaultdict(list)\n", "for feature_name in buckets_id:\n", "\tlower = None\n", "\tupper = breakpoints[feature_name][0]\n", "\tpaths = buckets_id[feature_name][0]\n", "\tclasses = buckets_class[feature_name][0]\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", - "\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", + "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n", "\t\tlower = breakpoints[feature_name][i-1]\n", "\t\tupper = breakpoints[feature_name][i]\n", "\t\tmembers = buckets_id[feature_name][i]\n", "\t\tclasses = buckets_class[feature_name][i]\n", "\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n", - "\t\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", + "\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n", "\tupper = None\n", "\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n", "\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", - "\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", - "\t#print(\"=\"*40)" + "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", + "\t#print(\"=\"*40)\n", + "\n", + "compressed_tree = {\n", + "\t\"paths\": path_ids,\n", + "\t\"classes\": path_classes,\n", + "\t\"layers\": compressed_layers,\n", + "}" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "id": "561b0bc1", "metadata": {}, "outputs": [],