DSA_SS24/notebooks/ml_xgboost.ipynb

719 lines
179 KiB
Plaintext
Raw Permalink Normal View History

2024-06-05 23:18:20 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Extreme Gradient Boosting (XGBoost) Training and Analysis"
]
},
2024-06-05 23:18:20 +02:00
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 2,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [],
"source": [
"import sqlite3\n",
"import os\n",
"from datetime import datetime\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import xgboost as xgb\n",
"from sklearn.model_selection import GridSearchCV\n",
2024-06-12 17:19:27 +02:00
"from sklearn.metrics import confusion_matrix, f1_score\n",
"import seaborn as sns\n",
"import numpy as np"
2024-06-05 23:18:20 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Data from Database"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 7,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [],
"source": [
"# connect to the database\n",
"conn = sqlite3.connect('../features.db')\n",
"c = conn.cursor()\n",
"# get training, validation and test data\n",
"train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n",
"valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n",
"test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n",
"# close the connection\n",
"conn.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Format Data for Machine Learning"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 8,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_x shape: (3502, 10)\n",
"test_x shape: (438, 10)\n",
"valid_x shape: (438, 10)\n",
"features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n",
"number of classes: 4\n"
]
}
],
"source": [
"# get the target and features\n",
"train_y = train['y']\n",
"train_y = train_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"train_x = train.drop(columns=['y'])\n",
"\n",
"valid_y = valid['y']\n",
"valid_y = valid_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"valid_x = valid.drop(columns=['y'])\n",
"\n",
"test_y = test['y']\n",
"test_y = test_y.map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"test_x = test.drop(columns=['y'])\n",
"\n",
"# drop id column\n",
"train_x = train_x.drop(columns=['id'])\n",
"valid_x = valid_x.drop(columns=['id'])\n",
"test_x = test_x.drop(columns=['id'])\n",
"\n",
"print('train_x shape:', train_x.shape)\n",
"print('test_x shape:', test_x.shape)\n",
"print('valid_x shape:', valid_x.shape)\n",
"\n",
"# print column names\n",
"print('features:', train_x.columns.to_list())\n",
"\n",
"# use xgboost\n",
"dtrain = xgb.DMatrix(train_x, label=train_y)\n",
"dvalid = xgb.DMatrix(valid_x, label=valid_y)\n",
"dtest = xgb.DMatrix(test_x, label=test_y)\n",
"\n",
"num_classes= len(set(valid_y.to_list()))\n",
"print('number of classes:', num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test Grid for Hyperparameter Analysis"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"param_grid = {\n",
" 'max_depth': [3, 4, 5],\n",
" 'min_child_weight': [1, 2, 3],\n",
" 'eta': [0.1, 0.2, 0.3],\n",
" 'learning_rate': [0.1, 0.2, 0.3],\n",
" 'n_estimators': [100, 200, 300]\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"# Create a XGBClassifier object\n",
"model = xgb.XGBClassifier(objective='multi:softmax', num_class=num_classes, eval_metric='merror')\n",
"\n",
"# Create the grid search object\n",
"grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: total: 2h 15min 58s\n",
"Wall time: 10min\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\
" estimator=XGBClassifier(base_score=None, booster=None,\n",
" callbacks=None, colsample_bylevel=None,\n",
" colsample_bynode=None,\n",
" colsample_bytree=None,\n",
" early_stopping_rounds=None,\n",
" enable_categorical=False,\n",
" eval_metric=&#x27;merror&#x27;, gamma=None,\n",
" gpu_id=None, grow_policy=None,\n",
" importance_type=None,\n",
" interaction_constraints=None,\n",
" learning_rate=None, max_bin=None,\n",
" ma...\n",
" max_leaves=None, min_child_weight=None,\n",
" missing=nan, monotone_constraints=None,\n",
" n_estimators=100, n_jobs=None, num_class=4,\n",
" num_parallel_tree=None,\n",
" objective=&#x27;multi:softmax&#x27;, predictor=None,\n",
" random_state=None, ...),\n",
" param_grid={&#x27;eta&#x27;: [0.1, 0.2, 0.3],\n",
" &#x27;learning_rate&#x27;: [0.1, 0.2, 0.3],\n",
" &#x27;max_depth&#x27;: [3, 4, 5], &#x27;min_child_weight&#x27;: [1, 2, 3],\n",
" &#x27;n_estimators&#x27;: [100, 200, 300]},\n",
" scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=3,\n",
" estimator=XGBClassifier(base_score=None, booster=None,\n",
" callbacks=None, colsample_bylevel=None,\n",
" colsample_bynode=None,\n",
" colsample_bytree=None,\n",
" early_stopping_rounds=None,\n",
" enable_categorical=False,\n",
" eval_metric=&#x27;merror&#x27;, gamma=None,\n",
" gpu_id=None, grow_policy=None,\n",
" importance_type=None,\n",
" interaction_constraints=None,\n",
" learning_rate=None, max_bin=None,\n",
" ma...\n",
" max_leaves=None, min_child_weight=None,\n",
" missing=nan, monotone_constraints=None,\n",
" n_estimators=100, n_jobs=None, num_class=4,\n",
" num_parallel_tree=None,\n",
" objective=&#x27;multi:softmax&#x27;, predictor=None,\n",
" random_state=None, ...),\n",
" param_grid={&#x27;eta&#x27;: [0.1, 0.2, 0.3],\n",
" &#x27;learning_rate&#x27;: [0.1, 0.2, 0.3],\n",
" &#x27;max_depth&#x27;: [3, 4, 5], &#x27;min_child_weight&#x27;: [1, 2, 3],\n",
" &#x27;n_estimators&#x27;: [100, 200, 300]},\n",
" scoring=&#x27;accuracy&#x27;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, early_stopping_rounds=None,\n",
" enable_categorical=False, eval_metric=&#x27;merror&#x27;, gamma=None,\n",
" gpu_id=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_to_onehot=None, max_delta_step=None, max_depth=None,\n",
" max_leaves=None, min_child_weight=None, missing=nan,\n",
" monotone_constraints=None, n_estimators=100, n_jobs=None,\n",
" num_class=4, num_parallel_tree=None, objective=&#x27;multi:softmax&#x27;,\n",
" predictor=None, random_state=None, ...)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">XGBClassifier</label><div class=\"sk-toggleable__content\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, early_stopping_rounds=None,\n",
" enable_categorical=False, eval_metric=&#x27;merror&#x27;, gamma=None,\n",
" gpu_id=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_to_onehot=None, max_delta_step=None, max_depth=None,\n",
" max_leaves=None, min_child_weight=None, missing=nan,\n",
" monotone_constraints=None, n_estimators=100, n_jobs=None,\n",
" num_class=4, num_parallel_tree=None, objective=&#x27;multi:softmax&#x27;,\n",
" predictor=None, random_state=None, ...)</pre></div></div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"GridSearchCV(cv=3,\n",
" estimator=XGBClassifier(base_score=None, booster=None,\n",
" callbacks=None, colsample_bylevel=None,\n",
" colsample_bynode=None,\n",
" colsample_bytree=None,\n",
" early_stopping_rounds=None,\n",
" enable_categorical=False,\n",
" eval_metric='merror', gamma=None,\n",
" gpu_id=None, grow_policy=None,\n",
" importance_type=None,\n",
" interaction_constraints=None,\n",
" learning_rate=None, max_bin=None,\n",
" ma...\n",
" max_leaves=None, min_child_weight=None,\n",
" missing=nan, monotone_constraints=None,\n",
" n_estimators=100, n_jobs=None, num_class=4,\n",
" num_parallel_tree=None,\n",
" objective='multi:softmax', predictor=None,\n",
" random_state=None, ...),\n",
" param_grid={'eta': [0.1, 0.2, 0.3],\n",
" 'learning_rate': [0.1, 0.2, 0.3],\n",
" 'max_depth': [3, 4, 5], 'min_child_weight': [1, 2, 3],\n",
" 'n_estimators': [100, 200, 300]},\n",
" scoring='accuracy')"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# Fit the grid search object to the data\n",
"grid_search.fit(train_x, train_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Results"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters: {'eta': 0.1, 'learning_rate': 0.1, 'max_depth': 5, 'min_child_weight': 3, 'n_estimators': 100}\n",
"Best score: 0.8012537024646579\n"
]
}
],
"source": [
"# Print the best parameters and the best score\n",
"print(f'Best parameters: {grid_search.best_params_}')\n",
"print(f'Best score: {grid_search.best_score_}')\n",
"#{'eta': 0.1, 'learning_rate': 0.1, 'max_depth': 5, 'min_child_weight': 3, 'n_estimators': 100}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# Save the best model\n",
"best_model = grid_search.best_estimator_\n",
"# timestamp\n",
"timestamp = datetime.now().strftime('%Y%m%d%H%M%S')\n",
"best_model.save_model(f'../ml_models/best_xgb_model_{timestamp}.json')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example Training of best Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load the best model to get the best hyperparameters from it"
2024-06-05 23:18:20 +02:00
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 3,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"best model: {'best_iteration': '99', 'best_ntree_limit': '100', 'scikit_learn': '{\"use_label_encoder\": false, \"n_estimators\": 100, \"objective\": \"multi:softmax\", \"max_depth\": 5, \"max_leaves\": null, \"max_bin\": null, \"grow_policy\": null, \"learning_rate\": 0.1, \"verbosity\": null, \"booster\": null, \"tree_method\": null, \"gamma\": null, \"min_child_weight\": 3, \"max_delta_step\": null, \"subsample\": null, \"sampling_method\": null, \"colsample_bytree\": null, \"colsample_bylevel\": null, \"colsample_bynode\": null, \"reg_alpha\": null, \"reg_lambda\": null, \"scale_pos_weight\": null, \"base_score\": null, \"missing\": NaN, \"num_parallel_tree\": null, \"random_state\": null, \"n_jobs\": null, \"monotone_constraints\": null, \"interaction_constraints\": null, \"importance_type\": null, \"gpu_id\": null, \"validate_parameters\": null, \"predictor\": null, \"enable_categorical\": false, \"max_cat_to_onehot\": null, \"eval_metric\": \"merror\", \"early_stopping_rounds\": null, \"callbacks\": null, \"kwargs\": {\"num_class\": 4, \"eta\": 0.1}, \"classes_\": [0, 1, 2, 3], \"n_classes_\": 4, \"_estimator_type\": \"classifier\"}'}\n"
]
}
],
"source": [
"# list directory\n",
"models = os.listdir('../ml_models')\n",
"model_path = [model for model in models if 'json' in model and 'best' in model and 'xgb' in model][0]\n",
"model_path = f'../ml_models/{model_path}'\n",
"# load the best model\n",
"best_model = xgb.Booster()\n",
"best_model.load_model(model_path)\n",
"best_params = best_model.attributes()\n",
"print('best model:', best_params)"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 9,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-06-12 17:19:27 +02:00
"[16:58:49] WARNING: C:/Users/administrator/workspace/xgboost-win64_release_1.6.0/src/learner.cc:627: \n",
2024-06-05 23:18:20 +02:00
"Parameters: { \"best_iteration\", \"best_ntree_limit\", \"scikit_learn\" } might not be used.\n",
"\n",
" This could be a false alarm, with some parameters getting used by language bindings but\n",
" then being mistakenly passed down to XGBoost core, or some parameter actually being used\n",
" but getting flagged wrongly here. Please open an issue if you find any such cases.\n",
"\n",
"\n",
2024-06-12 17:19:27 +02:00
"[0]\ttrain-merror:0.16762\teval-merror:0.22603\n",
2024-06-05 23:18:20 +02:00
"[1]\ttrain-merror:0.15220\teval-merror:0.22374\n",
"[2]\ttrain-merror:0.13849\teval-merror:0.21461\n",
"[3]\ttrain-merror:0.13535\teval-merror:0.20776\n",
"[4]\ttrain-merror:0.13278\teval-merror:0.20091\n",
"[5]\ttrain-merror:0.12907\teval-merror:0.20548\n",
"[6]\ttrain-merror:0.12307\teval-merror:0.20320\n",
"[7]\ttrain-merror:0.11850\teval-merror:0.20320\n",
"[8]\ttrain-merror:0.11422\teval-merror:0.19406\n",
"[9]\ttrain-merror:0.10965\teval-merror:0.20091\n",
"[10]\ttrain-merror:0.10280\teval-merror:0.20320\n",
"[11]\ttrain-merror:0.09880\teval-merror:0.19406\n",
"[12]\ttrain-merror:0.09423\teval-merror:0.19406\n",
"[13]\ttrain-merror:0.09109\teval-merror:0.19863\n",
"[14]\ttrain-merror:0.08709\teval-merror:0.19863\n",
"[15]\ttrain-merror:0.08195\teval-merror:0.19863\n",
"[16]\ttrain-merror:0.07910\teval-merror:0.20091\n",
"[17]\ttrain-merror:0.07624\teval-merror:0.19635\n",
"[18]\ttrain-merror:0.06967\teval-merror:0.19863\n",
"[19]\ttrain-merror:0.06710\teval-merror:0.19406\n",
"[20]\ttrain-merror:0.06254\teval-merror:0.19178\n",
"[21]\ttrain-merror:0.06025\teval-merror:0.19863\n",
"[22]\ttrain-merror:0.05682\teval-merror:0.20091\n",
"[23]\ttrain-merror:0.05311\teval-merror:0.20091\n",
"[24]\ttrain-merror:0.05168\teval-merror:0.20320\n",
"[25]\ttrain-merror:0.04940\teval-merror:0.19406\n",
"[26]\ttrain-merror:0.04597\teval-merror:0.20091\n",
"[27]\ttrain-merror:0.04397\teval-merror:0.19863\n",
"[28]\ttrain-merror:0.04112\teval-merror:0.19863\n",
"[29]\ttrain-merror:0.04026\teval-merror:0.19863\n",
"[30]\ttrain-merror:0.03769\teval-merror:0.19635\n",
"[31]\ttrain-merror:0.03712\teval-merror:0.20091\n",
"[32]\ttrain-merror:0.03626\teval-merror:0.20320\n",
"[33]\ttrain-merror:0.03541\teval-merror:0.20091\n",
"[34]\ttrain-merror:0.03370\teval-merror:0.19863\n",
"[35]\ttrain-merror:0.03113\teval-merror:0.19635\n",
"[36]\ttrain-merror:0.02970\teval-merror:0.19635\n",
"[37]\ttrain-merror:0.02798\teval-merror:0.19406\n",
"[38]\ttrain-merror:0.02713\teval-merror:0.19178\n",
"[39]\ttrain-merror:0.02513\teval-merror:0.18950\n",
"[40]\ttrain-merror:0.02370\teval-merror:0.19178\n",
"[41]\ttrain-merror:0.02199\teval-merror:0.18950\n",
"[42]\ttrain-merror:0.01885\teval-merror:0.19406\n",
"[43]\ttrain-merror:0.01828\teval-merror:0.19406\n",
"[44]\ttrain-merror:0.01799\teval-merror:0.19178\n",
"[45]\ttrain-merror:0.01628\teval-merror:0.18950\n",
"[46]\ttrain-merror:0.01656\teval-merror:0.18950\n",
"[47]\ttrain-merror:0.01428\teval-merror:0.19178\n",
"[48]\ttrain-merror:0.01314\teval-merror:0.19406\n",
"[49]\ttrain-merror:0.01199\teval-merror:0.19406\n",
"[50]\ttrain-merror:0.01114\teval-merror:0.19406\n",
"[51]\ttrain-merror:0.01028\teval-merror:0.19178\n",
"[52]\ttrain-merror:0.00885\teval-merror:0.19635\n",
"[53]\ttrain-merror:0.00885\teval-merror:0.19635\n",
"[54]\ttrain-merror:0.00857\teval-merror:0.19635\n",
"[55]\ttrain-merror:0.00771\teval-merror:0.19178\n",
"[56]\ttrain-merror:0.00685\teval-merror:0.19178\n",
"[57]\ttrain-merror:0.00657\teval-merror:0.19178\n",
"[58]\ttrain-merror:0.00514\teval-merror:0.19178\n",
"[59]\ttrain-merror:0.00428\teval-merror:0.19178\n",
"[60]\ttrain-merror:0.00400\teval-merror:0.19178\n",
"[61]\ttrain-merror:0.00343\teval-merror:0.19406\n",
"[62]\ttrain-merror:0.00371\teval-merror:0.18950\n",
"[63]\ttrain-merror:0.00314\teval-merror:0.19178\n",
"[64]\ttrain-merror:0.00257\teval-merror:0.18950\n",
"[65]\ttrain-merror:0.00228\teval-merror:0.18950\n",
"[66]\ttrain-merror:0.00228\teval-merror:0.18950\n",
"[67]\ttrain-merror:0.00200\teval-merror:0.18950\n",
"[68]\ttrain-merror:0.00200\teval-merror:0.18721\n",
"[69]\ttrain-merror:0.00200\teval-merror:0.18493\n",
"[70]\ttrain-merror:0.00171\teval-merror:0.18950\n",
"[71]\ttrain-merror:0.00143\teval-merror:0.18721\n",
"[72]\ttrain-merror:0.00114\teval-merror:0.18493\n",
"[73]\ttrain-merror:0.00114\teval-merror:0.18493\n",
"[74]\ttrain-merror:0.00114\teval-merror:0.18493\n",
"[75]\ttrain-merror:0.00114\teval-merror:0.18721\n",
"[76]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[77]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[78]\ttrain-merror:0.00057\teval-merror:0.18493\n",
"[79]\ttrain-merror:0.00057\teval-merror:0.18950\n",
"[80]\ttrain-merror:0.00057\teval-merror:0.19178\n",
"[81]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[82]\ttrain-merror:0.00057\teval-merror:0.18950\n",
"[83]\ttrain-merror:0.00057\teval-merror:0.19178\n",
"[84]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[85]\ttrain-merror:0.00057\teval-merror:0.18950\n",
"[86]\ttrain-merror:0.00057\teval-merror:0.18493\n",
"[87]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[88]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[89]\ttrain-merror:0.00057\teval-merror:0.18721\n",
"[90]\ttrain-merror:0.00057\teval-merror:0.18493\n",
"[91]\ttrain-merror:0.00029\teval-merror:0.18493\n",
"[92]\ttrain-merror:0.00029\teval-merror:0.18493\n",
"[93]\ttrain-merror:0.00029\teval-merror:0.18265\n",
"[94]\ttrain-merror:0.00029\teval-merror:0.18493\n",
"[95]\ttrain-merror:0.00029\teval-merror:0.18037\n",
"[96]\ttrain-merror:0.00029\teval-merror:0.18265\n",
"[97]\ttrain-merror:0.00029\teval-merror:0.18265\n",
"[98]\ttrain-merror:0.00029\teval-merror:0.18265\n",
"[99]\ttrain-merror:0.00029\teval-merror:0.18265\n",
2024-06-12 17:19:27 +02:00
"CPU times: total: 14.3 s\n",
"Wall time: 1.22 s\n"
2024-06-05 23:18:20 +02:00
]
}
],
"source": [
"%%time\n",
"# train the models\n",
"# add the best parameters to the model\n",
"#best_params = grid_search.best_params_.copy()\n",
"best_params['objective'] = 'multi:softmax'\n",
"best_params['eval_metric'] = 'merror'\n",
"best_params['num_class'] = num_classes\n",
"\n",
"num_round = 100\n",
"\n",
"# Train the model and get the training history\n",
"evals_result = {}\n",
"model = xgb.train(best_params, dtrain, num_round, evals=[(dtrain, 'train'), (dvalid, 'eval')], evals_result=evals_result)"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 10,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAHqCAYAAAAZLi26AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAACJ70lEQVR4nOzdd3RU5drG4XsmPaRQ0ggt9N5LpINSBERRRECUIooFROV4VGxYjvDZEBUUG6IC0kEUBREURHrvvUNICJAKqbO/PzYZiAmQyqT8rrWyzOzZ5ZnMJuaet1kMwzAEAAAAAADynNXRBQAAAAAAUFQRugEAAAAAyCeEbgAAAAAA8gmhGwAAAACAfELoBgAAAAAgnxC6AQAAAADIJ4RuAAAAAADyCaEbAAAAAIB8QugGAAAAACCfELoBoJgZPHiwQkJCcnTsG2+8IYvFkrcFAdcxdepUWSwWHTt2zL6tQ4cO6tChw02P/euvv2SxWPTXX3/laU0Wi0VvvPFGnp4TAFC0EboBoICwWCxZ+srrEFEYPfDAA7JYLHrxxRcdXQokJScny8/PT23atLnuPoZhqEKFCmrSpMktrCxnfv311wIdrF944QVZLBb17dvX0aUAALLAYhiG4egiAADStGnT0j3+/vvvtWzZMv3www/ptnfu3FmBgYE5vk5ycrJsNpvc3NyyfWxKSopSUlLk7u6e4+vnVkxMjAIDAxUUFKTU1FQdP36c1vcC4Mknn9QXX3yho0ePqlKlShmeX7lypTp06KAPP/xQo0aNytI5p06dqiFDhujo0aP23hlJSUmSJFdX1xse+9dff6ljx476888/s9Qyfq0RI0Zo0qRJyuxPpISEBDk7O8vZ2Tlb58wrhmGoYsWKcnZ2Vnh4uMLDw+Xt7e2QWgAAWeOY/2MAADJ46KGH0j1et26dli1blmH7v126dEmenp5Zvo6Li0uO6pPk0LCRZt68eUpNTdWUKVN0++23a9WqVWrfvr1Da8qMYRhKSEiQh4eHo0u5JQYMGKDJkyfrxx9/1EsvvZTh+RkzZshqtapfv365us7NwnZ+c+QHTpL5YcKpU6e0YsUKde3aVfPnz9egQYMcWtP1ZPd3EwAUVXQvB4BCpEOHDqpXr542b96sdu3aydPTUy+//LIk6aefflKPHj0UHBwsNzc3Va1aVW+//bZSU1PTnePfY7qPHTsmi8WiDz74QF9++aWqVq0qNzc3NW/eXBs3bkx3bGZjui0Wi0aMGKGFCxeqXr16cnNzU926dbVkyZIM9f/1119q1qyZ3N3dVbVqVX3xxRfZHic+ffp0de7cWR07dlTt2rU1ffr0TPfbt2+fHnjgAfn7+8vDw0M1a9bUK6+8km6f06dPa+jQofafWeXKlfXkk0/aW1OvV1tmY41DQkJ01113aenSpWrWrJk8PDz0xRdfSJK+/fZb3X777QoICJCbm5vq1Kmjzz//PNO6f/vtN7Vv317e3t7y8fFR8+bNNWPGDEnSmDFj5OLionPnzmU4btiwYSpZsqQSEhIyPe8HH3wgi8Wi48ePZ3hu9OjRcnV11cWLFyVJBw8eVO/evRUUFCR3d3eVL19e/fr1U3R0dKbnlqTWrVsrJCTEXuu1kpOTNXfuXHXs2FHBwcHasWOHBg8erCpVqsjd3V1BQUF65JFHdP78+eueP01mY7pPnTqlXr16qUSJEgoICNBzzz2nxMTEDMf+/fff6tOnjypWrCg3NzdVqFBBzz33nC5fvmzfZ/DgwZo0aZKk9EM+0mQ2pnvr1q3q1q2bfHx85OXlpTvuuEPr1q1Lt0/aPfPPP/9o1KhR8vf3V4kSJXTvvfdm+n5ez/Tp01WnTh117NhRnTp1uu79f7N7W5KioqL03HPPKSQkRG5ubipfvrwGDhyoyMjIdDVfe59LmY+Xz4vfTZK0fv16de/eXaVKlVKJEiXUoEEDffzxx5LMf0cWi0Vbt27NcNzYsWPl5OSk06dPZ/lnCQC3Ci3dAFDInD9/Xt26dVO/fv300EMP2buaT506VV5eXho1apS8vLy0YsUKvf7664qJidH7779/0/POmDFDsbGxevzxx2WxWPTee+/pvvvu05EjR27aOr569WrNnz9fTz31lLy9vfXJJ5+od+/eOnHihMqUKSPJDCZ33nmnypYtqzfffFOpqal666235O/vn+XXfubMGf3555/67rvvJEn9+/fXRx99pIkTJ6ZrAd2xY4fatm0rFxcXDRs2TCEhITp8+LB+/vlnvfPOO/ZztWjRQlFRURo2bJhq1aql06dPa+7cubp06VKOWlT379+v/v376/HHH9djjz2mmjVrSpI+//xz1a1bV3fffbecnZ31888/66mnnpLNZtPw4cPtx0+dOlWPPPKI6tatq9GjR6tkyZLaunWrlixZogcffFAPP/yw3nrrLc2aNUsjRoywH5eUlKS5c+eqd+/e122JfeCBB/TCCy9o9uzZ+u9//5vuudmzZ6tLly4qVaqUkpKS1LVrVyUmJurpp59WUFCQTp8+rV9++UVRUVHy9fXN9PwWi0UPPvigxo4dq927d6tu3br255YsWaILFy5owIABkqRly5bpyJEjGjJkiIKCgrR79259+eWX2r17t9atW5etD2EuX76sO+64QydOnNDIkSMVHBysH374QStWrMiw75w5c3Tp0iU9+eSTKlOmjDZs2KBPP/1Up06d0pw5cyRJjz/+uM6cOZPp0I7M7N69W23btpWPj49eeOEFubi46IsvvlCHDh20cuVKhYaGptv/6aefVqlSpTRmzBgdO3ZMEyZM0IgRIzRr1qybXisxMVHz5s3Tf/7zH0nm/T9kyBCdPXtWQUFB9v2ycm/HxcWpbdu22rt3rx555BE1adJEkZGRWrRokU6dOiU/P7+b1vNvuf3dtGzZMt11110qW7asnnnmGQUFBWnv3r365Zdf9Mwzz+j+++/X8OHDNX36dDVu3DjdtadPn64OHTqoXLly2a4bAPKdAQAokIYPH278+9d0+/btDUnG5MmTM+x/6dKlDNsef/xxw9PT00hISLBvGzRokFGpUiX746NHjxqSjDJlyhgXLlywb//pp58MScbPP/9s3zZmzJgMNUkyXF1djUOHDtm3bd++3ZBkfPrpp/ZtPXv2NDw9PY3Tp0/btx08eNBwdnbOcM7r+eCDDwwPDw8jJibGMAzDOHDggCHJWLBgQbr92rVrZ3h7exvHjx9Pt91ms9m/HzhwoGG1Wo2NGzdmuE7afpm9XsMwjG+//daQZBw9etS+rVKlSoYkY8mSJRn2z+y96dq1q1GlShX746ioKMPb29sIDQ01Ll++fN26W7ZsaYSGhqZ7fv78+YYk488//8xwnWu1bNnSaNq0abptGzZsMCQZ33//vWEYhrF161ZDkjFnzpwbniszu3fvNiQZo0ePTre9X79+hru7uxEdHW0YRuY/jx9//NGQZKxatcq+LbOfc/v27Y327dvbH0+YMMGQZMyePdu+LT4+3qhWrVqGn0lm1x03bpxhsVjS3SuZ/dtLI8kYM2aM/XGvXr0MV1dX4/Dhw/ZtZ86cMby9vY127dpleC2dOnVK934+99xzhpOTkxEVFZXp9a41d+5cQ5Jx8OBBwzAMIyYmxnB3dzc++uijdPtl5d5+/fXXDUnG/Pnzr7tPZj9/wzCMP//8M8PPNre/m1JSUozKlSsblSpVMi5evJhpPYZhGP379zeCg4ON1NRU+7YtW7YYkoxvv/02w3UAoCCgezkAFDJubm4aMmRIhu3Xjh2OjY1VZGSk2rZtq0uXLmnfvn03PW/fvn1VqlQp++O2bdtKko4cOXLTYzt16qSqVavaHzdo0EA+Pj72Y1NTU/XHH3+oV69eCg4Otu9XrVo1devW7abnTzN9+nT16NHDPnFU9erV1bRp03RdbM+dO6dVq1bpkUceUcWKFdMdn9aCarPZtHDhQvXs2VPNmjXLcJ2cTsxWuXJlde3aNcP2a9+b6OhoRUZGqn379jpy5Ii9y/ayZcsUGxurl156KUNr9bX1DBw4UOvXr9fhw4ft26ZPn64KFSrcdGx73759tXnz5nT
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Get the training loss and validation loss\n",
"train_loss = evals_result['train']['merror']\n",
"valid_loss = evals_result['eval']['merror']\n",
"\n",
"# Calculate the training accuracy and validation accuracy\n",
"train_accuracy = [1 - x for x in train_loss]\n",
"valid_accuracy = [1 - x for x in valid_loss]\n",
"\n",
"# Create a new figure\n",
"fig = plt.figure(figsize=(10, 5))\n",
"\n",
"# Plot the training accuracy and validation accuracy\n",
"plt.plot(train_accuracy, label='Train')\n",
"plt.plot(valid_accuracy, label='Validation')\n",
"plt.xlabel('Iteration')\n",
"plt.ylabel('Accuracy')\n",
"plt.title('Training Accuracy vs Validation Accuracy')\n",
"plt.legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 11,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8310502283105022\n"
]
}
],
"source": [
"# Get the accuracy of the model\n",
"preds = model.predict(dtest)\n",
"correct = 0\n",
"for i in range(len(test_y)):\n",
" if preds[i] == test_y[i]:\n",
" correct += 1\n",
"accuracy = correct / len(test_y)\n",
"print('Accuracy:', accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate Model Performance"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 12,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAokAAAIjCAYAAABvUIGpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABaXklEQVR4nO3dd3gUVdvH8d8mIQkkJCFAmlTpVaRIFxCUXgQL6qOACD4IUkUJ0hGiiFIFFKkKSFERUFAEAdHQq/QqCiZ0AkkIKfP+wcs+rgNIYDeTsN+P11wXe2bmzL27Jtzc58wZm2EYhgAAAIC/8bA6AAAAAGQ+JIkAAAAwIUkEAACACUkiAAAATEgSAQAAYEKSCAAAABOSRAAAAJiQJAIAAMCEJBEAAAAmJIkAbuvQoUN64oknFBgYKJvNpsWLFzu1/+PHj8tms2nmzJlO7Tcrq1u3rurWrWt1GADcHEkikAUcOXJEr776qh588EH5+voqICBANWvW1Lhx45SYmOjSa7dr1067d+/WiBEj9Nlnn6ly5couvV5Gat++vWw2mwICAm76OR46dEg2m002m02jR49Od/+nTp3SkCFDtGPHDidECwAZy8vqAADc3rfffqunn35aPj4+eumll1S2bFldu3ZN69evV9++fbVnzx598sknLrl2YmKioqOj9fbbb6tbt24uuUbBggWVmJiobNmyuaT/f+Pl5aWEhAQtXbpUzzzzjMO+OXPmyNfXV1evXr2rvk+dOqWhQ4eqUKFCqlChwh2f98MPP9zV9QDAmUgSgUzs2LFjatu2rQoWLKjVq1crPDzcvq9r1646fPiwvv32W5dd/8yZM5KkoKAgl13DZrPJ19fXZf3/Gx8fH9WsWVPz5s0zJYlz585V06ZN9eWXX2ZILAkJCcqRI4e8vb0z5HoAcDsMNwOZ2KhRo3TlyhVNmzbNIUG8oWjRourRo4f9dUpKioYPH64iRYrIx8dHhQoVUv/+/ZWUlORwXqFChdSsWTOtX79ejzzyiHx9ffXggw9q9uzZ9mOGDBmiggULSpL69u0rm82mQoUKSbo+THvjz383ZMgQ2Ww2h7aVK1eqVq1aCgoKkr+/v0qUKKH+/fvb999qTuLq1atVu3Zt+fn5KSgoSC1bttS+fftuer3Dhw+rffv2CgoKUmBgoDp06KCEhIRbf7D/8Pzzz2v58uW6ePGivW3z5s06dOiQnn/+edPx58+f1xtvvKFy5crJ399fAQEBaty4sXbu3Gk/Zs2aNapSpYokqUOHDvZh6xvvs27duipbtqy2bt2qRx99VDly5LB/Lv+ck9iuXTv5+vqa3n/Dhg2VK1cunTp16o7fKwDcKZJEIBNbunSpHnzwQdWoUeOOjn/llVc0aNAgVaxYUWPGjFGdOnUUFRWltm3bmo49fPiwnnrqKT3++OP64IMPlCtXLrVv31579uyRJLVu3VpjxoyRJD333HP67LPPNHbs2HTFv2fPHjVr1kxJSUkaNmyYPvjgA7Vo0UK//PLLbc/78ccf1bBhQ50+fVpDhgxR79699euvv6pmzZo6fvy46fhnnnlGly9fVlRUlJ555hnNnDlTQ4cOveM4W7duLZvNpq+++sreNnfuXJUsWVIVK1Y0HX/06FEtXrxYzZo104cffqi+fftq9+7dqlOnjj1hK1WqlIYNGyZJ6ty5sz777DN99tlnevTRR+39nDt3To0bN1aFChU0duxY1atX76bxjRs3Tnnz5lW7du2UmpoqSfr444/1ww8/aMKECYqIiLjj9woAd8wAkCldunTJkGS0bNnyjo7fsWOHIcl45ZVXHNrfeOMNQ5KxevVqe1vBggUNSca6devsbadPnzZ8fHyMPn362NuOHTtmSDLef/99hz7btWtnFCxY0BTD4MGDjb//WhkzZowhyThz5swt475xjRkzZtjbKlSoYISEhBjnzp2zt+3cudPw8PAwXnrpJdP1Xn75ZYc+n3zySSN37ty3vObf34efn59hGIbx1FNPGfXr1zcMwzBSU1ONsLAwY+jQoTf9DK5evWqkpqaa3oePj48xbNgwe9vmzZtN7+2GOnXqGJKMKVOm3HRfnTp1HNq+//57Q5LxzjvvGEePHjX8/f2NVq1a/et7BIC7RSURyKTi4uIkSTlz5ryj47/77jtJUu/evR3a+/TpI0mmuYulS5dW7dq17a/z5s2rEiVK6OjRo3cd8z/dmMv4zTffKC0t7Y7O+euvv7Rjxw61b99ewcHB9vby5cvr8ccft7/Pv/vvf//r8Lp27do6d+6c/TO8E88//7zWrFmjmJgYrV69WjExMTcdapauz2P08Lj+6zM1NVXnzp2zD6Vv27btjq/p4+OjDh063NGxTzzxhF599VUNGzZMrVu3lq+vrz7++OM7vhYApBdJIpBJBQQESJIuX758R8f//vvv8vDwUNGiRR3aw8LCFBQUpN9//92hvUCBAqY+cuXKpQsXLtxlxGbPPvusatasqVdeeUWhoaFq27atFixYcNuE8UacJUqUMO0rVaqUzp49q/j4eIf2f76XXLlySVK63kuTJk2UM2dOzZ8/X3PmzFGVKlVMn+UNaWlpGjNmjIoVKyYfHx/lyZNHefPm1a5du3Tp0qU7vuYDDzyQrptURo8ereDgYO3YsUPjx49XSEjIHZ8LAOlFkghkUgEBAYqIiNBvv/2WrvP+eePIrXh6et603TCMu77GjflyN2TPnl3r1q3Tjz/+qBdffFG7du3Ss88+q8cff9x07L24l/dyg4+Pj1q3bq1Zs2bp66+/vmUVUZJGjhyp3r1769FHH9Xnn3+u77//XitXrlSZMmXuuGIqXf980mP79u06ffq0JGn37t3pOhcA0oskEcjEmjVrpiNHjig6Ovpfjy1YsKDS0tJ06NAhh/bY2FhdvHjRfqeyM+TKlcvhTuAb/lmtlCQPDw/Vr19fH374ofbu3asRI0Zo9erV+umnn27a9404Dxw4YNq3f/9+5cmTR35+fvf2Bm7h+eef1/bt23X58uWb3uxzw6JFi1SvXj1NmzZNbdu21RNPPKEGDRqYPpM7TdjvRHx8vDp06KDSpUurc+fOGjVqlDZv3uy0/gHgn0gSgUzszTfflJ+fn1555RXFxsaa9h85ckTjxo2TdH24VJLpDuQPP/xQktS0aVOnxVWkSBFdunRJu3btsrf99ddf+vrrrx2OO3/+vOncG4tK/3NZnhvCw8NVoUIFzZo1yyHp+u233/TDDz/Y36cr1KtXT8OHD9fEiRMVFhZ2y+M8PT1NVcqFCxfq5MmTDm03ktmbJdTp9dZbb+nEiROaNWuWPvzwQxUqVEjt2rW75ecIAPeKxbSBTKxIkSKaO3eunn32WZUqVcrhiSu//vqrFi5cqPbt20uSHnroIbVr106ffPKJLl68qDp16mjTpk2aNWuWWrVqdcvlVe5G27Zt9dZbb+nJJ59U9+7dlZCQoMmTJ6t48eION24MGzZM69atU9OmTVWwYEGdPn1akyZNUr58+VSrVq1b9v/++++rcePGql69ujp27KjExERNmDBBgYGBGjJkiNPexz95eHhowIAB/3pcs2bNNGzYMHXo0EE1atTQ7t27NWfOHD344IMOxxUpUkRBQUGaMmWKcubMKT8/P1WtWlWFCxdOV1yrV6/WpEmTNHjwYPuSPDNmzFDdunU1cOBAjRo1Kl39AcAdsfjuagB34ODBg0anTp2MQoUKGd7e3kbOnDmNmjVrGhMmTDCuXr1qPy45OdkYOnSoUbhwYSNbtmxG/vz5jcjISIdjDOP6EjhNmzY1XeefS6/cagkcwzCMH374wShbtqzh7e1tlChRwvj8889NS+CsWrXKaNmypREREWF4e3sbERERxnPPPWccPHjQdI1/LhPz448/GjVr1jSyZ89uBAQEGM2bNzf27t3rcMyN6/1ziZ0ZM2YYkoxjx47d8jM1DMclcG7lVkvg9OnTxwgPDzeyZ89u1KxZ04iOjr7p0jXffPONUbp0acPLy8vhfdapU8coU6bMTa/5937i4uKMggULGhUrVjSSk5MdjuvVq5fh4eFhREdH3/Y9AMDdsBl
"text/plain": [
"<Figure size 800x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Get the confusion matrix\n",
"cm = confusion_matrix(test_y, preds)\n",
"\n",
"# Create a new figure\n",
"plt.figure(figsize=(8, 6))\n",
"\n",
"labels = ['GSVT', 'AFIB', 'SR', 'SB']\n",
"# Plot the confusion matrix\n",
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)\n",
"plt.xlabel('Predicted')\n",
"plt.ylabel('Actual')\n",
"plt.title('Confusion Matrix')\n",
"plt.show()"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 13,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:title={'center':'Feature importance'}, xlabel='F score', ylabel='Features'>"
]
},
2024-06-12 17:19:27 +02:00
"execution_count": 13,
2024-06-05 23:18:20 +02:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAApMAAAHHCAYAAADj4dOBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAACIR0lEQVR4nOzdeVyN6f8/8Nc5Lae9lFba7HsiDBMyomQZjDGMocY2tjFZsoyRsmVn7GPshvGZxWAIE2MZxmRnhMiWQYytVNSpc/3+8Ov+OjqlbpFz5vV8PHrovq7rvu73+5TTu+teUgghBIiIiIiIZFCWdgBEREREpL9YTBIRERGRbCwmiYiIiEg2FpNEREREJBuLSSIiIiKSjcUkEREREcnGYpKIiIiIZGMxSURERESysZgkIiIiItlYTBIR/YetXr0aCoUC165dK+1QiEhPsZgkov+UvOJJ18eYMWNeyzH//PNPREVF4dGjR69l/v+yzMxMREVFYd++faUdCtF/lnFpB0BEVBomTpwIb29vrbZatWq9lmP9+eefiI6ORlhYGOzs7F7LMeTq2bMnunXrBpVKVdqhyJKZmYno6GgAQEBAQOkGQ/QfxWKSiP6T2rRpAz8/v9IO45VkZGTA0tLyleYwMjKCkZFRCUX05mg0GmRnZ5d2GEQEnuYmItJpx44daNq0KSwtLWFtbY22bdsiISFBa8yZM2cQFhaGChUqwMzMDC4uLujduzfu378vjYmKikJERAQAwNvbWzqlfu3aNVy7dg0KhQKrV6/Od3yFQoGoqCiteRQKBc6dO4ePP/4YZcqUgb+/v9T/3XffoX79+jA3N4e9vT26deuGGzduvDRPXddMenl5oV27dti3bx/8/Pxgbm6O2rVrS6eSN23ahNq1a8PMzAz169fHyZMnteYMCwuDlZUVrly5gqCgIFhaWsLNzQ0TJ06EEEJrbEZGBkaMGAF3d3eoVCpUrVoVs2bNyjdOoVBgyJAhWL9+PWrWrAmVSoWlS5fC0dERABAdHS29tnmvW1G+Ps+/tklJSdLqsa2tLT799FNkZmbme82+++47NGzYEBYWFihTpgyaNWuG3377TWtMUb5/iAwFVyaJ6D8pNTUV9+7d02orW7YsAGDdunUIDQ1FUFAQpk+fjszMTCxZsgT+/v44efIkvLy8AABxcXG4cuUKPv30U7i4uCAhIQHLli1DQkIC/vrrLygUCnTu3BkXL17E999/j7lz50rHcHR0xL///lvsuD/88ENUrlwZU6dOlQquKVOmYPz48ejatSv69u2Lf//9FwsWLECzZs1w8uRJWafWk5KS8PHHH+Ozzz7DJ598glmzZqF9+/ZYunQpvvzySwwaNAgAEBMTg65duyIxMRFK5f+tT+Tm5iI4OBjvvPMOZsyYgZ07d2LChAnIycnBxIkTAQBCCHTo0AF79+5Fnz59ULduXezatQsRERG4efMm5s6dqxXT77//jh9++AFDhgxB2bJl4ePjgyVLlmDgwIHo1KkTOnfuDACoU6cOgKJ9fZ7XtWtXeHt7IyYmBidOnMDy5cvh5OSE6dOnS2Oio6MRFRWFJk2aYOLEiTA1NUV8fDx+//13tG7dGkDRv3+IDIYgIvoPWbVqlQCg80MIIR4/fizs7OxEv379tPZLSUkRtra2Wu2ZmZn55v/+++8FAHHgwAGpbebMmQKAuHr1qtbYq1evCgBi1apV+eYBICZMmCBtT5gwQQAQ3bt31xp37do1YWRkJKZMmaLV/vfffwtjY+N87QW9Hs/H5unpKQCIP//8U2rbtWuXACDMzc3F9evXpfZvvvlGABB79+6V2kJDQwUA8fnnn0ttGo1GtG3bVpiamop///1XCCHE5s2bBQAxefJkrZi6dOkiFAqFSEpK0no9lEqlSEhI0Br777//5nut8hT165P32vbu3VtrbKdOnYSDg4O0fenSJaFUKkWnTp1Ebm6u1liNRiOEKN73D5Gh4GluIvpPWrRoEeLi4rQ+gGerWY8ePUL37t1x79496cPIyAiNGjXC3r17pTnMzc2lz58+fYp79+7hnXfeAQCcOHHitcQ9YMAAre1NmzZBo9Gga9euWvG6uLigcuXKWvEWR40aNdC4cWNpu1GjRgCA9957Dx4eHvnar1y5km+OIUOGSJ/nnabOzs7G7t27AQCxsbEwMjLC0KFDtfYbMWIEhBDYsWOHVnvz5s1Ro0aNIudQ3K/Pi69t06ZNcf/+faSlpQEANm/eDI1Gg8jISK1V2Lz8gOJ9/xAZCp7mJqL/pIYNG+q8AefSpUsAnhVNutjY2EifP3jwANHR0di4cSPu3r2rNS41NbUEo/0/L96BfunSJQghULlyZZ3jTUxMZB3n+YIRAGxtbQEA7u7uOtsfPnyo1a5UKlGhQgWttipVqgCAdH3m9evX4ebmBmtra61x1atXl/qf92LuL1Pcr8+LOZcpUwbAs9xsbGxw+fJlKJXKQgva4nz/EBkKFpNERM/RaDQAnl335uLikq/f2Pj/3ja7du2KP//8ExEREahbty6srKyg0WgQHBwszVOYF6/Zy5Obm1vgPs+vtuXFq1AosGPHDp13ZVtZWb00Dl0KusO7oHbxwg0zr8OLub9Mcb8+JZFbcb5/iAwFv6uJiJ5TsWJFAICTkxMCAwMLHPfw4UPs2bMH0dHRiIyMlNrzVqaeV1DRmLfy9eLDzF9ckXtZvEIIeHt7Syt/bwONRoMrV65oxXTx4kUAkG5A8fT0xO7du/H48WOt1ckLFy5I/S9T0GtbnK9PUVWsWBEajQbnzp1D3bp1CxwDvPz7h8iQ8JpJIqLnBAUFwcbGBlOnToVarc7Xn3cHdt4q1ourVvPmzcu3T96zIF8sGm1sbFC2bFkcOHBAq33x4sVFjrdz584wMjJCdHR0vliEEPkeg/MmLVy4UCuWhQsXwsTEBC1btgQAhISEIDc3V2scAMydOxcKhQJt2rR56TEsLCwA5H9ti/P1KaqOHTtCqVRi4sSJ+VY2845T1O8fIkPClUkioufY2NhgyZIl6NmzJ+rVq4du3brB0dERycnJ2L59O959910sXLgQNjY2aNasGWbMmAG1Wo1y5crht99+w9WrV/PNWb9+fQDAuHHj0K1bN5iYmKB9+/awtLRE3759MW3aNPTt2xd+fn44cOCAtIJXFBUrVsTkyZMxduxYXLt2DR07doS1tTWuXr2KX375Bf3798fIkSNL7PUpKjMzM+zcuROhoaFo1KgRduzYge3bt+PLL7+Ung3Zvn17tGjRAuPGjcO1a9fg4+OD3377DVu2bEF4eLi0ylcYc3Nz1KhRA//73/9QpUoV2Nvbo1atWqhVq1aRvz5FValSJYwbNw6TJk1C06ZN0blzZ6hUKhw9ehRubm6IiYkp8vcPkUEppbvIiYhKRd6jcI4ePVrouL1794qgoCBha2srzMzMRMWKFUVYWJg4duyYNOaff/4RnTp1EnZ2dsLW1lZ8+OGH4tatWzofVTNp0iRRrlw5oVQqtR7Fk5mZKfr06SNsbW2FtbW16Nq1q7h7926BjwbKe6zOi37++Wfh7+8vLC0thaWlpahWrZoYPHiwSExMLNLr8eKjgdq2bZtvLAAxePBgrba8xxvNnDlTagsNDRWWlpbi8uXLonXr1sLCwkI4OzuLCRMm5HukzuPHj8WwYcOEm5ubMDExEZUrVxYzZ86UHrVT2LHz/Pnnn6J+/frC1NRU63Ur6tenoNdW12sjhBArV64Uvr6+QqVSiTJlyojmzZuLuLg4rTFF+f4hMhQKId7AVdNERPSfERYWhp9++gnp6emlHQoRvQG8ZpKIiIiIZGMxSURERESysZgkIiIiItl4zSQRERERycaVSSIiIiKSjcUkEREREcnGh5ZTidNoNLh16xasra0L/FNnRERE9HYRQuDx48dwc3ODUln09UYWk1Tibt26BXd399IOg4iIiGS4ceMGypcvX+TxLCapxFlbWwMArl69Cnt7+1KOpuSp1Wr89ttvaN26NUxMTEo7nBLH/PSbIednyLkBzE/fGUJ+aWlpcHd3l36OFxWLSSpxeae2ra2tYWNjU8r
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot the feature importance\n",
"xgb.plot_importance(model)"
]
},
{
"cell_type": "code",
2024-06-12 17:19:27 +02:00
"execution_count": 14,
2024-06-05 23:18:20 +02:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIF0lEQVR4nO3de5iVZb038O8wwHAS0HBAkd1kniI8BUGIhw4kKtFLmZKHUFTaHkhjyhRT0EzR8oCVSqWI7W1b0rRNaZhNsdsmRmqQ9gZ5QkwDwROIOcjMev/odWpiMEB4FjCfz3Wt63Ldz32v9Xt4gu7ru+7nfipKpVIpAAAAAFCgNuUuAAAAAIDWRygFAAAAQOGEUgAAAAAUTigFAAAAQOGEUgAAAAAUTigFAAAAQOGEUgAAAAAUTigFAAAAQOGEUgAAAAAUTigFsAlVVFTkwgsvbHo/ffr0VFRUZNGiRWWrCQBgcznxxBNTU1OzQWNmz56dioqKzJ49e7PUBGw9hFLAVuXNkOfNV9u2bdO7d++ceOKJefbZZ8tdHgDAZvfP86EOHTpkjz32yLhx47J06dJylwew3tqWuwCAjfGVr3wl73rXu/L666/ngQceyPTp03Pffffl0UcfTYcOHcpdHgDAZveP86H77rsv119/fe6+++48+uij6dSpUyE1fPe7301jY+MGjTn44IPz17/+Ne3bt99MVQFbC6EUsFU6/PDDM2DAgCTJKaeckh49euTyyy/PzJkzc/TRR5e5OgCAze+f50PveMc7ctVVV+W///u/c8wxx6zVf9WqVencufMmraFdu3YbPKZNmzZ+RASSuH0P2EYcdNBBSZInnniiqW3BggX51Kc+lR122CEdOnTIgAEDMnPmzLXGvvzyyxk/fnxqampSVVWVXXbZJaNHj87y5cuTJKtXr87EiRPTv3//dOvWLZ07d85BBx2UX/7yl8WcHADAevjwhz+cJHnqqady4oknpkuXLnniiSdyxBFHZLvttstxxx2XJGlsbMyUKVPy3ve+Nx06dEjPnj3z7//+73nppZfW+syf/vSnOeSQQ7Lddtula9euef/735/vf//7Tcdb2lPq1ltvTf/+/ZvG7L333rnmmmuajq9rT6nbbrst/fv3T8eOHdOjR48cf/zxa23P8OZ5Pfvssxk5cmS6dOmSHXfcMV/84hfT0NDwdv74gDIQSgHbhDc3Et9+++2TJH/4wx/ygQ98IH/84x9z7rnn5sorr0znzp0zcuTI3HnnnU3jXn311Rx00EH55je/mUMPPTTXXHNNTj311CxYsCB//vOfkyQrVqzIDTfckA9+8IO5/PLLc+GFF2bZsmUZNmxY5s2bV/SpAgC06M0f597xjnckSdasWZNhw4aluro6V1xxRY488sgkyb//+7/n7LPPzpAhQ3LNNddkzJgxueWWWzJs2LC88cYbTZ83ffr0DB8+PC+++GImTJiQyy67LPvtt19mzZq1zhruvffeHHPMMdl+++1z+eWX57LLLssHP/jB/PrXv37L2qdPn56jjz46lZWVmTx5csaOHZs77rgjBx54YF5++eVmfRsaGjJs2LC84x3vyBVXXJFDDjkkV155Zb7zne9szB8bUEZu3wO2Sq+88kqWL1+e119/Pb/5zW9y0UUXpaqqKh/72MeSJGeddVb+7d/+Lb/97W9TVVWVJDn99NNz4IEH5pxzzsknPvGJJMnXv/71PProo7njjjua2pLk/PPPT6lUSvK3oGvRokXN9j0YO3Zs9tprr3zzm9/MjTfeWNRpAwA0+cf50K9//et85StfSceOHfOxj30sc+bMSX19fY466qhMnjy5acx9992XG264IbfcckuOPfbYpvYPfehDOeyww3Lbbbfl2GOPzSuvvJIzzzwzAwcOzOzZs5vdbvfmHKkld911V7p27Zp77rknlZWV63Ueb7zxRs4555z069cvv/rVr5q+68ADD8zHPvaxXH311bnoooua+r/++usZNWpULrjggiTJqaeemve973258cYbc9ppp63fHx6wRbBSCtgqDR06NDvuuGP69OmTT33qU+ncuXNmzpyZXXbZJS+++GJ+8Ytf5Oijj87KlSuzfPnyLF++PC+88EKGDRuWxx57rGkp+A9/+MPsu+++zQKpN1VUVCRJKisrmwKpxsbGvPjii1mzZk0GDBiQhx9+uLiTBgD4B/84H/r0pz+dLl265M4770zv3r2b+vxzSHPbbbelW7du+ehHP9o0R1q+fHn69++fLl26NG1PcO+992blypU599xz19r/6c05Uku6d++eVatW5d57713v83jwwQfz/PPP5/TTT2/2XcOHD89ee+2Vu+66a60xp556arP3Bx10UJ588sn1/k5gy2ClFLBVuvbaa7PHHnvklVdeybRp0/KrX/2qaUXU448/nlKplAsuuKDpF7R/9vzzz6d379554oknmpayv5Wbb745V155ZRYsWNBsWfu73vWuTXNCAAAb6M35UNu2bdOzZ8/sueeeadPm7+sO2rZtm1122aXZmMceeyyvvPJKqqurW/zM559/PsnfbwXs16/fBtV0+umn5wc/+EEOP/zw9O7dO4ceemiOPvroHHbYYesc8/TTTydJ9txzz7WO7bXXXrnvvvuatXXo0CE77rhjs7btt9++xT2xgC2bUArYKg0cOLDpaTMjR47MgQcemGOPPTYLFy5seizxF7/4xQwbNqzF8bvtttt6f9d//ud/5sQTT8zIkSNz9tlnp7q6umm/g3/cWB0AoEj/OB9qSVVVVbOQKvnbqu/q6urccsstLY7557BnQ1VXV2fevHm555578tOf/jQ//elPc9NNN2X06NG5+eab39Znv2l9bwsEtnxCKWCr92ZA9KEPfSjf+ta3ctJJJyX52yOKhw4d+pZj3/3ud+fRRx99yz633357dt1119xxxx3NlqtPmjTp7RcPAFCgd7/73fn5z3+eIUOGpGPHjm/ZL0keffTRDfoxL0nat2+fESNGZMSIEWlsbMzpp5+eb3/727ngggta/Kx3vvOdSZKFCxc2PUHwTQsXLmw6Dmx77CkFbBM++MEPZuDAgZkyZUq6du2aD37wg/n2t7+dv/zlL2v1XbZsWdN/H3nkkZk/f36zJ/K96c1NPN/8Ne4fN/X8zW9+kzlz5mzq0wAA2KyOPvroNDQ05OKLL17r2Jo1a5qedHfooYdmu+22y+TJk/P666836/dWG52/8MILzd63adMm++yzT5Kkvr6+xTEDBgxIdXV1pk6d2qzPT3/60/zxj3/M8OHD1+vcgK2PlVLANuPss8/OUUcdlenTp+faa6/NgQcemL333jtjx47NrrvumqVLl2bOnDn585//nPnz5zeNuf3223PUUUflpJNOSv/+/fPiiy9m5syZmTp1avbdd9987GMfa3o63/Dhw/PUU09l6tSp6du3b1599dUynzUAwPo75JBD8u///u+ZPHly5s2bl0MPPTTt2rXLY489lttuuy3XXHNNPvWpT6Vr1665+uqrc8opp+T9739/jj322Gy//faZP39+XnvttXXeinfKKafkxRdfzIc//OHssssuefrpp/PNb34z++23X97znve0OKZdu3a5/PLLM2bMmBxyyCE55phjsnTp0lxzzTWpqanJ+PHjN+cfCVBGQilgm/HJT34y7373u3PFFVdk7NixefDBB3PRRRdl+vTpeeGFF1JdXZ39998/EydObBrTpUuX/O///m8mTZqUO++8MzfffHOqq6vzkY98pGlj0BNPPDFLlizJt7/97dxzzz3p27dv/vM//zO33XZbZs+eXaazBQDYOFOnTk3//v3z7W9/O+edd17atm2bmpqaHH/88RkyZEhTv5NPPjnV1dW57LLLcvHFF6ddu3bZa6+93jIkOv744/Od73wn1113XV5++eX06tUro0aNyoUXXrjW/lb/6MQTT0ynTp1y2WWX5Zxzzknnzp3ziU98Ipdffnm6d+++KU8f2IJUlN5q7SUAAAAAbAb2lAIAAACgcEIpAAAAAAonlAIAAACgcEIpAAAAAAonlAIAAACgcEIpAAAAAArXttw
"text/plain": [
"<Figure size 1200x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot recall and precision\n",
"# Calculate the recall and precision\n",
"recall = cm.diagonal() / cm.sum(axis=1)\n",
"precision = cm.diagonal() / cm.sum(axis=0)\n",
"\n",
"# plot in a bar chart\n",
"fig, ax = plt.subplots(1, 2, figsize=(12, 6))\n",
"ax[0].bar(range(num_classes), recall)\n",
"ax[0].set_xticks(range(num_classes))\n",
"ax[0].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
"ax[0].set_xlabel('Class')\n",
"ax[0].set_ylabel('Recall')\n",
"ax[0].set_title('Recall')\n",
"\n",
"ax[1].bar(range(num_classes), precision)\n",
"ax[1].set_xticks(range(num_classes))\n",
"ax[1].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
"ax[1].set_xlabel('Class')\n",
"ax[1].set_ylabel('Precision')\n",
"ax[1].set_title('Precision')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
2024-06-12 17:19:27 +02:00
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F1 Score: 0.8157211953487169\n"
]
}
],
"source": [
"# Calculate F1 Score for multiclass classification\n",
"f1 = f1_score(test_y, preds, average='macro')\n",
"\n",
"print('F1 Score:', f1)"
]
2024-06-05 23:18:20 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}