2024-06-21 16:35:16 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Decison Tree"
]
},
2024-06-21 17:05:39 +02:00
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import sqlite3\n",
"import os\n",
"from datetime import datetime\n",
"from joblib import dump, load\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import xgboost as xgb\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.metrics import confusion_matrix, f1_score\n",
"from sklearn.ensemble import GradientBoostingClassifier\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"\n",
"import seaborn as sns"
]
},
2024-06-21 16:35:16 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Data from Database"
]
},
{
"cell_type": "code",
2024-06-21 17:05:39 +02:00
"execution_count": 7,
2024-06-21 16:35:16 +02:00
"metadata": {},
"outputs": [],
"source": [
"# connect to the database\n",
"conn = sqlite3.connect('../features.db')\n",
"c = conn.cursor()\n",
"# get training, validation and test data\n",
"train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n",
"valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n",
"test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n",
"# close the connection\n",
"conn.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Format Data for Machine Learning"
]
},
{
"cell_type": "code",
2024-06-21 17:05:39 +02:00
"execution_count": 8,
2024-06-21 16:35:16 +02:00
"metadata": {},
2024-06-21 17:05:39 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_x shape: (3502, 10)\n",
"test_x shape: (438, 10)\n",
"valid_x shape: (438, 10)\n",
"features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n",
"number of classes: 4\n"
]
}
],
2024-06-21 16:35:16 +02:00
"source": [
"# get the target and features\n",
"train_y = train['y']\n",
"train_y = train_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"train_x = train.drop(columns=['y'])\n",
"\n",
"valid_y = valid['y']\n",
"valid_y = valid_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"valid_x = valid.drop(columns=['y'])\n",
"\n",
"test_y = test['y']\n",
"test_y = test_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"test_x = test.drop(columns=['y'])\n",
"\n",
"# drop id column\n",
"train_x = train_x.drop(columns=['id'])\n",
"valid_x = valid_x.drop(columns=['id'])\n",
"test_x = test_x.drop(columns=['id'])\n",
"\n",
"print('train_x shape:', train_x.shape)\n",
"print('test_x shape:', test_x.shape)\n",
"print('valid_x shape:', valid_x.shape)\n",
"# print column names\n",
"print('features:', train_x.columns.to_list())\n",
"feature_names = train_x.columns.to_list()\n",
"\n",
"# Create an imputer object with a mean filling strategy\n",
"imputer = SimpleImputer(strategy='mean')\n",
"\n",
"train_x = imputer.fit_transform(train_x)\n",
"valid_x = imputer.transform(valid_x)\n",
"test_x = imputer.transform(test_x)\n",
"\n",
"# Scale Data between 0 and 1\n",
"scaler = MinMaxScaler()\n",
"# Fit the scaler to your data and then transform it\n",
"train_x = scaler.fit_transform(train_x)\n",
"valid_x = scaler.transform(valid_x)\n",
"test_x = scaler.transform(test_x)\n",
"\n",
"\n",
"\n",
"# use xgboost\n",
"dtrain = xgb.DMatrix(train_x, label=train_y)\n",
"dvalid = xgb.DMatrix(valid_x, label=valid_y)\n",
"dtest = xgb.DMatrix(test_x, label=test_y)\n",
"\n",
"num_classes= len(set(valid_y.to_list()))\n",
"print('number of classes:', num_classes)"
]
2024-06-21 17:05:39 +02:00
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validierungsgenauigkeit: 0.7557077625570776\n",
"Testgenauigkeit: 0.7922374429223744\n"
]
}
],
"source": [
2024-06-21 17:24:25 +02:00
"\n",
2024-06-21 17:05:39 +02:00
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
2024-06-21 17:24:25 +02:00
"\n",
2024-06-21 17:05:39 +02:00
"# Beispiel: Begrenzung der Tiefe des Baumes\n",
"dt_classifier = DecisionTreeClassifier(max_depth=5)\n",
"\n",
2024-06-21 17:24:25 +02:00
"# Schritt 3: Trainieren des Modells mit Trainingsdaten\n",
2024-06-21 17:05:39 +02:00
"dt_classifier.fit(train_x, train_y)\n",
"\n",
2024-06-21 17:24:25 +02:00
"# Schritt 4: Bewertung des Modells mit Validierungsdaten\n",
2024-06-21 17:05:39 +02:00
"valid_pred = dt_classifier.predict(valid_x)\n",
"valid_accuracy = accuracy_score(valid_y, valid_pred)\n",
"print(f'Validierungsgenauigkeit: {valid_accuracy}')\n",
"\n",
2024-06-21 17:24:25 +02:00
"# Schritt 5: Hyperparameter-Optimierung\n",
2024-06-21 17:05:39 +02:00
"\n",
"# Schritt 6: Endgültige Bewertung mit Testdaten\n",
"test_pred = dt_classifier.predict(test_x)\n",
"test_accuracy = accuracy_score(test_y, test_pred)\n",
"print(f'Testgenauigkeit: {test_accuracy}')\n"
]
2024-06-21 19:02:32 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Die Validierungsgenauigkeit des Modells liegt bei 75,5%, was darauf hinweist, dass das Modell in etwa drei Vierteln der Fälle korrekte Vorhersagen auf den Validierungsdaten macht. Dies zeigt eine recht solide Leistung, deutet jedoch auch darauf hin, dass es noch Verbesserungspotenzial gibt, insbesondere bei der Verfeinerung des Modells, um die Fehlerquote zu senken"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mit einer Testgenauigkeit von 79% klassifiziert das Modell die Testdaten überwiegend korrekt. Dieses Ergebnis ist ein Indikator dafür, dass das Modell eine gute Generalisierungsfähigkeit aufweist und zuverlässig auf neuen, unbekannten Daten agieren kann. "
]
2024-06-21 16:35:16 +02:00
}
],
"metadata": {
2024-06-21 17:05:39 +02:00
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
2024-06-21 16:35:16 +02:00
"language_info": {
2024-06-21 17:05:39 +02:00
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
2024-06-21 16:35:16 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}