{ "cells": [ { "cell_type": "code", "execution_count": 6, "id": "d5618056", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import argparse\n", "from sklearn.tree import DecisionTreeClassifier, plot_tree, _tree\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.tree import export_graphviz\n", "import pydotplus\n", "from matplotlib import pyplot as plt\n", "from labels import mac_to_label\n", "import json\n", "import math" ] }, { "cell_type": "code", "execution_count": 7, "id": "b96f3403", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset size: 4735360\n", "train accuracy: 0.879490682862549\n", "test accuracy: 0.879490682862549\n" ] } ], "source": [ "inputfile = \"data.csv\"\n", "outputfile = \"tree.json\"\n", "\n", "# Training set X and Y\n", "Set1 = pd.read_csv(inputfile)\n", "Set = Set1.values.tolist()\n", "X = [i[0:3] for i in Set]\n", "Y =[i[3] for i in Set]\n", "\n", "# Test set Xt and Yt\n", "Set2 = pd.read_csv(inputfile)\n", "Sett = Set2.values.tolist()\n", "Xt = [i[0:3] for i in Set]\n", "Yt =[i[3] for i in Set]\n", "\n", "# prepare training and testing set\n", "X = np.array(X)\n", "Y = np.array(Y)\n", "Xt = np.array(Xt)\n", "Yt = np.array(Yt)\n", "\n", "print(f\"dataset size: {len(X)}\")\n", "\n", "# decision tree fit\n", "dt = DecisionTreeClassifier(max_depth = 5)\n", "dt.fit(X, Y)\n", "Predict_Y = dt.predict(X)\n", "print(f\"train accuracy: {accuracy_score(Y, Predict_Y)}\")\n", "\n", "Predict_Yt = dt.predict(Xt)\n", "print(f\"test accuracy: {accuracy_score(Yt, Predict_Yt)}\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "d336971a", "metadata": {}, "outputs": [], "source": [ "# output the tree\n", "def get_lineage(tree, feature_names):\n", " data = {\"features\": {}, \"paths\": []}\n", "\n", " thresholds = tree.tree_.threshold\n", " features = [feature_names[i] for i in tree.tree_.feature]\n", " left = tree.tree_.children_left\n", " right = tree.tree_.children_right\n", " value = tree.tree_.value\n", " \n", " # get ids of child nodes\n", " idx = np.argwhere(left == -1)[:, 0]\n", " # traverse the tree and get the node information\n", " def recurse(left, right, child, lineage=None):\n", " if lineage is None:\n", " lineage = [child]\n", " if child in left:\n", " parent = np.where(left == child)[0].item()\n", " split = 'l'\n", " else:\n", " parent = np.where(right == child)[0].item()\n", " split = 'r'\n", " \n", " lineage.append((parent, split, thresholds[parent], features[parent]))\n", " if parent == 0:\n", " lineage.reverse()\n", " return lineage\n", " else:\n", " return recurse(left, right, parent, lineage)\n", "\n", " for j, child in enumerate(idx):\n", " clause = []\n", " for node in recurse(left, right, child):\n", " if len(str(node)) < 3:\n", " continue\n", " direction = node[1]\n", " threshold = node[2]\n", " feature = node[3]\n", " if direction == \"l\": # feature <= threshold\n", " clause.append({\"feature\": feature, \"operation\": \"<=\", \"value\": threshold})\n", " else: # direction == \"r\" # feature > threshold\n", " threshold\n", " clause.append({\"feature\": feature, \"operation\": \">\", \"value\": threshold})\n", " \n", " a = list(value[node][0])\n", " ind = a.index(max(a))\n", " clause = {\"conditions\": clause, \"classification\": ind}\n", " data[\"paths\"].append(clause)\n", "\n", " for i, fe in enumerate(features):\n", " if tree.tree_.feature[i] != _tree.TREE_UNDEFINED:\n", " if not fe in data[\"features\"]:\n", " data[\"features\"][fe] = []\n", " data[\"features\"][fe].append(thresholds[i])\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 9, "id": "7f36344d", "metadata": {}, "outputs": [], "source": [ "# get feature names\n", "feature_names = Set1.columns\n", "file = open(outputfile, \"w+\")\n", "lineage = get_lineage(dt, feature_names)\n", "file.write(json.dumps(lineage, indent = 4))\n", "file.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "cf8832b9", "metadata": {}, "outputs": [], "source": [ "fig = plt.figure(figsize=(25,20))\n", "_ = plot_tree(dt, filled=True)" ] } ], "metadata": { "kernelspec": { "display_name": "switch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }