{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Decision Tree Training and Analysis" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sqlite3\n", "import os\n", "from datetime import datetime\n", "from joblib import dump, load\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import GridSearchCV, train_test_split\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import MinMaxScaler\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Data from Database" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Connect to the database\n", "conn = sqlite3.connect('../features.db')\n", "c = conn.cursor()\n", "\n", "# Get training, validation, and test data\n", "train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n", "valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n", "test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n", "\n", "# Close the connection\n", "conn.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Format Data for Machine Learning" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_x shape: (3502, 10)\n", "test_x shape: (438, 10)\n", "valid_x shape: (438, 10)\n", "features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n", "number of classes: 4\n" ] } ], "source": [ "# Get the target and features\n", "train_y = train['y'].map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n", "train_x = train.drop(columns=['y'])\n", "\n", "valid_y = valid['y'].map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n", "valid_x = valid.drop(columns=['y'])\n", "\n", "test_y = test['y'].map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n", "test_x = test.drop(columns=['y'])\n", "\n", "# Drop id column\n", "train_x = train_x.drop(columns=['id'])\n", "valid_x = valid_x.drop(columns=['id'])\n", "test_x = test_x.drop(columns=['id'])\n", "\n", "print('train_x shape:', train_x.shape)\n", "print('test_x shape:', test_x.shape)\n", "print('valid_x shape:', valid_x.shape)\n", "\n", "# Print column names\n", "print('features:', train_x.columns.to_list())\n", "feature_names = train_x.columns.to_list()\n", "\n", "# Create an imputer object with a mean filling strategy\n", "imputer = SimpleImputer(strategy='mean')\n", "\n", "train_x = imputer.fit_transform(train_x)\n", "valid_x = imputer.transform(valid_x)\n", "test_x = imputer.transform(test_x)\n", "\n", "# Scale data between 0 and 1\n", "scaler = MinMaxScaler()\n", "\n", "# Fit the scaler to your data and then transform it\n", "train_x = scaler.fit_transform(train_x)\n", "valid_x = scaler.transform(valid_x)\n", "test_x = scaler.transform(test_x)\n", "\n", "# Use DecisionTreeClassifier\n", "num_classes = len(set(valid_y.to_list()))\n", "print('number of classes:', num_classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test Grid for Hyperparameter Analysis" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "param_grid = {\n", " 'criterion': ['gini', 'entropy'],\n", " 'max_depth': [None, 10, 20, 30, 40, 50],\n", " 'min_samples_split': [2, 10, 20],\n", " 'min_samples_leaf': [1, 5, 10]\n", "}\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Create a DecisionTreeClassifier object\n", "model = DecisionTreeClassifier()\n", "\n", "# Create the grid search object\n", "grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GridSearchCV(cv=3, estimator=DecisionTreeClassifier(),\n", " param_grid={'criterion': ['gini', 'entropy'],\n", " 'max_depth': [None, 10, 20, 30, 40, 50],\n", " 'min_samples_leaf': [1, 5, 10],\n", " 'min_samples_split': [2, 10, 20]},\n", " scoring='accuracy')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
GridSearchCV(cv=3, estimator=DecisionTreeClassifier(),\n", " param_grid={'criterion': ['gini', 'entropy'],\n", " 'max_depth': [None, 10, 20, 30, 40, 50],\n", " 'min_samples_leaf': [1, 5, 10],\n", " 'min_samples_split': [2, 10, 20]},\n", " scoring='accuracy')
DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)
DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)
DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)