mirror of
https://github.com/ltcptgeneral/IdealRMT-DecisionTrees.git
synced 2025-09-05 14:57:23 +00:00
add list of paths and classes to compressed_tree output
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 1,
|
||||
"id": "ec310f34",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -14,7 +14,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 2,
|
||||
"id": "5b54797e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -28,22 +28,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 3,
|
||||
"id": "a38fdb8a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# First cleanup the tree by rounding the decision points to integer values\n",
|
||||
"# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n",
|
||||
"# we also enumerate all the paths for later use\n",
|
||||
"\n",
|
||||
"i = 0\n",
|
||||
"\n",
|
||||
"path_ids = set()\n",
|
||||
"path_classes = set()\n",
|
||||
"\n",
|
||||
"# for each path in the tree\n",
|
||||
"for path in paths:\n",
|
||||
"\t# assign a path id \n",
|
||||
"\tpath[\"id\"] = i\n",
|
||||
"\ti += 1\n",
|
||||
"\tpath_ids.add(i)\n",
|
||||
"\tpath_classes.add(path[\"classification\"])\n",
|
||||
"\ti += 1\t\n",
|
||||
"\t# for each condition\n",
|
||||
"\tconditions = path[\"conditions\"]\n",
|
||||
"\tfor condition in conditions:\n",
|
||||
@@ -57,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 4,
|
||||
"id": "2fd4f738",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -80,7 +83,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 5,
|
||||
"id": "98cde024",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -120,7 +123,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 6,
|
||||
"id": "b6fbadbf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -177,40 +180,46 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 7,
|
||||
"id": "0a767971",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# combine breakpoints and buckets to one representation\n",
|
||||
"\n",
|
||||
"compressed_tree = defaultdict(list)\n",
|
||||
"compressed_layers = defaultdict(list)\n",
|
||||
"for feature_name in buckets_id:\n",
|
||||
"\tlower = None\n",
|
||||
"\tupper = breakpoints[feature_name][0]\n",
|
||||
"\tpaths = buckets_id[feature_name][0]\n",
|
||||
"\tclasses = buckets_class[feature_name][0]\n",
|
||||
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
|
||||
"\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||
"\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n",
|
||||
"\t\tlower = breakpoints[feature_name][i-1]\n",
|
||||
"\t\tupper = breakpoints[feature_name][i]\n",
|
||||
"\t\tmembers = buckets_id[feature_name][i]\n",
|
||||
"\t\tclasses = buckets_class[feature_name][i]\n",
|
||||
"\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n",
|
||||
"\t\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||
"\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||
"\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n",
|
||||
"\tupper = None\n",
|
||||
"\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n",
|
||||
"\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n",
|
||||
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
|
||||
"\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||
"\t#print(\"=\"*40)"
|
||||
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||
"\t#print(\"=\"*40)\n",
|
||||
"\n",
|
||||
"compressed_tree = {\n",
|
||||
"\t\"paths\": path_ids,\n",
|
||||
"\t\"classes\": path_classes,\n",
|
||||
"\t\"layers\": compressed_layers,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 8,
|
||||
"id": "561b0bc1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
Reference in New Issue
Block a user