2024-06-21 16:35:16 +02:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Decison Tree"
|
|
|
|
]
|
|
|
|
},
|
2024-06-21 17:05:39 +02:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 6,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import sqlite3\n",
|
|
|
|
"import os\n",
|
|
|
|
"from datetime import datetime\n",
|
|
|
|
"from joblib import dump, load\n",
|
|
|
|
"import pandas as pd\n",
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
"import xgboost as xgb\n",
|
|
|
|
"from sklearn.model_selection import GridSearchCV\n",
|
|
|
|
"from sklearn.metrics import confusion_matrix, f1_score\n",
|
|
|
|
"from sklearn.ensemble import GradientBoostingClassifier\n",
|
|
|
|
"from sklearn.impute import SimpleImputer\n",
|
|
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
|
|
"from sklearn.preprocessing import MinMaxScaler\n",
|
|
|
|
"\n",
|
|
|
|
"import seaborn as sns"
|
|
|
|
]
|
|
|
|
},
|
2024-06-21 16:35:16 +02:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Import Data from Database"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-06-21 17:05:39 +02:00
|
|
|
"execution_count": 7,
|
2024-06-21 16:35:16 +02:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# connect to the database\n",
|
|
|
|
"conn = sqlite3.connect('../features.db')\n",
|
|
|
|
"c = conn.cursor()\n",
|
|
|
|
"# get training, validation and test data\n",
|
|
|
|
"train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n",
|
|
|
|
"valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n",
|
|
|
|
"test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n",
|
|
|
|
"# close the connection\n",
|
|
|
|
"conn.close()"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Format Data for Machine Learning"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-06-21 17:05:39 +02:00
|
|
|
"execution_count": 8,
|
2024-06-21 16:35:16 +02:00
|
|
|
"metadata": {},
|
2024-06-21 17:05:39 +02:00
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"train_x shape: (3502, 10)\n",
|
|
|
|
"test_x shape: (438, 10)\n",
|
|
|
|
"valid_x shape: (438, 10)\n",
|
|
|
|
"features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n",
|
|
|
|
"number of classes: 4\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
2024-06-21 16:35:16 +02:00
|
|
|
"source": [
|
|
|
|
"# get the target and features\n",
|
|
|
|
"train_y = train['y']\n",
|
|
|
|
"train_y = train_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
|
|
|
|
"train_x = train.drop(columns=['y'])\n",
|
|
|
|
"\n",
|
|
|
|
"valid_y = valid['y']\n",
|
|
|
|
"valid_y = valid_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
|
|
|
|
"valid_x = valid.drop(columns=['y'])\n",
|
|
|
|
"\n",
|
|
|
|
"test_y = test['y']\n",
|
|
|
|
"test_y = test_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
|
|
|
|
"test_x = test.drop(columns=['y'])\n",
|
|
|
|
"\n",
|
|
|
|
"# drop id column\n",
|
|
|
|
"train_x = train_x.drop(columns=['id'])\n",
|
|
|
|
"valid_x = valid_x.drop(columns=['id'])\n",
|
|
|
|
"test_x = test_x.drop(columns=['id'])\n",
|
|
|
|
"\n",
|
|
|
|
"print('train_x shape:', train_x.shape)\n",
|
|
|
|
"print('test_x shape:', test_x.shape)\n",
|
|
|
|
"print('valid_x shape:', valid_x.shape)\n",
|
|
|
|
"# print column names\n",
|
|
|
|
"print('features:', train_x.columns.to_list())\n",
|
|
|
|
"feature_names = train_x.columns.to_list()\n",
|
|
|
|
"\n",
|
|
|
|
"# Create an imputer object with a mean filling strategy\n",
|
|
|
|
"imputer = SimpleImputer(strategy='mean')\n",
|
|
|
|
"\n",
|
|
|
|
"train_x = imputer.fit_transform(train_x)\n",
|
|
|
|
"valid_x = imputer.transform(valid_x)\n",
|
|
|
|
"test_x = imputer.transform(test_x)\n",
|
|
|
|
"\n",
|
|
|
|
"# Scale Data between 0 and 1\n",
|
|
|
|
"scaler = MinMaxScaler()\n",
|
|
|
|
"# Fit the scaler to your data and then transform it\n",
|
|
|
|
"train_x = scaler.fit_transform(train_x)\n",
|
|
|
|
"valid_x = scaler.transform(valid_x)\n",
|
|
|
|
"test_x = scaler.transform(test_x)\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
"# use xgboost\n",
|
|
|
|
"dtrain = xgb.DMatrix(train_x, label=train_y)\n",
|
|
|
|
"dvalid = xgb.DMatrix(valid_x, label=valid_y)\n",
|
|
|
|
"dtest = xgb.DMatrix(test_x, label=test_y)\n",
|
|
|
|
"\n",
|
|
|
|
"num_classes= len(set(valid_y.to_list()))\n",
|
|
|
|
"print('number of classes:', num_classes)"
|
|
|
|
]
|
2024-06-21 17:05:39 +02:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 9,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"Validierungsgenauigkeit: 0.7557077625570776\n",
|
|
|
|
"Testgenauigkeit: 0.7922374429223744\n"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-06-21 17:24:25 +02:00
|
|
|
"\n",
|
2024-06-21 17:05:39 +02:00
|
|
|
"from sklearn.tree import DecisionTreeClassifier\n",
|
|
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
|
|
"\n",
|
2024-06-21 17:24:25 +02:00
|
|
|
"\n",
|
2024-06-21 17:05:39 +02:00
|
|
|
"# Beispiel: Begrenzung der Tiefe des Baumes\n",
|
|
|
|
"dt_classifier = DecisionTreeClassifier(max_depth=5)\n",
|
|
|
|
"\n",
|
2024-06-21 17:24:25 +02:00
|
|
|
"# Schritt 3: Trainieren des Modells mit Trainingsdaten\n",
|
2024-06-21 17:05:39 +02:00
|
|
|
"dt_classifier.fit(train_x, train_y)\n",
|
|
|
|
"\n",
|
2024-06-21 17:24:25 +02:00
|
|
|
"# Schritt 4: Bewertung des Modells mit Validierungsdaten\n",
|
2024-06-21 17:05:39 +02:00
|
|
|
"valid_pred = dt_classifier.predict(valid_x)\n",
|
|
|
|
"valid_accuracy = accuracy_score(valid_y, valid_pred)\n",
|
|
|
|
"print(f'Validierungsgenauigkeit: {valid_accuracy}')\n",
|
|
|
|
"\n",
|
2024-06-21 17:24:25 +02:00
|
|
|
"# Schritt 5: Hyperparameter-Optimierung\n",
|
2024-06-21 17:05:39 +02:00
|
|
|
"\n",
|
|
|
|
"# Schritt 6: Endgültige Bewertung mit Testdaten\n",
|
|
|
|
"test_pred = dt_classifier.predict(test_x)\n",
|
|
|
|
"test_accuracy = accuracy_score(test_y, test_pred)\n",
|
|
|
|
"print(f'Testgenauigkeit: {test_accuracy}')\n"
|
|
|
|
]
|
2024-06-21 16:35:16 +02:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
2024-06-21 17:05:39 +02:00
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
2024-06-21 16:35:16 +02:00
|
|
|
"language_info": {
|
2024-06-21 17:05:39 +02:00
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.11.9"
|
2024-06-21 16:35:16 +02:00
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 2
|
|
|
|
}
|