mirror of
https://github.com/ltcptgeneral/IdealRMT-DecisionTrees.git
synced 2025-09-07 07:47:24 +00:00
169 lines
6.0 KiB
Python
169 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Train a decision tree, optionally “nudge” its split thresholds, and
|
||
export the result as JSON.
|
||
|
||
Usage examples
|
||
--------------
|
||
# plain training, no nudging
|
||
python build_tree.py --input data/combined/data.csv --output tree.json
|
||
|
||
# nudge every internal threshold, keeping only the top-2 bits
|
||
python build_tree.py --input data/combined/data.csv --output tree.json \
|
||
--nudge --bits 2
|
||
"""
|
||
import argparse
|
||
import copy
|
||
import json
|
||
import math
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from sklearn.metrics import accuracy_score
|
||
from sklearn.tree import DecisionTreeClassifier, _tree
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 1. command-line arguments
|
||
# ----------------------------------------------------------------------
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--input", "-i", help="CSV file with protocol,src,dst,label", default="../data/combined/data.csv")
|
||
parser.add_argument("--output", "-o", help="Path for the exported JSON tree", default="tree.json")
|
||
parser.add_argument("--depth", "-d", type=int, default=5,
|
||
help="Max depth of the decision tree (default: 5)")
|
||
parser.add_argument("--nudge", action="store_true",
|
||
help="Enable threshold nudging")
|
||
parser.add_argument("--bits", type=int, default=2,
|
||
help="Number of bits to keep when nudging (default: 2)")
|
||
args = parser.parse_args()
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 2. helper functions
|
||
# ----------------------------------------------------------------------
|
||
def nudge_threshold_max_n_bits(threshold: float, n_bits: int) -> int:
|
||
"""Remove n bits from each"""
|
||
threshold = math.floor(threshold)
|
||
if n_bits == 0:
|
||
return threshold
|
||
|
||
mask = pow(2, 32) - 1 ^ ((1 << n_bits) - 1)
|
||
nudged_value = threshold & mask
|
||
if threshold & (1 << (n_bits - 1)):
|
||
nudged_value += (1 << (n_bits))
|
||
|
||
return nudged_value
|
||
|
||
def apply_nudging(tree: _tree.Tree, node_idx: int, n_bits: int) -> None:
|
||
"""Post-order traversal that nudges every internal node’s threshold."""
|
||
flag = False
|
||
if tree.children_left[node_idx] != -1:
|
||
apply_nudging(tree, tree.children_left[node_idx], n_bits)
|
||
flag = True
|
||
if tree.children_right[node_idx] != -1:
|
||
apply_nudging(tree, tree.children_right[node_idx], n_bits)
|
||
flag = True
|
||
if flag: # internal node
|
||
tree.threshold[node_idx] = nudge_threshold_max_n_bits(
|
||
tree.threshold[node_idx], n_bits
|
||
)
|
||
|
||
# output the tree
|
||
def get_lineage(tree, feature_names):
|
||
data = {"features": {}, "paths": [], "classes": list(tree.classes_)}
|
||
|
||
thresholds = tree.tree_.threshold
|
||
features = [feature_names[i] for i in tree.tree_.feature]
|
||
left = tree.tree_.children_left
|
||
right = tree.tree_.children_right
|
||
value = tree.tree_.value
|
||
|
||
# -------- helper to climb up from a leaf to the root -----------
|
||
def recurse(left, right, child, lineage=None):
|
||
if lineage is None:
|
||
lineage = [child] # leaf marker (an int)
|
||
if child in left:
|
||
parent = np.where(left == child)[0].item()
|
||
split = "l"
|
||
elif child in right:
|
||
parent = np.where(right == child)[0].item()
|
||
split = "r"
|
||
else: # should never happen
|
||
return lineage
|
||
|
||
lineage.append((parent, split, thresholds[parent], features[parent]))
|
||
if parent == 0:
|
||
return list(reversed(lineage))
|
||
return recurse(left, right, parent, lineage)
|
||
|
||
leaf_ids = np.where(left == -1)[0] # indices of all leaves
|
||
for path_id, leaf in enumerate(leaf_ids):
|
||
clause = []
|
||
|
||
for node in recurse(left, right, leaf):
|
||
if not isinstance(node, tuple): # skip the leaf marker
|
||
continue
|
||
|
||
direction, threshold, feature = node[1], node[2], node[3]
|
||
if direction == "l":
|
||
clause.append(
|
||
{"feature": feature, "operation": "<=", "value": threshold}
|
||
)
|
||
else:
|
||
clause.append(
|
||
{"feature": feature, "operation": ">", "value": threshold}
|
||
)
|
||
|
||
class_idx = int(np.argmax(value[leaf][0])) # use the leaf itself
|
||
data["paths"].append(
|
||
{"conditions": clause, "classification": class_idx, "id": path_id}
|
||
)
|
||
|
||
# collect all thresholds per feature
|
||
for i, feat in enumerate(features):
|
||
if tree.tree_.feature[i] != _tree.TREE_UNDEFINED:
|
||
data["features"].setdefault(feat, []).append(thresholds[i])
|
||
|
||
return data
|
||
|
||
|
||
class SetEncoder(json.JSONEncoder):
|
||
def default(self, obj):
|
||
if isinstance(obj, set):
|
||
return list(obj)
|
||
return super().default(obj)
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 3. load data
|
||
# ----------------------------------------------------------------------
|
||
df = pd.read_csv(args.input)
|
||
X = df.iloc[:, :3].to_numpy()
|
||
Y = df.iloc[:, 3].to_numpy()
|
||
|
||
print(f"dataset size: {len(X)}")
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 4. train the tree
|
||
# ----------------------------------------------------------------------
|
||
dt = DecisionTreeClassifier(max_depth=args.depth)
|
||
dt.fit(X, Y)
|
||
print("train accuracy (before nudging):",
|
||
accuracy_score(Y, dt.predict(X)))
|
||
|
||
if args.nudge:
|
||
nudged_tree = copy.deepcopy(dt.tree_)
|
||
apply_nudging(nudged_tree, 0, args.bits)
|
||
dt.tree_ = nudged_tree
|
||
print(f"nudging enabled, removed bottom {args.bits} bit(s) per threshold")
|
||
|
||
print("train accuracy (after nudging):",
|
||
accuracy_score(Y, dt.predict(X)))
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 5. export
|
||
# ----------------------------------------------------------------------
|
||
lineage = get_lineage(dt, df.columns[:3])
|
||
|
||
output_path = Path(args.output)
|
||
output_path.write_text(json.dumps(lineage, indent=4, cls=SetEncoder))
|
||
print(f"Wrote tree to {output_path.resolve()}")
|