488 lines
109 KiB
Plaintext
488 lines
109 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Gradient Boosting Tree (GBT) Training and Analysis"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import sqlite3\n",
|
||
|
"import os\n",
|
||
|
"from datetime import datetime\n",
|
||
|
"from joblib import dump, load\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import xgboost as xgb\n",
|
||
|
"from sklearn.model_selection import GridSearchCV\n",
|
||
|
"from sklearn.metrics import confusion_matrix\n",
|
||
|
"from sklearn.ensemble import GradientBoostingClassifier\n",
|
||
|
"from sklearn.impute import SimpleImputer\n",
|
||
|
"from sklearn.metrics import accuracy_score\n",
|
||
|
"from sklearn.preprocessing import MinMaxScaler\n",
|
||
|
"\n",
|
||
|
"import seaborn as sns"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Import Data from Database"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# connect to the database\n",
|
||
|
"conn = sqlite3.connect('../features.db')\n",
|
||
|
"c = conn.cursor()\n",
|
||
|
"# get training, validation and test data\n",
|
||
|
"train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n",
|
||
|
"valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n",
|
||
|
"test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n",
|
||
|
"# close the connection\n",
|
||
|
"conn.close()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Format Data for Machine Learning"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 46,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"train_x shape: (3502, 10)\n",
|
||
|
"test_x shape: (438, 10)\n",
|
||
|
"valid_x shape: (438, 10)\n",
|
||
|
"features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n",
|
||
|
"number of classes: 4\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# get the target and features\n",
|
||
|
"train_y = train['y']\n",
|
||
|
"train_y = train_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
|
||
|
"train_x = train.drop(columns=['y'])\n",
|
||
|
"\n",
|
||
|
"valid_y = valid['y']\n",
|
||
|
"valid_y = valid_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
|
||
|
"valid_x = valid.drop(columns=['y'])\n",
|
||
|
"\n",
|
||
|
"test_y = test['y']\n",
|
||
|
"test_y = test_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
|
||
|
"test_x = test.drop(columns=['y'])\n",
|
||
|
"\n",
|
||
|
"# drop id column\n",
|
||
|
"train_x = train_x.drop(columns=['id'])\n",
|
||
|
"valid_x = valid_x.drop(columns=['id'])\n",
|
||
|
"test_x = test_x.drop(columns=['id'])\n",
|
||
|
"\n",
|
||
|
"print('train_x shape:', train_x.shape)\n",
|
||
|
"print('test_x shape:', test_x.shape)\n",
|
||
|
"print('valid_x shape:', valid_x.shape)\n",
|
||
|
"# print column names\n",
|
||
|
"print('features:', train_x.columns.to_list())\n",
|
||
|
"feature_names = train_x.columns.to_list()\n",
|
||
|
"\n",
|
||
|
"# Create an imputer object with a mean filling strategy\n",
|
||
|
"imputer = SimpleImputer(strategy='mean')\n",
|
||
|
"\n",
|
||
|
"train_x = imputer.fit_transform(train_x)\n",
|
||
|
"valid_x = imputer.transform(valid_x)\n",
|
||
|
"test_x = imputer.transform(test_x)\n",
|
||
|
"\n",
|
||
|
"# Scale Data between 0 and 1\n",
|
||
|
"scaler = MinMaxScaler()\n",
|
||
|
"# Fit the scaler to your data and then transform it\n",
|
||
|
"train_x = scaler.fit_transform(train_x)\n",
|
||
|
"valid_x = scaler.transform(valid_x)\n",
|
||
|
"test_x = scaler.transform(test_x)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# use xgboost\n",
|
||
|
"dtrain = xgb.DMatrix(train_x, label=train_y)\n",
|
||
|
"dvalid = xgb.DMatrix(valid_x, label=valid_y)\n",
|
||
|
"dtest = xgb.DMatrix(test_x, label=test_y)\n",
|
||
|
"\n",
|
||
|
"num_classes= len(set(valid_y.to_list()))\n",
|
||
|
"print('number of classes:', num_classes)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Test Grid for Hyperparameter Analysis"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"param_grid = {\n",
|
||
|
" 'n_estimators': [100, 200, 300],\n",
|
||
|
" 'learning_rate': [0.1, 0.2, 0.3],\n",
|
||
|
" 'max_depth': [1, 3, 5],\n",
|
||
|
"}# 'random_stat': 42"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Create a XGBClassifier object\n",
|
||
|
"model = GradientBoostingClassifier()\n",
|
||
|
"\n",
|
||
|
"# Create the grid search object\n",
|
||
|
"grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Training"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"CPU times: total: 3min 28s\n",
|
||
|
"Wall time: 4min 16s\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-4 {color: black;background-color: white;}#sk-container-id-4 pre{padding: 0;}#sk-container-id-4 div.sk-toggleable {background-color: white;}#sk-container-id-4 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-4 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-4 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-4 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-4 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-4 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-4 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-4 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-4 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-4 div.sk-item {position: relative;z-index: 1;}#sk-container-id-4 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-4 div.sk-item::before, #sk-container-id-4 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-4 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-4 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-4 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-4 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-4 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-4 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-4 div.sk-label-container {text-align: center;}#sk-container-id-4 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-4 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-4\" class=\"sk-top-container\
|
||
|
" param_grid={'learning_rate': [0.1, 0.2, 0.3],\n",
|
||
|
" 'max_depth': [1, 3, 5],\n",
|
||
|
" 'n_estimators': [100, 200, 300]},\n",
|
||
|
" scoring='accuracy')</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3, estimator=GradientBoostingClassifier(),\n",
|
||
|
" param_grid={'learning_rate': [0.1, 0.2, 0.3],\n",
|
||
|
" 'max_depth': [1, 3, 5],\n",
|
||
|
" 'n_estimators': [100, 200, 300]},\n",
|
||
|
" scoring='accuracy')</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" ><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier()</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" ><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier()</pre></div></div></div></div></div></div></div></div></div></div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"GridSearchCV(cv=3, estimator=GradientBoostingClassifier(),\n",
|
||
|
" param_grid={'learning_rate': [0.1, 0.2, 0.3],\n",
|
||
|
" 'max_depth': [1, 3, 5],\n",
|
||
|
" 'n_estimators': [100, 200, 300]},\n",
|
||
|
" scoring='accuracy')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 29,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"# Fit the grid search object to the data\n",
|
||
|
"grid_search.fit(train_x, train_y)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Results"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Best parameters: {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 100}\n",
|
||
|
"Best score: 0.796973125095374\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Print the best parameters and the best score\n",
|
||
|
"print(f'Best parameters: {grid_search.best_params_}')\n",
|
||
|
"print(f'Best score: {grid_search.best_score_}')\n",
|
||
|
"#{'eta': 0.1, 'learning_rate': 0.1, 'max_depth': 5, 'min_child_weight': 3, 'n_estimators': 100}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Save Model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"['../ml_models/best_gbt_model_20240611203442.joblib']"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 34,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Save the best model\n",
|
||
|
"best_model = grid_search.best_estimator_\n",
|
||
|
"# timestamp\n",
|
||
|
"timestamp = datetime.now().strftime('%Y%m%d%H%M%S')\n",
|
||
|
"model_path = f'../ml_models/best_gbt_model_{timestamp}.joblib'\n",
|
||
|
"dump(best_model, model_path)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Example Training of best Model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"load the best model to get the best hyperparameters from it"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# list directory\n",
|
||
|
"models = os.listdir('../ml_models')\n",
|
||
|
"model_path = [model for model in models if 'joblib' in model and 'best' in model and 'gbt' in model][0]\n",
|
||
|
"model_path = f'../ml_models/{model_path}'\n",
|
||
|
"# load the best model\n",
|
||
|
"best_model = load(model_path)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-5 {color: black;background-color: white;}#sk-container-id-5 pre{padding: 0;}#sk-container-id-5 div.sk-toggleable {background-color: white;}#sk-container-id-5 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-5 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-5 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-5 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-5 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-5 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-5 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-5 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-5 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-5 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-5 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-5 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-5 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-5 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-5 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-5 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-5 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-5 div.sk-item {position: relative;z-index: 1;}#sk-container-id-5 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-5 div.sk-item::before, #sk-container-id-5 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-5 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-5 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-5 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-5 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-5 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-5 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-5 div.sk-label-container {text-align: center;}#sk-container-id-5 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-5 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-5\" class=\"sk-top-container\
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"GradientBoostingClassifier()"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# example training of a model with the best parameters\n",
|
||
|
"eval_result = {}\n",
|
||
|
"model = GradientBoostingClassifier(**grid_search.best_params_)\n",
|
||
|
"model.fit(train_x, train_y)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 44,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Accuracy: 0.819634703196347\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"preds = best_model.predict(test_x)\n",
|
||
|
"accuracy = accuracy_score(test_y, preds)\n",
|
||
|
"print(f\"Model Accuracy: {accuracy}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Evaluate Model Performance"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 45,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAokAAAIjCAYAAABvUIGpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABZdElEQVR4nO3deZyN9fvH8feZYRYzZsZgNvu+L1myh8guoqT0DQmJCCkju5gSWUPZibJUQqVskRp7lmxZUzSMnTHGmLl/f/g5ddyU4Zy5Z5zXs8f9eHzvz32fz32dc750dX2WYzMMwxAAAADwDx5WBwAAAIC0hyQRAAAAJiSJAAAAMCFJBAAAgAlJIgAAAExIEgEAAGBCkggAAAATkkQAAACYkCQCAADAhCQRwL86ePCg6tWrp8DAQNlsNi1ZssSp/R87dkw2m02zZs1yar/pWa1atVSrVi2rwwDg5kgSgXTg8OHD6ty5s/Lnzy8fHx8FBASoWrVqGjdunOLj41367LZt22r37t0aPny45s6dqwoVKrj0eampXbt2stlsCggIuOPnePDgQdlsNtlsNo0aNSrF/Z88eVKDBw/Wjh07nBAtAKSuDFYHAODfff3113rmmWfk7e2tF198USVLltT169e1YcMG9enTR3v27NHHH3/skmfHx8crOjpab7/9trp16+aSZ+TJk0fx8fHKmDGjS/r/LxkyZNDVq1e1bNkytWrVyuHavHnz5OPjo2vXrt1X3ydPntSQIUOUN29elS1b9p5f9/3339/X8wDAmUgSgTTs6NGjat26tfLkyaM1a9YoPDzcfq1r1646dOiQvv76a5c9PzY2VpIUFBTksmfYbDb5+Pi4rP//4u3trWrVqunTTz81JYnz589X48aN9fnnn6dKLFevXlWmTJnk5eWVKs8DgH/DcDOQho0cOVJXrlzR9OnTHRLEWwoWLKgePXrYz2/cuKFhw4apQIEC8vb2Vt68edWvXz8lJCQ4vC5v3rxq0qSJNmzYoEcffVQ+Pj7Knz+/5syZY79n8ODBypMnjySpT58+stlsyps3r6Sbw7S3/vc/DR48WDabzaFt5cqVql69uoKCguTv768iRYqoX79+9ut3m5O4Zs0a1ahRQ35+fgoKClKzZs20b9++Oz7v0KFDateunYKCghQYGKj27dvr6tWrd/9gb/P888/r22+/1YULF+xtW7Zs0cGDB/X888+b7j937pzeeOMNlSpVSv7+/goICFDDhg21c+dO+z0//PCDKlasKElq3769fdj61vusVauWSpYsqW3btumxxx5TpkyZ7J/L7XMS27ZtKx8fH9P7r1+/vrJkyaKTJ0/e83sFgHtFkgikYcuWLVP+/PlVtWrVe7r/5Zdf1sCBA1WuXDmNGTNGNWvWVFRUlFq3bm2699ChQ3r66af1xBNPaPTo0cqSJYvatWunPXv2SJJatGihMWPGSJKee+45zZ07V2PHjk1R/Hv27FGTJk2UkJCgoUOHavTo0XryySf1008//evrVq1apfr16+v06dMaPHiwevXqpZ9//lnVqlXTsWPHTPe3atVKly9fVlRUlFq1aqVZs2ZpyJAh9xxnixYtZLPZ9MUXX9jb5s+fr6JFi6pcuXKm+48cOaIlS5aoSZMm+uCDD9SnTx/t3r1bNWvWtCdsxYoV09ChQyVJnTp10ty5czV37lw99thj9n7Onj2rhg0bqmzZsho7dqxq1659x/jGjRun7Nmzq23btkpKSpIkffTRR/r+++81YcIERURE3PN7BYB7ZgBIky5evGhIMpo1a3ZP9+/YscOQZLz88ssO7W+88YYhyVizZo29LU+ePIYkY/369fa206dPG97e3kbv3r3tbUePHjUkGe+//75Dn23btjXy5MljimHQoEHGP/9aGTNmjCHJiI2NvWvct54xc+ZMe1vZsmWNkJAQ4+zZs/a2nTt3Gh4eHsaLL75oet5LL73k0OdTTz1lZM2a9a7P/Of78PPzMwzDMJ5++mmjTp06hmEYRlJSkhEWFmYMGTLkjp/BtWvXjKSkJNP78Pb2NoYOHWpv27Jli+m93VKzZk1DkjFlypQ7XqtZs6ZD23fffWdIMt555x3jyJEjhr+/v9G8efP/fI8AcL+oJAJp1KVLlyRJmTNnvqf7v/nmG0lSr169HNp79+4tSaa5i8WLF1eNGjXs59mzZ1eRIkV05MiR+475drfmMn711VdKTk6+p9f89ddf2rFjh9q1a6fg4GB7e+nSpfXEE0/Y3+c/vfLKKw7nNWrU0NmzZ+2f4b14/vnn9cMPPygmJkZr1qxRTEzMHYeapZvzGD08bv71mZSUpLNnz9qH0rdv337Pz/T29lb79u3v6d569eqpc+fOGjp0qFq0aCEfHx999NFH9/wsAEgpkkQgjQoICJAkXb58+Z7u//333+Xh4aGCBQs6tIeFhSkoKEi///67Q3vu3LlNfWTJkkXnz5+/z4jNnn32WVWrVk0vv/yyQkND1bp1ay1cuPBfE8ZbcRYpUsR0rVixYjpz5ozi4uIc2m9/L1myZJGkFL2XRo0aKXPmzFqwYIHmzZunihUrmj7LW5KTkzVmzBgVKlRI3t7eypYtm7Jnz65du3bp4sWL9/zMHDlypGiRyqhRoxQcHKwdO3Zo/PjxCgkJuefXAkBKkSQCaVRAQIAiIiL066+/puh1ty8cuRtPT887thuGcd/PuDVf7hZfX1+tX79eq1at0v/+9z/t2rVLzz77rJ544gnTvQ/iQd7LLd7e3mrRooVmz56tL7/88q5VREkaMWKEevXqpccee0yffPKJvvvuO61cuVIlSpS454qpdPPzSYlffvlFp0+fliTt3r07Ra8FgJQiSQTSsCZNmujw4cOKjo7+z3vz5Mmj5ORkHTx40KH91KlTunDhgn2lsjNkyZLFYSXwLbdXKyXJw8NDderU0QcffKC9e/dq+PDhWrNmjdauXXvHvm/FeeDAAdO1/fv3K1u2bPLz83uwN3AXzz//vH755Rddvnz5jot9blm8eLFq166t6dOnq3Xr1qpXr57q1q1r+kzuNWG/F3FxcWrfvr2KFy+uTp06aeTIkdqyZYvT+geA25EkAmnYm2++KT8/P7388ss6deqU6frhw4c1btw4STeHSyWZViB/8MEHkqTGjRs7La4CBQro4sWL2rVrl73tr7/+0pdffulw37lz50yvvbWp9O3b8twSHh6usmXLavbs2Q5J16+//qrvv//e/j5doXbt2ho2bJgmTpyosLCwu97n6elpqlIuWrRIJ06ccGi7lczeKaFOqbfeekvHjx/X7Nmz9cEHHyhv3rxq27btXT9HAHhQbKYNpGEFChTQ/Pnz9eyzz6pYsWIOv7jy888/a9GiRWrXrp0kqUyZMmrbtq0+/vhjXbhwQTVr1tTmzZs1e/ZsNW/e/K7bq9yP1q1b66233tJTTz2l7t276+rVq5o8ebIKFy7ssHBj6NChWr9+vRo3bqw8efLo9OnTmjRpknLmzKnq1avftf/3339fDRs2VJUqVdShQwfFx8drwoQJCgwM1ODBg532Pm7n4eGh/v37/+d9TZo00dChQ9W+fXtVrVpVu3fv1rx585Q/f36H+woUKKCgoCBNmTJFmTNnlp+fnypVqqR8+fKlKK41a9Zo0qRJGjRokH1LnpkzZ6pWrVoaMGCARo4cmaL+AOCeWLy6GsA9+O2334yOHTsaefPmNby8vIzMmTMb1apVMyZMmGBcu3bNfl9iYqIxZMgQI1++fEbGjBmNXLlyGZGRkQ73GMbNLXAaN25ses7tW6/cbQscwzCM77//3ihZsqTh5eVlFClSxPjkk09MW+CsXr3aaNasmREREWF4eXkZERERxnPPPWf89ttvpmfcvk3MqlWrjGrVqhm+vr5GQECA0bRpU2Pv3r0O99x63u1b7MycOdOQZBw9evSun6lhOG6Bczd32wKnd+/eRnh4uOHr62tUq1bNiI6OvuPWNV999ZVRvHhxI0OGDA7vs2bNmkaJEiXu+Mx/9nPp0iUjT548Rrly5YzExESH+3r27Gl4eHgY0dHR//oeAOB+2AwjBTO7AQA
|
||
|
"text/plain": [
|
||
|
"<Figure size 800x600 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Get the confusion matrix\n",
|
||
|
"cm = confusion_matrix(test_y, preds)\n",
|
||
|
"\n",
|
||
|
"# Create a new figure\n",
|
||
|
"plt.figure(figsize=(8, 6))\n",
|
||
|
"\n",
|
||
|
"labels = ['GSVT', 'AFIB', 'SR', 'SB']\n",
|
||
|
"# Plot the confusion matrix\n",
|
||
|
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)\n",
|
||
|
"plt.xlabel('Predicted')\n",
|
||
|
"plt.ylabel('Actual')\n",
|
||
|
"plt.title('Confusion Matrix')\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 47,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAJOCAYAAACqS2TfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABfZUlEQVR4nO3deVhUdeP+8XtAAUFZXEAlYhGXzH1Nc8tIs8XM7+OSlkppPW2aqKmVuCamqeTjXppZVpqpLZppZJuZ5kYuuS+oCe4gmIpwfn/4c2oCDZDjYeD9uq65LvjMOTP3HHHgnnPO59gMwzAEAAAAAADynYvVAQAAAAAAKKwo3QAAAAAAmITSDQAAAACASSjdAAAAAACYhNINAAAAAIBJKN0AAAAAAJiE0g0AAAAAgEko3QAAAAAAmITSDQAAAACASSjdAAAAAACYhNINACjQ5s2bJ5vNlu1tyJAhpjznzz//rBEjRujcuXOmPP7NuLY9Nm7caHWUPJs+fbrmzZtndQwAAG6JYlYHAAAgJ0aNGqXQ0FCHsRo1apjyXD///LNGjhypXr16ydfX15TnKMqmT5+usmXLqlevXlZHAQDAdJRuAIBTaNeunRo0aGB1jJuSlpYmLy8vq2NY5sKFC/L09LQ6BgAAtxSHlwMACoWvvvpKzZs3l5eXl0qVKqUHH3xQO3bscFjmt99+U69evRQWFiYPDw+VL19eTz75pE6fPm1fZsSIERo0aJAkKTQ01H4o+6FDh3To0CHZbLZsD4222WwaMWKEw+PYbDbt3LlT3bp1k5+fn5o1a2a//4MPPlD9+vVVokQJlS5dWl27dtWRI0fy9Np79eqlkiVLKiEhQQ899JBKliypwMBATZs2TZK0bds2tW7dWl5eXgoODtaHH37osP61Q9Z/+OEHPfPMMypTpoy8vb3Vo0cPnT17NsvzTZ8+XXfeeafc3d1VsWJFPf/881kOxW/VqpVq1KihTZs2qUWLFvL09NQrr7yikJAQ7dixQ99//71927Zq1UqSdObMGQ0cOFA1a9ZUyZIl5e3trXbt2ik+Pt7hsb/77jvZbDYtWrRIr7/+um677TZ5eHjo3nvv1b59+7LkXb9+vR544AH5+fnJy8tLtWrV0ltvveWwzK5du/Sf//xHpUuXloeHhxo0aKDPP//cYZn09HSNHDlSlStXloeHh8qUKaNmzZpp9erVOfp3AgAUTezpBgA4heTkZJ06dcphrGzZspKk999/Xz179lTbtm31xhtv6MKFC5oxY4aaNWumLVu2KCQkRJK0evVqHThwQJGRkSpfvrx27Nih2bNna8eOHfrll19ks9nUsWNH7dmzRx999JEmT55sf45y5crp5MmTuc7dqVMnVa5cWWPHjpVhGJKk119/XcOGDVPnzp3Vu3dvnTx5Uv/73//UokULbdmyJU+HtGdkZKhdu3Zq0aKFxo8frwULFuiFF16Ql5eXXn31VXXv3l0dO3bUzJkz1aNHDzVp0iTL4fovvPCCfH19NWLECO3evVszZszQ4cOH7SVXuvphwsiRIxUREaFnn33Wvtyvv/6qtWvXqnjx4vbHO336tNq1a6euXbvq8ccfV0BAgFq1aqUXX3xRJUuW1KuvvipJCggIkCQdOHBAy5YtU6dOnRQaGqqkpCTNmjVLLVu21M6dO1WxYkWHvOPGjZOLi4sGDhyo5ORkjR8/Xt27d9f69evty6xevVoPPfSQKlSooH79+ql8+fL6/fff9eWXX6pfv36SpB07dujuu+9WYGCghgwZIi8vLy1atEgdOnTQp59+qkcffdT+2mNiYtS7d281atRIKSkp2rhxozZv3qz77rsv1/9mAIAiwgAAoAB79913DUnZ3gzDMM6fP2/4+voaffr0cVgvMTHR8PHxcRi/cOFClsf/6KOPDEnGDz/8YB+bMGGCIck4ePCgw7IHDx40JBnvvvtulseRZAwfPtz+/fDhww1JxmOPPeaw3KFDhwxXV1fj9ddfdxjftm2bUaxYsSzj19sev/76q32sZ8+ehiRj7Nix9rGzZ88aJUqUMGw2m/Hxxx/bx3ft2pUl67XHrF+/vnH58mX7+Pjx4w1JxmeffWYYhmGcOHHCcHNzM9q0aWNkZGTYl5s6daohyZg7d659rGXLloYkY+bMmVlew5133mm0bNkyy/jFixcdHtcwrm5zd3d3Y9SoUfaxNWvWGJKMO+64w7h06ZJ9/K233jIkGdu2bTMMwzCuXLlihIaGGsHBwcbZs2cdHjczM9P+9b333mvUrFnTuHjxosP9TZs2NSpXrmwfq127tvHggw9myQ0AwI1weDkAwClMmzZNq1evdrhJV/dknjt3To899phOnTplv7m6uqpx48Zas2aN/TFKlChh//rixYs6deqU7rrrLknS5s2bTcn93//+1+H7JUuWKDMzU507d3bIW758eVWuXNkhb2717t3b/rWvr6+qVq0qLy8vde7c2T5etWpV+fr66sCBA1nWf/rppx32VD/77LMqVqyYVqxYIUn65ptvdPnyZb300ktycfnrT4g+ffrI29tby5cvd3g8d3d3RUZG5ji/u7u7/XEzMjJ0+vRplSxZUlWrVs323ycyMlJubm7275s3by5J9te2ZcsWHTx4UC+99FKWoweu7bk/c+aMvv32W3Xu3Fnnz5+3/3ucPn1abdu21d69e3Xs2DFJV7fpjh07tHfv3hy/JgAAOLwcAOAUGjVqlO1EatcKUOvWrbNdz9vb2/71mTNnNHLkSH388cc6ceKEw3LJycn5mPYv/zyEe+/evTIMQ5UrV852+b+X3tzw8PBQuXLlHMZ8fHx022232Qvm38ezO1f7n5lKliypChUq6NChQ5Kkw4cPS7pa3P/Ozc1NYWFh9vuvCQwMdCjF/yYzM1NvvfWWpk+froMHDyojI8N+X5kyZbIsf/vttzt87+fnJ0n217Z//35JN57lft++fTIMQ8OGDdOwYcOyXebEiRMKDAzUqFGj9Mgjj6hKlSqqUaOG7r//fj3xxBOqVatWjl8jAKDooXQDAJxaZmampKvndZcvXz7L/cWK/fWrrnPnzvr55581aNAg1alTRyVLllRmZqbuv/9+++PcyD/L6zV/L4f/9Pe969fy2mw2ffXVV3J1dc2yfMmSJf81R3aye6wbjRv///xyM/3ztf+bsWPHatiwYXryySc1evRolS5dWi4uLnrppZey/ffJj9d27XEHDhyotm3bZrtMeHi4JKlFixbav3+/PvvsM61atUrvvPOOJk+erJkzZzocZQAAwN9RugEATq1SpUqSJH9/f0VERFx3ubNnzyouLk4jR45UdHS0fTy7Q4WvV66v7Un950zd/9zD+295DcNQaGioqlSpkuP1boW9e/fqnnvusX+fmpqq48eP64EHHpAkBQcHS5J2796tsLAw+3KXL1/WwYMHb7j9/+5623fx4sW65557NGfOHIfxc+fO2Se0y41rPxvbt2+/brZrr6N48eI5yl+6dGlFRkYqMjJSqampatGihUaMGEHpBgBcF+d0AwCcWtu2beXt7a2xY8cqPT09y/3XZhy/tlf0n3tBY2Njs6xz7Vra/yzX3t7eKlu2rH744QeH8enTp+c4b8eOHeXq6qqRI0dmyWIYhsPly2612bNnO2zDGTNm6MqVK2rXrp0kKSIiQm5ubpoyZYpD9jlz5ig5OVkPPvhgjp7Hy8sry7aVrv4b/XObfPLJJ/ZzqnOrXr16Cg0NVWxsbJbnu/Y8/v7+atWqlWbNmqXjx49neYy/z1j/z3+bkiVLKjw8XJcuXcpTPgBA0cCebgCAU/P29taMGTP0xBNPqF69euratavKlSunhIQELV++XHfffbemTp0qb29v++W00tPTFRgYqFWrVungwYNZHrN+/fqSpFdffVVdu3ZV8eLF9fDDD8vLy0u9e/fWuHHj1Lt3bzVo0EA//PCD9uzZk+O8lSpV0pgxYzR06FAdOnRIHTp0UKlSpXTw4EEtXbpUTz/9tAYOHJhv2yc3Ll++rHvvvVedO3fW7t27NX36dDVr1kzt27eXdPWyaUOHDtXIkSN1//33q3379vb
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x600 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# plot the feature importance\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"feature_importances = model.feature_importances_\n",
|
||
|
"# Sort the feature importances in descending order\n",
|
||
|
"sorted_idx = np.argsort(feature_importances)[::-1]\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(10, 6))\n",
|
||
|
"plt.title(\"Feature Importances\")\n",
|
||
|
"plt.bar(range(len(feature_importances)), feature_importances[sorted_idx], align=\"center\")\n",
|
||
|
"plt.xticks(range(len(feature_importances)), np.array(feature_names)[sorted_idx], rotation=90)\n",
|
||
|
"plt.xlim([-1, len(feature_importances)])\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 48,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIJElEQVR4nO3de5iVVd038O8wwIAgoOGAIjWZpwhPQRLioQOJSvRSpuQhFJXywJMxZYopaKZoecBKpVTEyh5J0x4KwwzjKRMjNUh7hTwhpoHgARRzkJn9/tHr1MRIiMO9gfl8rmtfl3vda+39u7mD1vXd6153RalUKgUAAAAACtSm3AUAAAAA0PoIpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQBaUEVFRc4777zG91OnTk1FRUUWLVpUtpoAADaW448/PjU1NW9pzOzZs1NRUZHZs2dvlJqAzYdQCtisvBHyvPFq27ZtevXqleOPPz7PPPNMucsDANjo/n0+1KFDh+y6664ZM2ZMli5dWu7yANZb23IXALAhvva1r+Xd7353Xnvttdx3332ZOnVq7rnnnjz88MPp0KFDucsDANjo/nU+dM899+Saa67JHXfckYcffjhbbbVVITVce+21aWhoeEtjDjzwwPz9739P+/btN1JVwOZCKAVslg499ND0798/SXLSSSele/fuueSSSzJ9+vQceeSRZa4OAGDj+/f50Dve8Y5cfvnl+Z//+Z8cddRRa/VftWpVOnXq1KI1tGvX7i2PadOmjR8RgSRu3wO2EAcccECS5PHHH29sW7BgQT796U9n2223TYcOHdK/f/9Mnz59rbEvvfRSxo4dm5qamlRVVWXHHXfMyJEjs3z58iTJ6tWrM378+PTr1y9du3ZNp06dcsABB+TXv/51MScHALAePvKRjyRJnnzyyRx//PHp3LlzHn/88Rx22GHZeuutc8wxxyRJGhoaMmnSpLzvfe9Lhw4d0qNHj3z+85/Piy++uNZn/uIXv8hBBx2UrbfeOl26dMkHPvCB/OhHP2o83tyeUjfffHP69evXOGaPPfbIlVde2Xj8zfaUuuWWW9KvX7907Ngx3bt3z7HHHrvW9gxvnNczzzyT4cOHp3Pnztluu+3y5S9/OfX19W/njw8oA6EUsEV4YyPxbbbZJkny5z//OR/84AfzyCOP5Kyzzspll12WTp06Zfjw4bn99tsbx73yyis54IAD8u1vfzsHH3xwrrzyypx88slZsGBB/vrXvyZJVq5cmeuuuy4f+tCHcskll+S8887LsmXLMmTIkMybN6/oUwUAaNYbP8694x3vSJKsWbMmQ4YMSXV1dS699NIcfvjhSZLPf/7zOeOMMzJo0KBceeWVGTVqVG666aYMGTIkr7/+euPnTZ06NUOHDs0LL7yQcePG5eKLL87ee++dmTNnvmkNd911V4466qhss802ueSSS3LxxRfnQx/6UH73u9+ts/apU6fmyCOPTGVlZSZOnJjRo0fntttuy/7775+XXnqpSd/6+voMGTIk73jHO3LppZfmoIMOymWXXZbvfe97G/LHBpSR2/eAzdKKFSuyfPnyvPbaa/n973+f888/P1VVVfn4xz+eJDn99NPzzne+M3/4wx9SVVWVJDn11FOz//7758wzz8wnP/nJJMk3v/nNPPzww7ntttsa25LknHPOSalUSvKPoGvRokVN9j0YPXp0dt9993z729/O9ddfX9RpAwA0+tf50O9+97t87WtfS8eOHfPxj388c+bMSV1dXY444ohMnDixccw999yT6667LjfddFOOPvroxvYPf/jDOeSQQ3LLLbfk6KOPzooVK/KFL3wh++67b2bPnt3kdrs35kjNmTFjRrp06ZI777wzlZWV63Uer7/+es4888z07ds3v/nNbxq/a//998/HP/7xXHHFFTn//PMb+7/22msZMWJEzj333CTJySefnPe///25/vrrc8opp6zfHx6wSbBSCtgsDR48ONttt1169+6dT3/60+nUqVOmT5+eHXfcMS+88ELuvvvuHHnkkXn55ZezfPnyLF++PM8//3yGDBmSRx99tHEp+E9+8pPstddeTQKpN1RUVCRJKisrGwOphoaGvPDCC1mzZk369++fBx98sLiTBgD4F/86H/rMZz6Tzp075/bbb0+vXr0a+/x7SHPLLbeka9eu+djHPtY4R1q+fHn69euXzp07N25PcNddd+Xll1/OWWedtdb+T2/MkZrTrVu3rFq1Knfdddd6n8f999+f5557LqeeemqT7xo6dGh23333zJgxY60xJ598cpP3BxxwQJ544on1/k5g02ClFLBZuuqqq7LrrrtmxYoVmTJlSn7zm980roh67LHHUiqVcu655zb+gvbvnnvuufTq1SuPP/5441L2dbnxxhtz2WWXZcGCBU2Wtb/73e9umRMCAHiL3pgPtW3bNj169Mhuu+2WNm3+ue6gbdu22XHHHZuMefTRR7NixYpUV1c3+5nPPfdckn/eCti3b9+3VNOpp56aH//4xzn00EPTq1evHHzwwTnyyCNzyCGHvOmYp556Kkmy2267rXVs9913zz333NOkrUOHDtluu+2atG2zzTbN7okFbNqEUsBmad9992182szw4cOz//775+ijj87ChQsbH0v85S9/OUOGDGl2/M4777ze3/XDH/4wxx9/fIYPH54zzjgj1dXVjfsd/OvG6gAARfrX+VBzqqqqmoRUyT9WfVdXV+emm25qdsy/hz1vVXV1debNm5c777wzv/jFL/KLX/wiN9xwQ0aOHJkbb7zxbX32G9b3tkBg0yeUAjZ7bwREH/7wh/Od73wnJ5xwQpJ/PKJ48ODB6xz7nve8Jw8//PA6+9x6663ZaaedcttttzVZrj5hwoS3XzwAQIHe85735Fe/+lUGDRqUjh07rrNfkjz88MNv6ce8JGnfvn2GDRuWYcOGpaGhIaeeemq++93v5txzz232s971rnclSRYuXNj4BME3LFy4sPE4sOWxpxSwRfjQhz6UfffdN5MmTUqXLl3yoQ99KN/97nfzt7/9ba2+y5Yta/zvww8/PPPnz2/yRL43vLGJ5xu/xv3rpp6///3vM2fOnJY+DQCAjerII49MfX19LrjggrWOrVmzpvFJdwcffHC23nrrTJw4Ma+99lqTfuva6Pz5559v8r5NmzbZc889kyR1dXXNjunfv3+qq6szefLkJn1+8Ytf5JFHHsnQoUPX69yAzY+VUsAW44wzzsgRRxyRqVOn5qqrrsr++++fPfbYI6NHj85OO+2UpUuXZs6cOfnrX/+a+fPnN4659dZbc8QRR+SEE05Iv3798sILL2T69OmZPHly9tprr3z84x9vfDrf0KFD8+STT2by5Mnp06dPXnnllTKfNQDA+jvooIPy+c9/PhMnTsy8efNy8MEHp127dnn00Udzyy235Morr8ynP/3pdOnSJVdccUVOOumkfOADH8jRRx+dbbbZJvPnz8+rr776prfinXTSSXnhhRfykY98JDvuuGOeeuqpfPvb387ee++d9773vc2OadeuXS655JKMGjUqBx10UI466qgsXbo0V155ZWpqajJ27NiN+UcClJFQCthifOpTn8p73vOeXHrppRk9enTuv//+nH/++Zk6dWqef/75VFdXZ5999sn48eMbx3Tu3Dm//e1vM2HChNx+++258cYbU11dnY9+9KONG4Mef/zxWbJkSb773e/mzjvvTJ8+ffLDH/4wt9xyS2bPnl2mswUA2DCTJ09Ov3798t3vfjdnn3122rZtm5qamhx77LEZNGhQY78TTzwx1dXVufjii3PBBRekXbt22X333dcZEh177LH53ve+l6uvvjovvfRSevbsmREjRuS8885ba3+rf3X88cdnq622ysUXX5wzzzwznTp1yic/+clccskl6datW0uePrAJqSita+0lAAAAAGwE9pQCAAAAoHBCKQAAAAAKJ5QCAAAAoHBCKQAAAAAKJ5QCAAAAoHBCKQAAAAAK17bcBRStoaE
|
||
|
"text/plain": [
|
||
|
"<Figure size 1200x600 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# plot recall and precision\n",
|
||
|
"# Calculate the recall and precision\n",
|
||
|
"recall = cm.diagonal() / cm.sum(axis=1)\n",
|
||
|
"precision = cm.diagonal() / cm.sum(axis=0)\n",
|
||
|
"\n",
|
||
|
"# plot in a bar chart\n",
|
||
|
"fig, ax = plt.subplots(1, 2, figsize=(12, 6))\n",
|
||
|
"ax[0].bar(range(num_classes), recall)\n",
|
||
|
"ax[0].set_xticks(range(num_classes))\n",
|
||
|
"ax[0].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
|
||
|
"ax[0].set_xlabel('Class')\n",
|
||
|
"ax[0].set_ylabel('Recall')\n",
|
||
|
"ax[0].set_title('Recall')\n",
|
||
|
"\n",
|
||
|
"ax[1].bar(range(num_classes), precision)\n",
|
||
|
"ax[1].set_xticks(range(num_classes))\n",
|
||
|
"ax[1].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
|
||
|
"ax[1].set_xlabel('Class')\n",
|
||
|
"ax[1].set_ylabel('Precision')\n",
|
||
|
"ax[1].set_title('Precision')\n",
|
||
|
"\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.4"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|