Files
IdealRMT-DecisionTrees/run/tree_compress.py
2025-06-13 16:11:56 -07:00

174 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""Batchcompress decisiontree JSON files.
This script preserves the original logic but loops over every *.json file
in results/tree and drops a corresponding compressed file in
results/compressed_tree.
Example:
$ python compress_trees_batch.py
"""
from __future__ import annotations
import json
import math
import os
from collections import defaultdict
from pathlib import Path
INPUT_DIR = Path("results/tree")
OUTPUT_DIR = Path("results/compressed_tree")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
class SetEncoder(json.JSONEncoder):
def default(self, obj): # type: ignore[override]
if isinstance(obj, set):
return list(obj)
return super().default(obj)
# helper function given a range and value x returns if x is in the range
def is_in_range(x: int, lower: int | None, upper: int | None) -> bool: # noqa: N803
if lower is None and upper is None:
return True
if lower is None:
return x <= upper # type: ignore[operator]
if upper is None:
return x > lower
return x <= upper and x > lower # type: ignore[operator]
for tree_path in INPUT_DIR.glob("*.json"):
with tree_path.open() as f:
tree = json.load(f)
paths = tree["paths"]
# First cleanup the tree by rounding the decision points to integer values
path_ids: set[int] = set()
path_classes = tree["classes"]
# assign ids and round thresholds
for idx, path in enumerate(paths):
path["id"] = idx
path_ids.add(idx)
for condition in path["conditions"]:
operation = condition["operation"]
if operation == "<=":
condition["value"] = math.floor(condition["value"])
else:
condition["value"] = math.floor(condition["value"])
# Find all breakpoints for each feature and create a set of disjoint ranges
breakpoints: dict[str, list[int]] = defaultdict(set) # type: ignore[assignment]
for path in paths:
for condition in path["conditions"]:
feature = condition["feature"]
value = condition["value"]
breakpoints[feature].add(value)
# sort breakpoint lists
for feature in breakpoints:
points = list(breakpoints[feature])
points.sort()
breakpoints[feature] = points # type: ignore[assignment]
# collapse all paths to ranges for each feature
for path in paths:
compressed: dict[str, dict[str, int | None]] = {}
for feature in breakpoints:
compressed[feature] = {"min": None, "max": None}
for condition in path["conditions"]:
feature = condition["feature"]
operation = condition["operation"]
value = condition["value"]
if operation == "<=" and compressed[feature]["max"] is None:
compressed[feature]["max"] = value
elif operation == ">" and compressed[feature]["min"] is None:
compressed[feature]["min"] = value
elif operation == "<=" and value < compressed[feature]["max"]: # type: ignore[operator]
compressed[feature]["max"] = value
elif operation == ">" and value > compressed[feature]["min"]: # type: ignore[operator]
compressed[feature]["min"] = value
path["compressed"] = compressed
# create buckets for each feature, where each is a list of sets
buckets_id: dict[str, list[set[int]]] = {}
buckets_class: dict[str, list[set[str]]] = {}
for feature in breakpoints:
num_points = len(breakpoints[feature])
buckets_id[feature] = [set() for _ in range(num_points + 1)]
buckets_class[feature] = [set() for _ in range(num_points + 1)]
# fill buckets
for path in paths:
for feature_name, feature in path["compressed"].items():
lower = feature["min"]
upper = feature["max"]
pid = path["id"]
cls = path["classification"]
for idx, bp in enumerate(breakpoints[feature_name]):
if is_in_range(bp, lower, upper):
buckets_id[feature_name][idx].add(pid)
buckets_class[feature_name][idx].add(cls)
# last bucket (> last breakpoint)
if is_in_range(bp + 1, lower, upper):
buckets_id[feature_name][-1].add(pid)
buckets_class[feature_name][-1].add(cls)
# combine breakpoints and buckets to one representation
compressed_layers: dict[str, list[dict[str, object]]] = defaultdict(list)
for feature_name in buckets_id:
lower = None
upper = breakpoints[feature_name][0]
compressed_layers[feature_name].append(
{
"min": lower,
"max": upper,
"paths": buckets_id[feature_name][0],
"classes": buckets_class[feature_name][0],
}
)
for i in range(1, len(buckets_id[feature_name]) - 1):
lower = breakpoints[feature_name][i - 1]
upper = breakpoints[feature_name][i]
compressed_layers[feature_name].append(
{
"min": lower,
"max": upper,
"paths": buckets_id[feature_name][i],
"classes": buckets_class[feature_name][i],
}
)
lower = breakpoints[feature_name][-1]
upper = None
compressed_layers[feature_name].append(
{
"min": lower,
"max": upper,
"paths": buckets_id[feature_name][-1],
"classes": buckets_class[feature_name][-1],
}
)
path_to_class = {path["id"]: path["classification"] for path in paths}
compressed_tree = {
"paths": list(path_ids),
"classes": path_classes,
"layers": compressed_layers,
"path_to_class": path_to_class,
}
out_path = OUTPUT_DIR / tree_path.name.replace("tree", "compressed_tree")
with out_path.open("w") as f_out:
json.dump(compressed_tree, f_out, indent=4, cls=SetEncoder)
# print(f"Wrote {out_path.relative_to(Path.cwd())}")