{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d5618056", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import argparse\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.tree import export_graphviz\n", "import pydotplus" ] }, { "cell_type": "code", "execution_count": 2, "id": "d336971a", "metadata": {}, "outputs": [], "source": [ "# extract argument\n", "inputfile = \"data.csv\"\n", "outputfile = \"tree\"\n", "#testfile = args.t\n", "\n", "# output the tree\n", "def get_lineage(tree, feature_names, file):\n", " proto = []\n", " src = []\n", " dst = []\n", " left = tree.tree_.children_left\n", " right = tree.tree_.children_right\n", " threshold = tree.tree_.threshold\n", " features = [feature_names[i] for i in tree.tree_.feature]\n", " value = tree.tree_.value\n", " le = '<='\n", " g = '>'\n", " # get ids of child nodes\n", " idx = np.argwhere(left == -1)[:, 0]\n", " \n", " # traverse the tree and get the node information\n", " def recurse(left, right, child, lineage=None):\n", " if lineage is None:\n", " lineage = [child]\n", " if child in left:\n", " parent = np.where(left == child)[0].item()\n", " split = 'l'\n", " else:\n", " parent = np.where(right == child)[0].item()\n", " split = 'r'\n", " \n", " lineage.append((parent, split, threshold[parent], features[parent]))\n", " if parent == 0:\n", " lineage.reverse()\n", " return lineage\n", " else:\n", " return recurse(left, right, parent, lineage)\n", "\n", " for j, child in enumerate(idx):\n", " clause = ' when '\n", " for node in recurse(left, right, child):\n", " if len(str(node)) < 3:\n", " continue\n", " i = node\n", " \n", " if i[1] == 'l':\n", " sign = le\n", " else:\n", " sign = g\n", " clause = clause + i[3] + sign + str(i[2]) + ' and '\n", " \n", " # wirte the node information into text file\n", " a = list(value[node][0])\n", " ind = a.index(max(a))\n", " clause = clause[:-4] + ' then ' + str(ind)\n", " file.write(clause)\n", " file.write(\";\\n\")\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "b96f3403", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy: 0.879490682862549\n", "test accuracy: 0.879490682862549\n" ] } ], "source": [ "# Training set X and Y\n", "Set1 = pd.read_csv(inputfile)\n", "Set = Set1.values.tolist()\n", "X = [i[0:3] for i in Set]\n", "Y =[i[3] for i in Set]\n", "\n", "# Test set Xt and Yt\n", "Set2 = pd.read_csv(inputfile)\n", "Sett = Set2.values.tolist()\n", "Xt = [i[0:3] for i in Set]\n", "Yt =[i[3] for i in Set]\n", "\n", "#class_names=['iperf','memcached','ping','sparkglm','sparkkmeans']\n", "feature_names=['proto','src','dst']\n", "\n", "# prepare training and testing set\n", "X = np.array(X)\n", "Y = np.array(Y)\n", "Xt = np.array(Xt)\n", "Yt = np.array(Yt)\n", "\n", "# decision tree fit\n", "dt = DecisionTreeClassifier(max_depth = 5)\n", "dt.fit(X, Y)\n", "Predict_Y = dt.predict(X)\n", "print(f\"train accuracy: {accuracy_score(Y, Predict_Y)}\")\n", "\n", "Predict_Yt = dt.predict(Xt)\n", "print(f\"test accuracy: {accuracy_score(Yt, Predict_Yt)}\")\n", "\n", "# output the tree in a text file, write it\n", "threshold = dt.tree_.threshold\n", "features = [feature_names[i] for i in dt.tree_.feature]\n", "proto = []\n", "src = []\n", "dst = []\n", "for i, fe in enumerate(features):\n", " \n", " if fe == 'proto':\n", " proto.append(threshold[i])\n", " elif fe == 'src':\n", " if threshold[i] != -2.0:\n", " src.append(threshold[i])\n", " else:\n", " dst.append(threshold[i])\n", "proto = [int(i) for i in proto]\n", "src = [int(i) for i in src]\n", "dst = [int(i) for i in dst]\n", "proto.sort()\n", "src.sort()\n", "dst.sort()\n", "tree = open(outputfile,\"w+\")\n", "tree.write(\"proto = \")\n", "tree.write(str(proto))\n", "tree.write(\";\\n\")\n", "tree.write(\"src = \")\n", "tree.write(str(src))\n", "tree.write(\";\\n\")\n", "tree.write(\"dst = \")\n", "tree.write(str(dst))\n", "tree.write(\";\\n\")\n", "get_lineage(dt,feature_names,tree)\n", "tree.close()" ] } ], "metadata": { "kernelspec": { "display_name": "switch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }