DSA_SS24/notebooks/ml_grad_boost_tree.ipynb

508 lines
110 KiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gradient Boosting Tree (GBT) Training and Analysis"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import sqlite3\n",
"import os\n",
"from datetime import datetime\n",
"from joblib import dump, load\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import xgboost as xgb\n",
"from sklearn.model_selection import GridSearchCV\n",
2024-06-12 17:19:27 +02:00
"from sklearn.metrics import confusion_matrix, f1_score\n",
"from sklearn.ensemble import GradientBoostingClassifier\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"\n",
"import seaborn as sns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Data from Database"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# connect to the database\n",
"conn = sqlite3.connect('../features.db')\n",
"c = conn.cursor()\n",
"# get training, validation and test data\n",
"train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n",
"valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n",
"test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n",
"# close the connection\n",
"conn.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Format Data for Machine Learning"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_x shape: (3502, 10)\n",
"test_x shape: (438, 10)\n",
"valid_x shape: (438, 10)\n",
"features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n",
"number of classes: 4\n"
]
}
],
"source": [
"# get the target and features\n",
"train_y = train['y']\n",
"train_y = train_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"train_x = train.drop(columns=['y'])\n",
"\n",
"valid_y = valid['y']\n",
"valid_y = valid_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"valid_x = valid.drop(columns=['y'])\n",
"\n",
"test_y = test['y']\n",
"test_y = test_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"test_x = test.drop(columns=['y'])\n",
"\n",
"# drop id column\n",
"train_x = train_x.drop(columns=['id'])\n",
"valid_x = valid_x.drop(columns=['id'])\n",
"test_x = test_x.drop(columns=['id'])\n",
"\n",
"print('train_x shape:', train_x.shape)\n",
"print('test_x shape:', test_x.shape)\n",
"print('valid_x shape:', valid_x.shape)\n",
"# print column names\n",
"print('features:', train_x.columns.to_list())\n",
"feature_names = train_x.columns.to_list()\n",
"\n",
"# Create an imputer object with a mean filling strategy\n",
"imputer = SimpleImputer(strategy='mean')\n",
"\n",
"train_x = imputer.fit_transform(train_x)\n",
"valid_x = imputer.transform(valid_x)\n",
"test_x = imputer.transform(test_x)\n",
"\n",
"# Scale Data between 0 and 1\n",
"scaler = MinMaxScaler()\n",
"# Fit the scaler to your data and then transform it\n",
"train_x = scaler.fit_transform(train_x)\n",
"valid_x = scaler.transform(valid_x)\n",
"test_x = scaler.transform(test_x)\n",
"\n",
"\n",
"\n",
"# use xgboost\n",
"dtrain = xgb.DMatrix(train_x, label=train_y)\n",
"dvalid = xgb.DMatrix(valid_x, label=valid_y)\n",
"dtest = xgb.DMatrix(test_x, label=test_y)\n",
"\n",
"num_classes= len(set(valid_y.to_list()))\n",
"print('number of classes:', num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test Grid for Hyperparameter Analysis"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"param_grid = {\n",
" 'n_estimators': [100, 200, 300],\n",
" 'learning_rate': [0.1, 0.2, 0.3],\n",
" 'max_depth': [1, 3, 5],\n",
"}# 'random_stat': 42"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Create a XGBClassifier object\n",
"model = GradientBoostingClassifier()\n",
"\n",
"# Create the grid search object\n",
"grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-12 17:19:27 +02:00
"CPU times: total: 2min 49s\n",
"Wall time: 4min 28s\n"
]
},
{
"data": {
"text/html": [
2024-06-12 17:19:27 +02:00
"<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\
" param_grid={&#x27;learning_rate&#x27;: [0.1, 0.2, 0.3],\n",
" &#x27;max_depth&#x27;: [1, 3, 5],\n",
" &#x27;n_estimators&#x27;: [100, 200, 300]},\n",
2024-06-12 17:19:27 +02:00
" scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3, estimator=GradientBoostingClassifier(),\n",
" param_grid={&#x27;learning_rate&#x27;: [0.1, 0.2, 0.3],\n",
" &#x27;max_depth&#x27;: [1, 3, 5],\n",
" &#x27;n_estimators&#x27;: [100, 200, 300]},\n",
2024-06-12 17:19:27 +02:00
" scoring=&#x27;accuracy&#x27;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier()</pre></div></div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"GridSearchCV(cv=3, estimator=GradientBoostingClassifier(),\n",
" param_grid={'learning_rate': [0.1, 0.2, 0.3],\n",
" 'max_depth': [1, 3, 5],\n",
" 'n_estimators': [100, 200, 300]},\n",
" scoring='accuracy')"
]
},
2024-06-12 17:19:27 +02:00
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# Fit the grid search object to the data\n",
"grid_search.fit(train_x, train_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Results"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters: {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 100}\n",
2024-06-12 17:19:27 +02:00
"Best score: 0.7969733696438982\n"
]
}
],
"source": [
"# Print the best parameters and the best score\n",
"print(f'Best parameters: {grid_search.best_params_}')\n",
"print(f'Best score: {grid_search.best_score_}')\n",
"#{'eta': 0.1, 'learning_rate': 0.1, 'max_depth': 5, 'min_child_weight': 3, 'n_estimators': 100}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-06-12 17:19:27 +02:00
"['../ml_models/best_gbt_model_20240612171757.joblib']"
]
},
2024-06-12 17:19:27 +02:00
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Save the best model\n",
"best_model = grid_search.best_estimator_\n",
"# timestamp\n",
"timestamp = datetime.now().strftime('%Y%m%d%H%M%S')\n",
"model_path = f'../ml_models/best_gbt_model_{timestamp}.joblib'\n",
"dump(best_model, model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example Training of best Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load the best model to get the best hyperparameters from it"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# list directory\n",
"models = os.listdir('../ml_models')\n",
"model_path = [model for model in models if 'joblib' in model and 'best' in model and 'gbt' in model][0]\n",
"model_path = f'../ml_models/{model_path}'\n",
"# load the best model\n",
"best_model = load(model_path)"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
2024-06-12 17:19:27 +02:00
"<style>#sk-container-id-2 {color: black;background-color: white;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\
],
"text/plain": [
"GradientBoostingClassifier()"
]
},
2024-06-12 17:19:27 +02:00
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# example training of a model with the best parameters\n",
"eval_result = {}\n",
"model = GradientBoostingClassifier(**grid_search.best_params_)\n",
"model.fit(train_x, train_y)"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Accuracy: 0.819634703196347\n"
]
}
],
"source": [
"preds = best_model.predict(test_x)\n",
"accuracy = accuracy_score(test_y, preds)\n",
"print(f\"Model Accuracy: {accuracy}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate Model Performance"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAokAAAIjCAYAAABvUIGpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABZdElEQVR4nO3deZyN9fvH8feZYRYzZsZgNvu+L1myh8guoqT0DQmJCCkju5gSWUPZibJUQqVskRp7lmxZUzSMnTHGmLl/f/g5ddyU4Zy5Z5zXs8f9eHzvz32fz32dc750dX2WYzMMwxAAAADwDx5WBwAAAIC0hyQRAAAAJiSJAAAAMCFJBAAAgAlJIgAAAExIEgEAAGBCkggAAAATkkQAAACYkCQCAADAhCQRwL86ePCg6tWrp8DAQNlsNi1ZssSp/R87dkw2m02zZs1yar/pWa1atVSrVi2rwwDg5kgSgXTg8OHD6ty5s/Lnzy8fHx8FBASoWrVqGjdunOLj41367LZt22r37t0aPny45s6dqwoVKrj0eampXbt2stlsCggIuOPnePDgQdlsNtlsNo0aNSrF/Z88eVKDBw/Wjh07nBAtAKSuDFYHAODfff3113rmmWfk7e2tF198USVLltT169e1YcMG9enTR3v27NHHH3/skmfHx8crOjpab7/9trp16+aSZ+TJk0fx8fHKmDGjS/r/LxkyZNDVq1e1bNkytWrVyuHavHnz5OPjo2vXrt1X3ydPntSQIUOUN29elS1b9p5f9/3339/X8wDAmUgSgTTs6NGjat26tfLkyaM1a9YoPDzcfq1r1646dOiQvv76a5c9PzY2VpIUFBTksmfYbDb5+Pi4rP//4u3trWrVqunTTz81JYnz589X48aN9fnnn6dKLFevXlWmTJnk5eWVKs8DgH/DcDOQho0cOVJXrlzR9OnTHRLEWwoWLKgePXrYz2/cuKFhw4apQIEC8vb2Vt68edWvXz8lJCQ4vC5v3rxq0qSJNmzYoEcffVQ+Pj7Knz+/5syZY79n8ODBypMnjySpT58+stlsyps3r6Sbw7S3/vc/DR48WDabzaFt5cqVql69uoKCguTv768iRYqoX79+9ut3m5O4Zs0a1ahRQ35+fgoKClKzZs20b9++Oz7v0KFDateunYKCghQYGKj27dvr6tWrd/9gb/P888/r22+/1YULF+xtW7Zs0cGDB/X888+b7j937pzeeOMNlSpVSv7+/goICFDDhg21c+dO+z0//PCDKlasKElq3769fdj61vusVauWSpYsqW3btumxxx5TpkyZ7J/L7XMS27ZtKx8fH9P7r1+/vrJkyaKTJ0/e83sFgHtFkgikYcuWLVP+/PlVtWrVe7r/5Zdf1sCBA1WuXDmNGTNGNWvWVFRUlFq3bm2699ChQ3r66af1xBNPaPTo0cqSJYvatWunPXv2SJJatGihMWPGSJKee+45zZ07V2PHjk1R/Hv27FGTJk2UkJCgoUOHavTo0XryySf1008//evrVq1apfr16+v06dMaPHiwevXqpZ9//lnVqlXTsWPHTPe3atVKly9fVlRUlFq1aqVZs2ZpyJAh9xxnixYtZLPZ9MUXX9jb5s+fr6JFi6pcuXKm+48cOaIlS5aoSZMm+uCDD9SnTx/t3r1bNWvWtCdsxYoV09ChQyVJnTp10ty5czV37lw99thj9n7Onj2rhg0bqmzZsho7dqxq1659x/jGjRun7Nmzq23btkpKSpIkffTRR/r+++81YcIERURE3PN7BYB7ZgBIky5evGhIMpo1a3ZP9+/YscOQZLz88ssO7W+88YYhyVizZo29LU+ePIYkY/369fa206dPG97e3kbv3r3tbUePHjUkGe+//75Dn23btjXy5MljimHQoEHGP/9aGTNmjCHJiI2NvWvct54xc+ZMe1vZsmWNkJAQ4+zZs/a2nTt3Gh4eHsaLL75oet5LL73k0OdTTz1lZM2a9a7P/Of78PPzMwzDMJ5++mmjTp06hmEYRlJSkhEWFmYMGTLkjp/BtWvXjKSkJNP78Pb2NoYOHWpv27Jli+m93VKzZk1DkjFlypQ7XqtZs6ZD23fffWdIMt555x3jyJEjhr+/v9G8efP/fI8AcL+oJAJp1KVLlyRJmTNnvqf7v/nmG0lSr169HNp79+4tSaa5i8WLF1eNGjXs59mzZ1eRIkV05MiR+475drfmMn711VdKTk6+p9f89ddf2rFjh9q1a6fg4GB7e+nSpfXEE0/Y3+c/vfLKKw7nNWrU0NmzZ+2f4b14/vnn9cMPPygmJkZr1qxRTEzMHYeapZvzGD08bv71mZSUpLNnz9qH0rdv337Pz/T29lb79u3v6d569eqpc+fOGjp0qFq0aCEfHx999NFH9/wsAEgpkkQgjQoICJAkXb58+Z7u//333+Xh4aGCBQs6tIeFhSkoKEi///67Q3vu3LlNfWTJkkXnz5+/z4jNnn32WVWrVk0vv/yyQkND1bp1ay1cuPBfE8ZbcRYpUsR0rVixYjpz5ozi4uIc2m9/L1myZJGkFL2XRo0aKXPmzFqwYIHmzZunihUrmj7LW5KTkzVmzBgVKlRI3t7eypYtm7Jnz65du3bp4sWL9/zMHDlypGiRyqhRoxQcHKwdO3Zo/PjxCgkJuefXAkBKkSQCaVRAQIAiIiL066+/puh1ty8cuRtPT887thuGcd/PuDVf7hZfX1+tX79eq1at0v/+9z/t2rVLzz77rJ544gnTvQ/iQd7LLd7e3mrRooVmz56tL7/88q5VREkaMWKEevXqpccee0yffPKJvvvuO61cuVIlSpS454qpdPPzSYlffvlFp0+fliTt3r07Ra8FgJQiSQTSsCZNmujw4cOKjo7+z3vz5Mmj5ORkHTx40KH91KlTunDhgn2lsjNkyZLFYSXwLbdXKyXJw8NDderU0QcffKC9e/dq+PDhWrNmjdauXXvHvm/FeeDAAdO1/fv3K1u2bPLz83uwN3AXzz//vH755Rddvnz5jot9blm8eLFq166t6dOnq3Xr1qpXr57q1q1r+kzuNWG/F3FxcWrfvr2KFy+uTp06aeTIkdqyZYvT+geA25EkAmnYm2++KT8/P7388ss6deqU6frhw4c1btw4STeHSyWZViB/8MEHkqTGjRs7La4CBQro4sWL2rVrl73tr7/+0pdffulw37lz50yvvbWp9O3b8twSHh6usmXLavbs2Q5J16+//qrvv//e/j5doXbt2ho2bJgmTpyosLCwu97n6elpqlIuWrRIJ06ccGi7lczeKaFOqbfeekvHjx/X7Nmz9cEHHyhv3rxq27btXT9HAHhQbKYNpGEFChTQ/Pnz9eyzz6pYsWIOv7jy888/a9GiRWrXrp0kqUyZMmrbtq0+/vhjXbhwQTVr1tTmzZs1e/ZsNW/e/K7bq9yP1q1b66233tJTTz2l7t276+rVq5o8ebIKFy7ssHBj6NChWr9+vRo3bqw8efLo9OnTmjRpknLmzKnq1avftf/3339fDRs2VJUqVdShQwfFx8drwoQJCgwM1ODBg532Pm7n4eGh/v37/+d9TZo00dChQ9W+fXtVrVpVu3fv1rx585Q/f36H+woUKKCgoCBNmTJFmTNnlp+fnypVqqR8+fKlKK41a9Zo0qRJGjRokH1LnpkzZ6pWrVoaMGCARo4cmaL+AOCeWLy6GsA9+O2334yOHTsaefPmNby8vIzMmTMb1apVMyZMmGBcu3bNfl9iYqIxZMgQI1++fEbGjBmNXLlyGZGRkQ73GMbNLXAaN25ses7tW6/cbQscwzCM77//3ihZsqTh5eVlFClSxPjkk09MW+CsXr3aaNasmREREWF4eXkZERERxnPPPWf89ttvpmfcvk3MqlWrjGrVqhm+vr5GQECA0bRpU2Pv3r0O99x63u1b7MycOdOQZBw9evSun6lhOG6Bczd32wKnd+/eRnh4uOHr62tUq1bNiI6OvuPWNV999ZVRvHhxI0OGDA7vs2bNmkaJEiXu+Mx/9nPp0iUjT548Rrly5YzExESH+3r27Gl4eHgY0dHR//oeAOB+2AwjBTO7AQA
"text/plain": [
"<Figure size 800x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Get the confusion matrix\n",
"cm = confusion_matrix(test_y, preds)\n",
"\n",
"# Create a new figure\n",
"plt.figure(figsize=(8, 6))\n",
"\n",
"labels = ['GSVT', 'AFIB', 'SR', 'SB']\n",
"# Plot the confusion matrix\n",
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)\n",
"plt.xlabel('Predicted')\n",
"plt.ylabel('Actual')\n",
"plt.title('Confusion Matrix')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
2024-06-12 17:19:27 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAJOCAYAAACqS2TfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABfaUlEQVR4nO3deVhUdeP+8XtAAUFZXEAlYhGXzH1Nc8tIs8XM76OWlkppPWVpoqZW4pqQppKPe2lmWWlmu7lEVmbmrrnkvqAmuINgKsL5/eHPqQk0QI6Hgffruua65DNnZm6OOnDP+ZzPsRmGYQgAAAAAAOQ7F6sDAAAAAABQWFG6AQAAAAAwCaUbAAAAAACTULoBAAAAADAJpRsAAAAAAJNQugEAAAAAMAmlGwAAAAAAk1C6AQAAAAAwCaUbAAAAAACTULoBAAAAADAJpRsAUKDNnTtXNpst29uQIUNMec1ffvlFI0aM0Llz50x5/ptxbX9s2LDB6ih5Nm3aNM2dO9fqGAAA3BLFrA4AAEBOjBo1SqGhoQ5jNWrUMOW1fvnlF40cOVI9e/aUr6+vKa9RlE2bNk1ly5ZVz549rY4CAIDpKN0AAKfQrl07NWjQwOoYNyUtLU1eXl5Wx7DMhQsX5OnpaXUMAABuKaaXAwAKhW+//VbNmzeXl5eXSpUqpQcffFA7duxw2Oa3335Tz549FRYWJg8PD5UvX15PPfWUTp8+bd9mxIgRGjRokCQpNDTUPpX90KFDOnTokGw2W7ZTo202m0aMGOHwPDabTTt37lTXrl3l5+enZs2a2e//4IMPVL9+fZUoUUKlS5fWY489piNHjuTpe+/Zs6dKliyphIQEPfTQQypZsqQCAwM1depUSdK2bdvUunVreXl5KTg4WB9++KHD469NWf/pp5/07LPPqkyZMvL29lb37t119uzZLK83bdo03XnnnXJ3d1fFihXVp0+fLFPxW7VqpRo1amjjxo1q0aKFPD099corrygkJEQ7duzQjz/+aN+3rVq1kiSdOXNGAwcOVM2aNVWyZEl5e3urXbt22rp1q8Nz//DDD7LZbFq4cKFef/113XbbbfLw8NC9996rffv2Zcm7du1aPfDAA/Lz85OXl5dq1aqlt956y2GbXbt26T//+Y9Kly4tDw8PNWjQQF9++aXDNunp6Ro5cqQqV64sDw8PlSlTRs2aNdOKFSty9PcEACiaONINAHAKycnJOnXqlMNY2bJlJUnvv/++evToobZt2+qNN97QhQsXNH36dDVr1kybN29WSEiIJGnFihU6cOCAIiMjVb58ee3YsUOzZs3Sjh079Ouvv8pms6ljx47as2ePPvroI02aNMn+GuXKldPJkydznbtTp06qXLmyxo4dK8MwJEmvv/66hg0bps6dO6tXr146efKk/ve//6lFixbavHlznqa0Z2RkqF27dmrRooXGjRun+fPn64UXXpCXl5deffVVdevWTR07dtSMGTPUvXt3NWnSJMt0/RdeeEG+vr4aMWKEdu/erenTp+vw4cP2kitd/TBh5MiRioiI0HPPPWffbv369Vq9erWKFy9uf77Tp0+rXbt2euyxx/TEE08oICBArVq10osvvqiSJUvq1VdflSQFBARIkg4cOKDPP/9cnTp1UmhoqJKSkjRz5ky1bNlSO3fuVMWKFR3yxsbGysXFRQMHDlRycrLGjRunbt26ae3atfZtVqxYoYceekgVKlRQv379VL58ef3+++/6+uuv1a9fP0nSjh07dPfddyswMFBDhgyRl5eXFi5cqA4dOujTTz/Vo48+av/eY2Ji1KtXLzVq1EgpKSnasGGDNm3apPvuuy/Xf2cAgCLCAACgAHv33XcNSdneDMMwzp8/b/j6+hq9e/d2eFxiYqLh4+PjMH7hwoUsz//RRx8ZkoyffvrJPjZ+/HhDknHw4EGHbQ8ePGhIMt59990szyPJGD58uP3r4cOHG5KMxx9/3GG7Q4cOGa6ursbrr7/uML5t2zajWLFiWcavtz/Wr19vH+vRo4chyRg7dqx97OzZs0aJEiUMm81mfPzxx/bxXbt2Zcl67Tnr169vXL582T4+btw4Q5LxxRdfGIZhGCdOnDDc3NyMNm3aGBkZGfbtpkyZYkgy5syZYx9r2bKlIcmYMWNGlu/hzjvvNFq2bJll/OLFiw7PaxhX97m7u7sxatQo+9jKlSsNScYdd9xhXLp0yT7+1ltvGZKMbdu2GYZhGFeuXDFCQ0ON4OBg4+zZsw7Pm5mZaf/zvffea9SsWdO4ePGiw/1NmzY1KleubB+rXbu28eCDD2bJDQDAjTC9HADgFKZOnaoVK1Y43KSrRzLPnTunxx9/XKdOnbLfXF1d1bhxY61cudL+HCVKlLD/+eLFizp16pTuuusuSdKmTZtMyf3f//7X4evFixcrMzNTnTt3dshbvnx5Va5c2SFvbvXq1cv+Z19fX1WtWlVeXl7q3Lmzfbxq1ary9fXVgQMHsjz+mWeecThS/dxzz6lYsWJasmSJJOm7777T5cuX9dJLL8nF5a9fIXr37i1vb2998803Ds/n7u6uyMjIHOd3d3e3P29GRoZOnz6tkiVLqmrVqtn+/URGRsrNzc3+dfPmzSXJ/r1t3rxZBw8e1EsvvZRl9sC1I/dnzpzR999/r86dO+v8+fP2v4/Tp0+rbdu22rt3r44dOybp6j7dsWOH9u7dm+PvCQAAppcDAJxCo0aNsl1I7VoBat26dbaP8/b2tv/5zJkzGjlypD7++GOdOHHCYbvk5OR8TPuXf07h3rt3rwzDUOXKlbPd/u+lNzc8PDxUrlw5hzEfHx/ddttt9oL59/HsztX+Z6aSJUuqQoUKOnTokCTp8OHDkq4W979zc3NTWFiY/f5rAgMDHUrxv8nMzNRbb72ladOm6eDBg8rIyLDfV6ZMmSzb33777Q5f+/n5SZL9e9u/f7+kG69yv2/fPhmGoWHDhmnYsGHZbnPixAkFBgZq1KhReuSRR1SlShXVqFFD999/v5588knVqlUrx98jAKDooXQDAJxaZmampKvndZcvXz7L/cWK/fWjrnPnzvrll180aNAg1alTRyVLllRmZqbuv/9++/PcyD/L6zV/L4f/9Pej69fy2mw2ffvtt3J1dc2yfcmSJf81R3aye64bjRv///xyM/3ze/83Y8eO1bBhw/TUU09p9OjRKl26tFxcXPTSSy9l+/eTH9/btecdOHCg2rZtm+024eHhkqQWLVpo//79+uKLL7R8+XK98847mjRpkmbMmOEwywAAgL+jdAMAnFqlSpUkSf7+/oqIiLjudmfPnlV8fLxGjhyp6Oho+3h2U4WvV66vHUn950rd/zzC+295DcNQaGioqlSpkuPH3Qp79+7VPffcY/86NTVVx48f1wMPPCBJCg4OliTt3r1bYWFh9u0uX76sgwcP3nD//9319u+iRYt0zz33aPbs2Q7j586dsy9olxvX/m1s3779utmufR/FixfPUf7SpUsrMjJSkZGRSk1NVYsWLTRixAhKNwDgujinGwDg1Nq2bStvb2+NHTtW6enpWe6/tuL4taOi/zwKGhcXl+Ux166l/c9y7e3trbJly+qnn35yGJ82bVqO83bs2FGurq4aOXJkliyGYThcvuxWmzVrlsM+nD59uq5cuaJ27dpJkiIiIuTm5qbJkyc7ZJ89e7aSk5P14IMP5uh1vLy8suxb6erf0T/3ySeffGI/pzq36tWrp9DQUMXFxWV5vWuv4+/vr1atWmnmzJk6fvx4luf4+4r1//y7KVmypMLDw3Xp0qU85QMAFA0c6QYAODVvb29Nnz5dTz75pOrVq6fHHntM5cqVU0JCgr755hvdfffdmjJliry9ve2X00pPT1dgYKCWL1+ugwcPZnnO+vXrS5JeffVVPfbYYypevLgefvhheXl5qVevXoqNjVWvXr3UoEED/fTTT9qzZ0+O81aqVEljxozR0KFDdejQIXXo0EGlSpXSwYMH9dlnn+mZZ57RwIED823/5Mbly5d17733qnPnztq9e7emTZumZs2aqX379pKuXjZt6NChGjlypO6//361b9/
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot the feature importance\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"feature_importances = model.feature_importances_\n",
"# Sort the feature importances in descending order\n",
"sorted_idx = np.argsort(feature_importances)[::-1]\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"plt.title(\"Feature Importances\")\n",
"plt.bar(range(len(feature_importances)), feature_importances[sorted_idx], align=\"center\")\n",
"plt.xticks(range(len(feature_importances)), np.array(feature_names)[sorted_idx], rotation=90)\n",
"plt.xlim([-1, len(feature_importances)])\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIJElEQVR4nO3de5iVVd038O8wwIAgoOGAIjWZpwhPQRLioQOJSvRSpuQhFJXywJMxZYopaKZoecBKpVTEyh5J0x4KwwzjKRMjNUh7hTwhpoHgARRzkJn9/tHr1MRIiMO9gfl8rmtfl3vda+39u7mD1vXd6153RalUKgUAAAAACtSm3AUAAAAA0PoIpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQBaUEVFRc4777zG91OnTk1FRUUWLVpUtpoAADaW448/PjU1NW9pzOzZs1NRUZHZs2dvlJqAzYdQCtisvBHyvPFq27ZtevXqleOPPz7PPPNMucsDANjo/n0+1KFDh+y6664ZM2ZMli5dWu7yANZb23IXALAhvva1r+Xd7353Xnvttdx3332ZOnVq7rnnnjz88MPp0KFDucsDANjo/nU+dM899+Saa67JHXfckYcffjhbbbVVITVce+21aWhoeEtjDjzwwPz9739P+/btN1JVwOZCKAVslg499ND0798/SXLSSSele/fuueSSSzJ9+vQceeSRZa4OAGDj+/f50Dve8Y5cfvnl+Z//+Z8cddRRa/VftWpVOnXq1KI1tGvX7i2PadOmjR8RgSRu3wO2EAcccECS5PHHH29sW7BgQT796U9n2223TYcOHdK/f/9Mnz59rbEvvfRSxo4dm5qamlRVVWXHHXfMyJEjs3z58iTJ6tWrM378+PTr1y9du3ZNp06dcsABB+TXv/51MScHALAePvKRjyRJnnzyyRx//PHp3LlzHn/88Rx22GHZeuutc8wxxyRJGhoaMmnSpLzvfe9Lhw4d0qNHj3z+85/Piy++uNZn/uIXv8hBBx2UrbfeOl26dMkHPvCB/OhHP2o83tyeUjfffHP69evXOGaPPfbIlVde2Xj8zfaUuuWWW9KvX7907Ngx3bt3z7HHHrvW9gxvnNczzzyT4cOHp3Pnztluu+3y5S9/OfX19W/njw8oA6EUsEV4YyPxbbbZJkny5z//OR/84AfzyCOP5Kyzzspll12WTp06Zfjw4bn99tsbx73yyis54IAD8u1vfzsHH3xwrrzyypx88slZsGBB/vrXvyZJVq5cmeuuuy4f+tCHcskll+S8887LsmXLMmTIkMybN6/oUwUAaNYbP8694x3vSJKsWbMmQ4YMSXV1dS699NIcfvjhSZLPf/7zOeOMMzJo0KBceeWVGTVqVG666aYMGTIkr7/+euPnTZ06NUOHDs0LL7yQcePG5eKLL87ee++dmTNnvmkNd911V4466qhss802ueSSS3LxxRfnQx/6UH73u9+ts/apU6fmyCOPTGVlZSZOnJjRo0fntttuy/7775+XXnqpSd/6+voMGTIk73jHO3LppZfmoIMOymWXXZbvfe97G/LHBpSR2/eAzdKKFSuyfPnyvPbaa/n973+f888/P1VVVfn4xz+eJDn99NPzzne+M3/4wx9SVVWVJDn11FOz//7758wzz8wnP/nJJMk3v/nNPPzww7ntttsa25LknHPOSalUSvKPoGvRokVN9j0YPXp0dt9993z729/O9ddfX9RpAwA0+tf50O9+97t87WtfS8eOHfPxj388c+bMSV1dXY444ohMnDixccw999yT6667LjfddFOOPvroxvYPf/jDOeSQQ3LLLbfk6KOPzooVK/KFL3wh++67b2bPnt3kdrs35kjNmTFjRrp06ZI777wzlZWV63Uer7/+es4888z07ds3v/nNbxq/a//998/HP/7xXHHFFTn//PMb+7/22msZMWJEzj333CTJySefnPe///25/vrrc8opp6zfHx6wSbBSCtgsDR48ONttt1169+6dT3/60+nUqVOmT5+eHXfcMS+88ELuvvvuHHnkkXn55ZezfPnyLF++PM8//3yGDBmSRx99tHEp+E9+8pPstddeTQKpN1RUVCRJKisrGwOphoaGvPDCC1mzZk369++fBx98sLiTBgD4F/86H/rMZz6Tzp075/bbb0+vXr0a+/x7SHPLLbeka9eu+djHPtY4R1q+fHn69euXzp07N25PcNddd+Xll1/OWWedtdb+T2/MkZrTrVu3rFq1Knfdddd6n8f999+f5557LqeeemqT7xo6dGh23333zJgxY60xJ598cpP3BxxwQJ544on1/k5g02ClFLBZuuqqq7LrrrtmxYoVmTJlSn7zm980roh67LHHUiqVcu655zb+gvbvnnvuufTq1SuPP/5441L2dbnxxhtz2WWXZcGCBU2Wtb/73e9umRMCAHiL3pgPtW3bNj169Mhuu+2WNm3+ue6gbdu22XHHHZuMefTRR7NixYpUV1c3+5nPPfdckn/eCti3b9+3VNOpp56aH//4xzn00EPTq1evHHzwwTnyyCNzyCGHvOmYp556Kkmy2267rXVs9913zz333NOkrUOHDtluu+2atG2zzTbN7okFbNqEUsBmad9992182szw4cOz//775+ijj87ChQsbH0v85S9/OUOGDGl2/M4777ze3/XDH/4wxx9/fIYPH54zzjgj1dXVjfsd/OvG6gAARfrX+VBzqqqqmoRUyT9WfVdXV+emm25qdsy/hz1vVXV1debNm5c777wzv/jFL/KLX/wiN9xwQ0aOHJkbb7zxbX32G9b3tkBg0yeUAjZ7bwREH/7wh/Od73wnJ5xwQpJ/PKJ48ODB6xz7nve8Jw8//PA6+9x6663ZaaedcttttzVZrj5hwoS3XzwAQIHe85735Fe/+lUGDRqUjh07rrNfkjz88MNv6ce8JGnfvn2GDRuWYcOGpaGhIaeeemq++93v5txzz232s971rnclSRYuXNj4BME3LFy4sPE4sOWxpxSwRfjQhz6UfffdN5MmTUqXLl3yoQ99KN/97nfzt7/9ba2+y5Yta/zvww8/PPPnz2/yRL43vLGJ5xu/xv3rpp6///3vM2fOnJY+DQCAjerII49MfX19LrjggrWOrVmzpvFJdwcffHC23nrrTJw4Ma+99lqTfuva6Pz5559v8r5NmzbZc889kyR1dXXNjunfv3+qq6szefLkJn1+8Ytf5JFHHsnQoUPX69yAzY+VUsAW44wzzsgRRxyRqVOn5qqrrsr++++fPfbYI6NHj85OO+2UpUuXZs6cOfnrX/+a+fPnN4659dZbc8QRR+SEE05Iv3798sILL2T69OmZPHly9tprr3z84x9vfDrf0KFD8+STT2by5Mnp06dPXnnllTKfNQDA+jvooIPy+c9/PhMnTsy8efNy8MEHp127dnn00Udzyy235Morr8ynP/3pdOnSJVdccUVOOumkfOADH8jRRx+dbbbZJvPnz8+rr776prfinXTSSXnhhRfykY98JDvuuGOeeuqpfPvb387ee++d9773vc2OadeuXS655JKMGjUqBx10UI466qgsXbo0V155ZWpqajJ27NiN+UcClJFQCthifOpTn8p73vOeXHrppRk9enTuv//+nH/++Zk6dWqef/75VFdXZ5999sn48eMbx3Tu3Dm//e1vM2HChNx+++258cYbU11dnY9+9KONG4Mef/zxWbJkSb773e/mzjvvTJ8+ffLDH/4wt9xyS2bPnl2mswUA2DCTJ09Ov3798t3vfjdnn3122rZtm5qamhx77LEZNGhQY78TTzwx1dXVufjii3PBBRekXbt22X333dcZEh177LH53ve+l6uvvjovvfRSevbsmREjRuS8885ba3+rf3X88cdnq622ysUXX5wzzzwznTp1yic/+clccskl6datW0uePrAJqSita+0lAAAAAGwE9pQCAAAAoHBCKQAAAAAKJ5QCAAAAoHBCKQAAAAAKJ5QCAAAAoHBCKQAAAAAK17bcBRStoaE
"text/plain": [
"<Figure size 1200x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot recall and precision\n",
"# Calculate the recall and precision\n",
"recall = cm.diagonal() / cm.sum(axis=1)\n",
"precision = cm.diagonal() / cm.sum(axis=0)\n",
"\n",
"# plot in a bar chart\n",
"fig, ax = plt.subplots(1, 2, figsize=(12, 6))\n",
"ax[0].bar(range(num_classes), recall)\n",
"ax[0].set_xticks(range(num_classes))\n",
"ax[0].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
"ax[0].set_xlabel('Class')\n",
"ax[0].set_ylabel('Recall')\n",
"ax[0].set_title('Recall')\n",
"\n",
"ax[1].bar(range(num_classes), precision)\n",
"ax[1].set_xticks(range(num_classes))\n",
"ax[1].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
"ax[1].set_xlabel('Class')\n",
"ax[1].set_ylabel('Precision')\n",
"ax[1].set_title('Precision')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
2024-06-12 17:19:27 +02:00
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F1 Score: 0.8004770573896727\n"
]
}
],
"source": [
"# Calculate F1 Score for multiclass classification\n",
"f1 = f1_score(test_y, preds, average='macro')\n",
"\n",
"print('F1 Score:', f1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}