{ "cells": [ { "cell_type": "code", "execution_count": 138, "id": "938dec51", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import argparse\n", "from sklearn.tree import DecisionTreeClassifier, plot_tree, _tree\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.tree import export_graphviz\n", "import pydotplus\n", "from matplotlib import pyplot as plt\n", "from labels import mac_to_label\n", "import json\n", "import math" ] }, { "cell_type": "code", "execution_count": 139, "id": "442624c7", "metadata": {}, "outputs": [], "source": [ "Set1 = pd.read_csv('data/combined/data.csv').values.tolist()\n", "X = [i[0:3] for i in Set1]\n", "Y =[i[3] for i in Set1]" ] }, { "cell_type": "code", "execution_count": 142, "id": "12ad454d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'0': 20, '1': 20, '2': 9, '3': 20, '4': 0, '5': 13, '6': 20, '7': 0, '8': 12, '9': 4, '10': 20, '11': 4, '12': 1, '13': 16, '14': 20, '15': 2, '16': 20, '17': 0, '18': 20, '19': 20, '20': 20, '21': 20, '22': 20, '23': 1, '24': 2, '25': 20, '26': 13, '27': 11, '28': 20, '29': 20}\n" ] } ], "source": [ "predict_Yt = []\n", "index=0\n", "\n", "with open('compressed_tree.json', 'r') as file:\n", " data = json.load(file)\n", " classes = data[\"classes\"]\n", " for x in X:\n", " counter = 0\n", " class_set = []\n", " paths_set = []\n", " features = [\"protocol\", \"src\", \"dst\"]\n", " for feature in features:\n", " if feature in data[\"layers\"]:\n", " for node in data['layers'][feature]:\n", " if node['min'] is None:\n", " if x[counter] <= node['max']:\n", " class_set.append(node['classes'])\n", " paths_set.append(node[\"paths\"])\n", " break #is this an issue?\n", " else:\n", " continue\n", " elif node['max'] is None:\n", " if node['min'] < x[counter]:\n", " class_set.append(node['classes'])\n", " paths_set.append(node[\"paths\"])\n", " break #is this an issue?\n", " else:\n", " continue\n", " elif node['min'] < x[counter] and x[counter] <= node['max']:\n", " class_set.append(node['classes'])\n", " paths_set.append(node[\"paths\"])\n", " break #is this an issue?\n", "\n", " counter += 1\n", " result = set(class_set[0])\n", " paths = set(paths_set[0])\n", " for s in class_set[1:]:\n", " result.intersection_update(s)\n", " for s in paths_set[1:]:\n", " paths.intersection_update(s)\n", "\n", " #predict_Yt.append(list(result))\n", " #print(result)\n", " if len(paths) != 1:\n", " print(paths)\n", " print(x)\n", " print(result)\n", " assert len(paths) == 1\n", " path = list(paths)[0]\n", " pred = data[\"path_to_class\"][str(path)]\n", " pred_class = classes[pred]\n", " predict_Yt.append(pred_class)\n", " \n", " index += 1" ] }, { "cell_type": "code", "execution_count": 143, "id": "8b4c56b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8410252791654538\n" ] } ], "source": [ "correct = 0\n", "for i in range(len(Y)):\n", " prediction = predict_Yt[i]\n", " if prediction != None and Y[i] == prediction:\n", " correct += 1\n", "\n", "print(correct / len(Y))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }