mirror of
https://github.com/ltcptgeneral/IdealRMT-DecisionTrees.git
synced 2025-09-05 23:07:24 +00:00
add list of paths and classes to compressed_tree output
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 1,
|
||||||
"id": "ec310f34",
|
"id": "ec310f34",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 2,
|
||||||
"id": "5b54797e",
|
"id": "5b54797e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -28,22 +28,25 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 3,
|
||||||
"id": "a38fdb8a",
|
"id": "a38fdb8a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# First cleanup the tree by rounding the decision points to integer values\n",
|
"# First cleanup the tree by rounding the decision points to integer values\n",
|
||||||
"# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n",
|
"# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough\n",
|
||||||
"# we also enumerate all the paths for later use\n",
|
|
||||||
"\n",
|
|
||||||
"i = 0\n",
|
"i = 0\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"path_ids = set()\n",
|
||||||
|
"path_classes = set()\n",
|
||||||
|
"\n",
|
||||||
"# for each path in the tree\n",
|
"# for each path in the tree\n",
|
||||||
"for path in paths:\n",
|
"for path in paths:\n",
|
||||||
"\t# assign a path id \n",
|
"\t# assign a path id \n",
|
||||||
"\tpath[\"id\"] = i\n",
|
"\tpath[\"id\"] = i\n",
|
||||||
"\ti += 1\n",
|
"\tpath_ids.add(i)\n",
|
||||||
|
"\tpath_classes.add(path[\"classification\"])\n",
|
||||||
|
"\ti += 1\t\n",
|
||||||
"\t# for each condition\n",
|
"\t# for each condition\n",
|
||||||
"\tconditions = path[\"conditions\"]\n",
|
"\tconditions = path[\"conditions\"]\n",
|
||||||
"\tfor condition in conditions:\n",
|
"\tfor condition in conditions:\n",
|
||||||
@@ -57,7 +60,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 4,
|
||||||
"id": "2fd4f738",
|
"id": "2fd4f738",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -80,7 +83,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 5,
|
||||||
"id": "98cde024",
|
"id": "98cde024",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -120,7 +123,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 6,
|
||||||
"id": "b6fbadbf",
|
"id": "b6fbadbf",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -177,40 +180,46 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 7,
|
||||||
"id": "0a767971",
|
"id": "0a767971",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# combine breakpoints and buckets to one representation\n",
|
"# combine breakpoints and buckets to one representation\n",
|
||||||
"\n",
|
"\n",
|
||||||
"compressed_tree = defaultdict(list)\n",
|
"compressed_layers = defaultdict(list)\n",
|
||||||
"for feature_name in buckets_id:\n",
|
"for feature_name in buckets_id:\n",
|
||||||
"\tlower = None\n",
|
"\tlower = None\n",
|
||||||
"\tupper = breakpoints[feature_name][0]\n",
|
"\tupper = breakpoints[feature_name][0]\n",
|
||||||
"\tpaths = buckets_id[feature_name][0]\n",
|
"\tpaths = buckets_id[feature_name][0]\n",
|
||||||
"\tclasses = buckets_class[feature_name][0]\n",
|
"\tclasses = buckets_class[feature_name][0]\n",
|
||||||
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
|
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
|
||||||
"\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||||
"\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n",
|
"\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n",
|
||||||
"\t\tlower = breakpoints[feature_name][i-1]\n",
|
"\t\tlower = breakpoints[feature_name][i-1]\n",
|
||||||
"\t\tupper = breakpoints[feature_name][i]\n",
|
"\t\tupper = breakpoints[feature_name][i]\n",
|
||||||
"\t\tmembers = buckets_id[feature_name][i]\n",
|
"\t\tmembers = buckets_id[feature_name][i]\n",
|
||||||
"\t\tclasses = buckets_class[feature_name][i]\n",
|
"\t\tclasses = buckets_class[feature_name][i]\n",
|
||||||
"\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n",
|
"\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n",
|
||||||
"\t\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
"\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||||
"\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n",
|
"\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n",
|
||||||
"\tupper = None\n",
|
"\tupper = None\n",
|
||||||
"\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n",
|
"\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n",
|
||||||
"\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n",
|
"\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n",
|
||||||
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
|
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
|
||||||
"\tcompressed_tree[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
|
||||||
"\t#print(\"=\"*40)"
|
"\t#print(\"=\"*40)\n",
|
||||||
|
"\n",
|
||||||
|
"compressed_tree = {\n",
|
||||||
|
"\t\"paths\": path_ids,\n",
|
||||||
|
"\t\"classes\": path_classes,\n",
|
||||||
|
"\t\"layers\": compressed_layers,\n",
|
||||||
|
"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 8,
|
||||||
"id": "561b0bc1",
|
"id": "561b0bc1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
Reference in New Issue
Block a user