fix incorrect classes in TreeCompress,

closes #1
This commit is contained in:
2025-06-08 18:34:39 +00:00
parent c208037ae9
commit fadeab8a99

View File

@@ -86,7 +86,41 @@
"execution_count": 5, "execution_count": 5,
"id": "98cde024", "id": "98cde024",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'dst': {'min': None, 'max': 578}, 'src': {'min': None, 'max': 60}, 'protocl': {'min': None, 'max': 0}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': None, 'max': 60}, 'protocl': {'min': None, 'max': 0}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': None, 'max': 60}, 'protocl': {'min': 0, 'max': None}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': None, 'max': 60}, 'protocl': {'min': 1, 'max': None}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': None, 'max': 67}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 101}, 'src': {'min': 67, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 101}, 'src': {'min': 54978, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 101}, 'src': {'min': 59817, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 101}, 'src': {'min': 60043, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': 67, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': 130, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': None, 'max': 3031}, 'src': {'min': 1223, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 3031, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 3067, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 5110, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 33925, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 46329, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 46331, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 49152, 'max': None}, 'src': {'min': None, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 49157, 'max': None}, 'src': {'min': None, 'max': 283}, 'protocl': {'min': None, 'max': 11}}\n",
"{'dst': {'min': 49157, 'max': None}, 'src': {'min': None, 'max': 283}, 'protocl': {'min': 11, 'max': None}}\n",
"{'dst': {'min': 49157, 'max': None}, 'src': {'min': None, 'max': 4566}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 56320, 'max': None}, 'src': {'min': None, 'max': 4566}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 49157, 'max': None}, 'src': {'min': 4566, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 51848, 'max': None}, 'src': {'min': 4566, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 49157, 'max': None}, 'src': {'min': 5225, 'max': None}, 'protocl': {'min': None, 'max': None}}\n",
"{'dst': {'min': 53283, 'max': None}, 'src': {'min': 5225, 'max': None}, 'protocl': {'min': None, 'max': None}}\n"
]
}
],
"source": [ "source": [
"# collapse all paths to ranges for each feature\n", "# collapse all paths to ranges for each feature\n",
"# because of how decision trees work, all conditions on a path must be true to reach the leaf node\n", "# because of how decision trees work, all conditions on a path must be true to reach the leaf node\n",
@@ -118,7 +152,8 @@
"\t\telif operation == \">\" and value > compressed[feature][\"max\"]:\n", "\t\telif operation == \">\" and value > compressed[feature][\"max\"]:\n",
"\t\t\tcompressed[feature][\"min\"] = value\n", "\t\t\tcompressed[feature][\"min\"] = value\n",
"\n", "\n",
"\tpath[\"compressed\"] = compressed" "\tpath[\"compressed\"] = compressed\n",
"\tprint(compressed)"
] ]
}, },
{ {
@@ -171,11 +206,14 @@
"\t\t# for each bucket which encompases the condition's range, add this path's id to the sets \n", "\t\t# for each bucket which encompases the condition's range, add this path's id to the sets \n",
"\t\ti = 0\n", "\t\ti = 0\n",
"\t\tfor bp in breakpoints[feature_name]:\n", "\t\tfor bp in breakpoints[feature_name]:\n",
"\t\t\tin_range = is_in_range(bp, lower, upper)\n", "\t\t\tif is_in_range(bp, lower, upper):\n",
"\t\t\tif in_range:\n",
"\t\t\t\tbuckets_id[feature_name][i].add(ID)\n", "\t\t\t\tbuckets_id[feature_name][i].add(ID)\n",
"\t\t\t\tbuckets_class[feature_name][i].add(Class)\n", "\t\t\t\tbuckets_class[feature_name][i].add(Class)\n",
"\t\t\ti += 1" "\t\t\ti += 1\n",
"\n",
"\t\tif is_in_range(bp+1, lower, upper):\n",
"\t\t\tbuckets_id[feature_name][i].add(ID)\n",
"\t\t\tbuckets_class[feature_name][i].add(Class)"
] ]
}, },
{ {
@@ -198,13 +236,13 @@
"\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n", "\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n",
"\t\tlower = breakpoints[feature_name][i-1]\n", "\t\tlower = breakpoints[feature_name][i-1]\n",
"\t\tupper = breakpoints[feature_name][i]\n", "\t\tupper = breakpoints[feature_name][i]\n",
"\t\tmembers = buckets_id[feature_name][i]\n", "\t\tpaths = buckets_id[feature_name][i]\n",
"\t\tclasses = buckets_class[feature_name][i]\n", "\t\tclasses = buckets_class[feature_name][i]\n",
"\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n", "\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n",
"\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",
"\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n", "\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n",
"\tupper = None\n", "\tupper = None\n",
"\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n", "\tpaths = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n",
"\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n", "\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n",
"\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n",
"\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n",