diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 475c463..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/TreeCompress.ipynb b/TreeCompress.ipynb index 11ae0fc..59208a4 100644 --- a/TreeCompress.ipynb +++ b/TreeCompress.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "id": "ec310f34", "metadata": {}, "outputs": [], @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "id": "5b54797e", "metadata": {}, "outputs": [], @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "id": "a38fdb8a", "metadata": {}, "outputs": [], @@ -38,14 +38,14 @@ "i = 0\n", "\n", "path_ids = set()\n", - "path_classes = set()\n", + "path_classes = tree[\"classes\"]\n", "\n", "# for each path in the tree\n", "for path in paths:\n", "\t# assign a path id \n", "\tpath[\"id\"] = i\n", "\tpath_ids.add(i)\n", - "\tpath_classes.add(path[\"classification\"])\n", + "\t#path_classes.add(path[\"classification\"])\n", "\ti += 1\t\n", "\t# for each condition\n", "\tconditions = path[\"conditions\"]\n", @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "id": "2fd4f738", "metadata": {}, "outputs": [], @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "id": "98cde024", "metadata": {}, "outputs": [], @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "id": "b6fbadbf", "metadata": {}, "outputs": [], @@ -171,16 +171,19 @@ "\t\t# for each bucket which encompases the condition's range, add this path's id to the sets \n", "\t\ti = 0\n", "\t\tfor bp in breakpoints[feature_name]:\n", - "\t\t\tin_range = is_in_range(bp, lower, upper)\n", - "\t\t\tif in_range:\n", + "\t\t\tif is_in_range(bp, lower, upper):\n", "\t\t\t\tbuckets_id[feature_name][i].add(ID)\n", "\t\t\t\tbuckets_class[feature_name][i].add(Class)\n", - "\t\t\ti += 1" + "\t\t\ti += 1\n", + "\n", + "\t\tif is_in_range(bp+1, lower, upper):\n", + "\t\t\tbuckets_id[feature_name][i].add(ID)\n", + "\t\t\tbuckets_class[feature_name][i].add(Class)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "id": "0a767971", "metadata": {}, "outputs": [], @@ -198,13 +201,13 @@ "\tfor i in range(1, len(buckets_id[feature_name]) - 1):\n", "\t\tlower = breakpoints[feature_name][i-1]\n", "\t\tupper = breakpoints[feature_name][i]\n", - "\t\tmembers = buckets_id[feature_name][i]\n", + "\t\tpaths = buckets_id[feature_name][i]\n", "\t\tclasses = buckets_class[feature_name][i]\n", "\t\t#print(f\"{feature_name} = [{lower}, {upper}]: {buckets[feature_name][i]}\")\n", "\t\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", "\tlower = breakpoints[feature_name][len(breakpoints[feature_name]) - 1]\n", "\tupper = None\n", - "\tmembers = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n", + "\tpaths = buckets_id[feature_name][len(buckets_id[feature_name]) - 1]\n", "\tclasses = buckets_class[feature_name][len(buckets_class[feature_name]) - 1]\n", "\t#print(f\"{feature_name} = [{lower}, {upper}]: {members}\")\n", "\tcompressed_layers[feature_name].append({\"min\": lower, \"max\": upper, \"paths\": paths, \"classes\": classes})\n", @@ -219,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "id": "561b0bc1", "metadata": {}, "outputs": [], @@ -238,7 +241,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "switch", "language": "python", "name": "python3" }, @@ -252,7 +255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/TreeToRMT.ipynb b/TreeToRMT.ipynb index 8c8d546..d0def9e 100644 --- a/TreeToRMT.ipynb +++ b/TreeToRMT.ipynb @@ -382,7 +382,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "switch", "language": "python", "name": "python3" }, @@ -396,7 +396,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.12.7" } }, "nbformat": 4,