mirror of
https://github.com/ltcptgeneral/IdealRMT-DecisionTrees.git
synced 2025-09-09 00:37:23 +00:00
Run data and code
This commit is contained in:
173
run/tree_compress.py
Normal file
173
run/tree_compress.py
Normal file
@@ -0,0 +1,173 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Batch‑compress decision‑tree JSON files.
|
||||
|
||||
This script preserves the original logic but loops over every *.json file
|
||||
in results/tree and drops a corresponding compressed file in
|
||||
results/compressed_tree.
|
||||
|
||||
Example:
|
||||
$ python compress_trees_batch.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
INPUT_DIR = Path("results/tree")
|
||||
OUTPUT_DIR = Path("results/compressed_tree")
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class SetEncoder(json.JSONEncoder):
|
||||
def default(self, obj): # type: ignore[override]
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
# helper function given a range and value x returns if x is in the range
|
||||
|
||||
def is_in_range(x: int, lower: int | None, upper: int | None) -> bool: # noqa: N803
|
||||
if lower is None and upper is None:
|
||||
return True
|
||||
if lower is None:
|
||||
return x <= upper # type: ignore[operator]
|
||||
if upper is None:
|
||||
return x > lower
|
||||
return x <= upper and x > lower # type: ignore[operator]
|
||||
|
||||
|
||||
for tree_path in INPUT_DIR.glob("*.json"):
|
||||
with tree_path.open() as f:
|
||||
tree = json.load(f)
|
||||
|
||||
paths = tree["paths"]
|
||||
|
||||
# First cleanup the tree by rounding the decision points to integer values
|
||||
path_ids: set[int] = set()
|
||||
path_classes = tree["classes"]
|
||||
|
||||
# assign ids and round thresholds
|
||||
for idx, path in enumerate(paths):
|
||||
path["id"] = idx
|
||||
path_ids.add(idx)
|
||||
for condition in path["conditions"]:
|
||||
operation = condition["operation"]
|
||||
if operation == "<=":
|
||||
condition["value"] = math.floor(condition["value"])
|
||||
else:
|
||||
condition["value"] = math.floor(condition["value"])
|
||||
|
||||
# Find all breakpoints for each feature and create a set of disjoint ranges
|
||||
breakpoints: dict[str, list[int]] = defaultdict(set) # type: ignore[assignment]
|
||||
for path in paths:
|
||||
for condition in path["conditions"]:
|
||||
feature = condition["feature"]
|
||||
value = condition["value"]
|
||||
breakpoints[feature].add(value)
|
||||
|
||||
# sort breakpoint lists
|
||||
for feature in breakpoints:
|
||||
points = list(breakpoints[feature])
|
||||
points.sort()
|
||||
breakpoints[feature] = points # type: ignore[assignment]
|
||||
|
||||
# collapse all paths to ranges for each feature
|
||||
for path in paths:
|
||||
compressed: dict[str, dict[str, int | None]] = {}
|
||||
for feature in breakpoints:
|
||||
compressed[feature] = {"min": None, "max": None}
|
||||
|
||||
for condition in path["conditions"]:
|
||||
feature = condition["feature"]
|
||||
operation = condition["operation"]
|
||||
value = condition["value"]
|
||||
if operation == "<=" and compressed[feature]["max"] is None:
|
||||
compressed[feature]["max"] = value
|
||||
elif operation == ">" and compressed[feature]["min"] is None:
|
||||
compressed[feature]["min"] = value
|
||||
elif operation == "<=" and value < compressed[feature]["max"]: # type: ignore[operator]
|
||||
compressed[feature]["max"] = value
|
||||
elif operation == ">" and value > compressed[feature]["min"]: # type: ignore[operator]
|
||||
compressed[feature]["min"] = value
|
||||
|
||||
path["compressed"] = compressed
|
||||
|
||||
# create buckets for each feature, where each is a list of sets
|
||||
buckets_id: dict[str, list[set[int]]] = {}
|
||||
buckets_class: dict[str, list[set[str]]] = {}
|
||||
for feature in breakpoints:
|
||||
num_points = len(breakpoints[feature])
|
||||
buckets_id[feature] = [set() for _ in range(num_points + 1)]
|
||||
buckets_class[feature] = [set() for _ in range(num_points + 1)]
|
||||
|
||||
# fill buckets
|
||||
for path in paths:
|
||||
for feature_name, feature in path["compressed"].items():
|
||||
lower = feature["min"]
|
||||
upper = feature["max"]
|
||||
pid = path["id"]
|
||||
cls = path["classification"]
|
||||
|
||||
for idx, bp in enumerate(breakpoints[feature_name]):
|
||||
if is_in_range(bp, lower, upper):
|
||||
buckets_id[feature_name][idx].add(pid)
|
||||
buckets_class[feature_name][idx].add(cls)
|
||||
# last bucket (> last breakpoint)
|
||||
if is_in_range(bp + 1, lower, upper):
|
||||
buckets_id[feature_name][-1].add(pid)
|
||||
buckets_class[feature_name][-1].add(cls)
|
||||
|
||||
# combine breakpoints and buckets to one representation
|
||||
compressed_layers: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||||
for feature_name in buckets_id:
|
||||
lower = None
|
||||
upper = breakpoints[feature_name][0]
|
||||
compressed_layers[feature_name].append(
|
||||
{
|
||||
"min": lower,
|
||||
"max": upper,
|
||||
"paths": buckets_id[feature_name][0],
|
||||
"classes": buckets_class[feature_name][0],
|
||||
}
|
||||
)
|
||||
for i in range(1, len(buckets_id[feature_name]) - 1):
|
||||
lower = breakpoints[feature_name][i - 1]
|
||||
upper = breakpoints[feature_name][i]
|
||||
compressed_layers[feature_name].append(
|
||||
{
|
||||
"min": lower,
|
||||
"max": upper,
|
||||
"paths": buckets_id[feature_name][i],
|
||||
"classes": buckets_class[feature_name][i],
|
||||
}
|
||||
)
|
||||
lower = breakpoints[feature_name][-1]
|
||||
upper = None
|
||||
compressed_layers[feature_name].append(
|
||||
{
|
||||
"min": lower,
|
||||
"max": upper,
|
||||
"paths": buckets_id[feature_name][-1],
|
||||
"classes": buckets_class[feature_name][-1],
|
||||
}
|
||||
)
|
||||
|
||||
path_to_class = {path["id"]: path["classification"] for path in paths}
|
||||
|
||||
compressed_tree = {
|
||||
"paths": list(path_ids),
|
||||
"classes": path_classes,
|
||||
"layers": compressed_layers,
|
||||
"path_to_class": path_to_class,
|
||||
}
|
||||
|
||||
out_path = OUTPUT_DIR / tree_path.name.replace("tree", "compressed_tree")
|
||||
with out_path.open("w") as f_out:
|
||||
json.dump(compressed_tree, f_out, indent=4, cls=SetEncoder)
|
||||
|
||||
# print(f"Wrote {out_path.relative_to(Path.cwd())}")
|
Reference in New Issue
Block a user