In [25]:
import json
import math
from collections import defaultdict

In [26]:
f = open("tree.json")
tree = json.loads(f.read())
#features = tree["features"]
paths = tree["paths"]
f.close()

In [27]:
# First cleanup the tree by rounding the decision points to integer values
# We assume all features will use integer values. If this is not the case, then training data should be normalized so that integer values can be accurate enough
# we also enumerate all the paths for later use

i = 0

# for each path in the tree
for path in paths:
	# assign a path id 
	path["id"] = i
	i += 1
	# for each condition
	conditions = path["conditions"]
	for condition in conditions:
		# if the round the thresholds using floor
		operation = condition["operation"]
		if operation == "<=": # if a <= x.y, then a <= x is as strong given integer values
			condition["value"] = math.floor(condition["value"])
		else: # if a > x.y, then a > x is as strong given integer values
			condition["value"] = math.floor(condition["value"])

In [28]:
# Find all breakpoints for each feature and create a set of disjoint ranges

breakpoints = defaultdict(set)
for path in paths:
	conditions = path["conditions"]
	for condition in conditions:
		feature = condition["feature"]
		value = condition["value"]
		breakpoints[feature].add(value)

for feature in breakpoints:
	points = list(breakpoints[feature])
	points.sort()
	breakpoints[feature] = points

In [29]:
# collapse all paths to ranges for each feature
# because of how decision trees work, all conditions on a path must be true to reach the leaf node
# intuitively, a collection of statements x > a, x > b, x < c, x < d ... which must all be satistifed
# logicall can be collapsed into a singular range

# for each path
for path in paths:
	conditions = path["conditions"]
	compressed = {}

	# create a new compressed feature dict with 1 entry for each feature
	for feature in breakpoints:
		compressed[feature] = {"min": None, "max": None}
	
	# for each condition in the path
	for condition in conditions:
		feature = condition["feature"]
		operation = condition["operation"]
		value = condition["value"]

		# move the min/max for the corresponding feature in compressed
		if operation == "<=" and compressed[feature]["min"] is None:
			compressed[feature]["max"] = value
		elif operation == ">" and compressed[feature]["max"] is None:
			compressed[feature]["min"] = value
		elif operation == "<=" and value < compressed[feature]["min"]:
			compressed[feature]["max"] = value
		elif operation == ">" and value > compressed[feature]["max"]:
			compressed[feature]["min"] = value

	path["compressed"] = compressed

In [30]:
# for each path, add the path's id to buckets corresponding to each breakpoint's range
# ie if breakpoints = [0, 1, 2, 3]
# then buckets = [(< 0), (0 - 1), (1 - 2), (2 - 3), (> 3)]
# therefore, each entry in buckets is paths less (or equal) than that entry in breakpoints but greater (stricly) than the previous
# the last entry would correspond to the paths which are greater than the last entry in breakpoints

# helper function given a range and value x returs if x is in the range
def is_in_range(x, lower, upper):
	if lower is None and upper is None:
		return True
	elif lower is None:
		return x <= upper
	elif upper is None:
		return x > lower
	else:
		return x <= upper and x > lower

# create buckets for each feature, where each is a list of sets
buckets = {}
for feature in breakpoints:
	num_points = len(breakpoints[feature])
	buckets[feature] = []
	# each index in the feature corresponds to the corresponding breakpoint value in breakpoints
	# each index holds a set, which is the membership of paths in that range
	for i in range(0, num_points + 1):
		buckets[feature].append(set())

# for each path
for path in paths:
	# for each feature in the compressed path conditions
	for feature_name in path["compressed"]:
		feature = path["compressed"][feature_name]
		lower = feature["min"]
		upper = feature["max"]
		ID = path["id"]

		# for each bucket which encompases the condition's range, add this path's id to the sets 
		i = 0
		for bp in breakpoints[feature_name]:
			in_range = is_in_range(bp, lower, upper)
			if in_range:
				buckets[feature_name][i].add(ID)
			i += 1

In [31]:
# combine breakpoints and buckets to one representation

compressed_tree = defaultdict(list)
for feature_name in buckets:
	lower = None
	upper = breakpoints[feature_name][0]
	members = buckets[feature_name][0]
	#print(f"{feature_name} = [{lower}, {upper}]: {members}")
	compressed_tree[feature_name].append({"min": lower, "max": upper, "paths": list(members)})
	for i in range(1, len(buckets[feature_name]) - 1):
		lower = breakpoints[feature_name][i-1]
		upper = breakpoints[feature_name][i]
		members = buckets[feature_name][i]
		#print(f"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}")
		compressed_tree[feature_name].append({"min": lower, "max": upper, "paths": list(members)})
	lower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]
	upper = None
	members = buckets[feature_name][len(buckets[feature_name]) - 1]
	#print(f"{feature_name} = [{lower}, {upper}]: {members}")
	compressed_tree[feature_name].append({"min": lower, "max": upper, "paths": list(members)})
	#print("="*40)

In [32]:
f = open("compressed_tree.json", "w+")
f.write(json.dumps(compressed_tree, indent = 4))
f.close()