mirror of
https://github.com/ltcptgeneral/IdealRMT-DecisionTrees.git
synced 2025-09-07 07:47:24 +00:00
174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
||
"""Batch‑compress decision‑tree JSON files.
|
||
|
||
This script preserves the original logic but loops over every *.json file
|
||
in results/tree and drops a corresponding compressed file in
|
||
results/compressed_tree.
|
||
|
||
Example:
|
||
$ python compress_trees_batch.py
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import math
|
||
import os
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
INPUT_DIR = Path("results/tree")
|
||
OUTPUT_DIR = Path("results/compressed_tree")
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
class SetEncoder(json.JSONEncoder):
|
||
def default(self, obj): # type: ignore[override]
|
||
if isinstance(obj, set):
|
||
return list(obj)
|
||
return super().default(obj)
|
||
|
||
|
||
# helper function given a range and value x returns if x is in the range
|
||
|
||
def is_in_range(x: int, lower: int | None, upper: int | None) -> bool: # noqa: N803
|
||
if lower is None and upper is None:
|
||
return True
|
||
if lower is None:
|
||
return x <= upper # type: ignore[operator]
|
||
if upper is None:
|
||
return x > lower
|
||
return x <= upper and x > lower # type: ignore[operator]
|
||
|
||
|
||
for tree_path in INPUT_DIR.glob("*.json"):
|
||
with tree_path.open() as f:
|
||
tree = json.load(f)
|
||
|
||
paths = tree["paths"]
|
||
|
||
# First cleanup the tree by rounding the decision points to integer values
|
||
path_ids: set[int] = set()
|
||
path_classes = tree["classes"]
|
||
|
||
# assign ids and round thresholds
|
||
for idx, path in enumerate(paths):
|
||
path["id"] = idx
|
||
path_ids.add(idx)
|
||
for condition in path["conditions"]:
|
||
operation = condition["operation"]
|
||
if operation == "<=":
|
||
condition["value"] = math.floor(condition["value"])
|
||
else:
|
||
condition["value"] = math.floor(condition["value"])
|
||
|
||
# Find all breakpoints for each feature and create a set of disjoint ranges
|
||
breakpoints: dict[str, list[int]] = defaultdict(set) # type: ignore[assignment]
|
||
for path in paths:
|
||
for condition in path["conditions"]:
|
||
feature = condition["feature"]
|
||
value = condition["value"]
|
||
breakpoints[feature].add(value)
|
||
|
||
# sort breakpoint lists
|
||
for feature in breakpoints:
|
||
points = list(breakpoints[feature])
|
||
points.sort()
|
||
breakpoints[feature] = points # type: ignore[assignment]
|
||
|
||
# collapse all paths to ranges for each feature
|
||
for path in paths:
|
||
compressed: dict[str, dict[str, int | None]] = {}
|
||
for feature in breakpoints:
|
||
compressed[feature] = {"min": None, "max": None}
|
||
|
||
for condition in path["conditions"]:
|
||
feature = condition["feature"]
|
||
operation = condition["operation"]
|
||
value = condition["value"]
|
||
if operation == "<=" and compressed[feature]["max"] is None:
|
||
compressed[feature]["max"] = value
|
||
elif operation == ">" and compressed[feature]["min"] is None:
|
||
compressed[feature]["min"] = value
|
||
elif operation == "<=" and value < compressed[feature]["max"]: # type: ignore[operator]
|
||
compressed[feature]["max"] = value
|
||
elif operation == ">" and value > compressed[feature]["min"]: # type: ignore[operator]
|
||
compressed[feature]["min"] = value
|
||
|
||
path["compressed"] = compressed
|
||
|
||
# create buckets for each feature, where each is a list of sets
|
||
buckets_id: dict[str, list[set[int]]] = {}
|
||
buckets_class: dict[str, list[set[str]]] = {}
|
||
for feature in breakpoints:
|
||
num_points = len(breakpoints[feature])
|
||
buckets_id[feature] = [set() for _ in range(num_points + 1)]
|
||
buckets_class[feature] = [set() for _ in range(num_points + 1)]
|
||
|
||
# fill buckets
|
||
for path in paths:
|
||
for feature_name, feature in path["compressed"].items():
|
||
lower = feature["min"]
|
||
upper = feature["max"]
|
||
pid = path["id"]
|
||
cls = path["classification"]
|
||
|
||
for idx, bp in enumerate(breakpoints[feature_name]):
|
||
if is_in_range(bp, lower, upper):
|
||
buckets_id[feature_name][idx].add(pid)
|
||
buckets_class[feature_name][idx].add(cls)
|
||
# last bucket (> last breakpoint)
|
||
if is_in_range(bp + 1, lower, upper):
|
||
buckets_id[feature_name][-1].add(pid)
|
||
buckets_class[feature_name][-1].add(cls)
|
||
|
||
# combine breakpoints and buckets to one representation
|
||
compressed_layers: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||
for feature_name in buckets_id:
|
||
lower = None
|
||
upper = breakpoints[feature_name][0]
|
||
compressed_layers[feature_name].append(
|
||
{
|
||
"min": lower,
|
||
"max": upper,
|
||
"paths": buckets_id[feature_name][0],
|
||
"classes": buckets_class[feature_name][0],
|
||
}
|
||
)
|
||
for i in range(1, len(buckets_id[feature_name]) - 1):
|
||
lower = breakpoints[feature_name][i - 1]
|
||
upper = breakpoints[feature_name][i]
|
||
compressed_layers[feature_name].append(
|
||
{
|
||
"min": lower,
|
||
"max": upper,
|
||
"paths": buckets_id[feature_name][i],
|
||
"classes": buckets_class[feature_name][i],
|
||
}
|
||
)
|
||
lower = breakpoints[feature_name][-1]
|
||
upper = None
|
||
compressed_layers[feature_name].append(
|
||
{
|
||
"min": lower,
|
||
"max": upper,
|
||
"paths": buckets_id[feature_name][-1],
|
||
"classes": buckets_class[feature_name][-1],
|
||
}
|
||
)
|
||
|
||
path_to_class = {path["id"]: path["classification"] for path in paths}
|
||
|
||
compressed_tree = {
|
||
"paths": list(path_ids),
|
||
"classes": path_classes,
|
||
"layers": compressed_layers,
|
||
"path_to_class": path_to_class,
|
||
}
|
||
|
||
out_path = OUTPUT_DIR / tree_path.name.replace("tree", "compressed_tree")
|
||
with out_path.open("w") as f_out:
|
||
json.dump(compressed_tree, f_out, indent=4, cls=SetEncoder)
|
||
|
||
# print(f"Wrote {out_path.relative_to(Path.cwd())}")
|