DSA_SS24/notebooks/ml_decision_tree.ipynb

1281 lines
128 KiB
Plaintext
Raw Normal View History

2024-06-21 18:29:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Decision Tree Training and Analysis"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sqlite3\n",
"import os\n",
"from datetime import datetime\n",
"from joblib import dump, load\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
"from sklearn.metrics import confusion_matrix, accuracy_score\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"import seaborn as sns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Data from Database"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"conn = sqlite3.connect('../features.db')\n",
"c = conn.cursor()\n",
"\n",
"# Get training, validation, and test data\n",
"train = pd.read_sql_query(\"SELECT * FROM train\", conn)\n",
"valid = pd.read_sql_query(\"SELECT * FROM validation\", conn)\n",
"test = pd.read_sql_query(\"SELECT * FROM test\", conn)\n",
"\n",
"# Close the connection\n",
"conn.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Format Data for Machine Learning"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_x shape: (3502, 10)\n",
"test_x shape: (438, 10)\n",
"valid_x shape: (438, 10)\n",
"features: ['age', 'gender', 'artial_rate', 'ventricular_rate', 'qrs_duration', 'qt_length', 'qrs_count', 'q_peak', 'r_axis', 't_axis']\n",
"number of classes: 4\n"
]
}
],
"source": [
"# Get the target and features\n",
"train_y = train['y'].map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"train_x = train.drop(columns=['y'])\n",
"\n",
"valid_y = valid['y'].map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"valid_x = valid.drop(columns=['y'])\n",
"\n",
"test_y = test['y'].map({'GSVT': 0, 'AFIB': 1, 'SR': 2, 'SB': 3})\n",
"test_x = test.drop(columns=['y'])\n",
"\n",
"# Drop id column\n",
"train_x = train_x.drop(columns=['id'])\n",
"valid_x = valid_x.drop(columns=['id'])\n",
"test_x = test_x.drop(columns=['id'])\n",
"\n",
"print('train_x shape:', train_x.shape)\n",
"print('test_x shape:', test_x.shape)\n",
"print('valid_x shape:', valid_x.shape)\n",
"\n",
"# Print column names\n",
"print('features:', train_x.columns.to_list())\n",
"feature_names = train_x.columns.to_list()\n",
"\n",
"# Create an imputer object with a mean filling strategy\n",
"imputer = SimpleImputer(strategy='mean')\n",
"\n",
"train_x = imputer.fit_transform(train_x)\n",
"valid_x = imputer.transform(valid_x)\n",
"test_x = imputer.transform(test_x)\n",
"\n",
"# Scale data between 0 and 1\n",
"scaler = MinMaxScaler()\n",
"\n",
"# Fit the scaler to your data and then transform it\n",
"train_x = scaler.fit_transform(train_x)\n",
"valid_x = scaler.transform(valid_x)\n",
"test_x = scaler.transform(test_x)\n",
"\n",
"# Use DecisionTreeClassifier\n",
"num_classes = len(set(valid_y.to_list()))\n",
"print('number of classes:', num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test Grid for Hyperparameter Analysis"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"param_grid = {\n",
" 'criterion': ['gini', 'entropy'],\n",
" 'max_depth': [None, 10, 20, 30, 40, 50],\n",
" 'min_samples_split': [2, 10, 20],\n",
" 'min_samples_leaf': [1, 5, 10]\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Create a DecisionTreeClassifier object\n",
"model = DecisionTreeClassifier()\n",
"\n",
"# Create the grid search object\n",
"grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=3, estimator=DecisionTreeClassifier(),\n",
" param_grid={&#x27;criterion&#x27;: [&#x27;gini&#x27;, &#x27;entropy&#x27;],\n",
" &#x27;max_depth&#x27;: [None, 10, 20, 30, 40, 50],\n",
" &#x27;min_samples_leaf&#x27;: [1, 5, 10],\n",
" &#x27;min_samples_split&#x27;: [2, 10, 20]},\n",
" scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;GridSearchCV<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html\">?<span>Documentation for GridSearchCV</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>GridSearchCV(cv=3, estimator=DecisionTreeClassifier(),\n",
" param_grid={&#x27;criterion&#x27;: [&#x27;gini&#x27;, &#x27;entropy&#x27;],\n",
" &#x27;max_depth&#x27;: [None, 10, 20, 30, 40, 50],\n",
" &#x27;min_samples_leaf&#x27;: [1, 5, 10],\n",
" &#x27;min_samples_split&#x27;: [2, 10, 20]},\n",
" scoring=&#x27;accuracy&#x27;)</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">best_estimator_: DecisionTreeClassifier</label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;DecisionTreeClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)</pre></div> </div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"GridSearchCV(cv=3, estimator=DecisionTreeClassifier(),\n",
" param_grid={'criterion': ['gini', 'entropy'],\n",
" 'max_depth': [None, 10, 20, 30, 40, 50],\n",
" 'min_samples_leaf': [1, 5, 10],\n",
" 'min_samples_split': [2, 10, 20]},\n",
" scoring='accuracy')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# Fit the grid search object to the data\n",
"grid_search.fit(train_x, train_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Results"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters: {'criterion': 'gini', 'max_depth': 10, 'min_samples_leaf': 10, 'min_samples_split': 10}\n",
"Best score: 0.769842911809933\n"
]
}
],
"source": [
"\n",
"# Print the best parameters and the best score\n",
"print(f'Best parameters: {grid_search.best_params_}')\n",
"print(f'Best score: {grid_search.best_score_}')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Save Model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['../ml_models/best_decision_tree_model_20240621173105.joblib']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# Save the best model\n",
"best_model = grid_search.best_estimator_\n",
"\n",
"# Timestamp\n",
"timestamp = datetime.now().strftime('%Y%m%d%H%M%S')\n",
"model_path = f'../ml_models/best_decision_tree_model_{timestamp}.joblib'\n",
"dump(best_model, model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example Training of best Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load the best model to get the best hyperparameters from it"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# List directory\n",
"models = os.listdir('../ml_models')\n",
"model_path = [model for model in models if 'joblib' in model and 'best' in model and 'decision_tree' in model][0]\n",
"model_path = f'../ml_models/{model_path}'\n",
"\n",
"# Load the best model\n",
"best_model = load(model_path)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-2 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-2 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-2 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-2 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-2 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-2 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-2 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-2 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-2 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-2 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-2 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" checked><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;DecisionTreeClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)</pre></div> </div></div></div></div>"
],
"text/plain": [
"DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Example training of a model with the best parameters\n",
"model = DecisionTreeClassifier(**grid_search.best_params_)\n",
"model.fit(train_x, train_y)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Accuracy: 0.7990867579908676\n"
]
}
],
"source": [
"# Predictions and accuracy\n",
"preds = best_model.predict(test_x)\n",
"accuracy = accuracy_score(test_y, preds)\n",
"print(f\"Model Accuracy: {accuracy}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate Model Performance"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhsAAAHHCAYAAAAWM5p0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABLm0lEQVR4nO3dd1gUVxcG8HcXYekgKE1FsHesMYiKhdgNthijiWgsiYoNNQZjN4ox9ooaW4xEE40lJhaEKDFiQzFW7MFGVUDagux8fxj3ywY0oDvMwrw/n3ke987snTOLwuHce2cUgiAIICIiIhKJUuoAiIiIqHRjskFERESiYrJBREREomKyQURERKJiskFERESiYrJBREREomKyQURERKJiskFERESiYrJBREREomKyQSSiGzduoEOHDrCxsYFCocCePXv02v/du3ehUCiwefNmvfZbkrVp0wZt2rSROgwi+gcmG1Tq3bp1C5988gmqVKkCU1NTWFtbw8vLC8uWLUNWVpao5/bz88PFixcxd+5cbN26FU2bNhX1fMVp0KBBUCgUsLa2LvBzvHHjBhQKBRQKBRYuXFjk/h8+fIiZM2ciOjpaD9ESkZTKSB0AkZh++eUXvPfee1CpVBg4cCDq1auHnJwcHD9+HJMmTcLly5exbt06Uc6dlZWFyMhIfPHFF/D39xflHJUrV0ZWVhaMjY1F6f+/lClTBpmZmfj555/Rt29fnX3btm2DqakpsrOzX6vvhw8fYtasWXBzc0PDhg0L/b7Dhw+/1vmISDxMNqjUunPnDvr164fKlSsjPDwczs7O2n2jRo3CzZs38csvv4h2/sTERACAra2taOdQKBQwNTUVrf//olKp4OXlhe+//z5fshESEoKuXbti165dxRJLZmYmzM3NYWJiUiznI6LC4zAKlVoLFixAeno6NmzYoJNovFCtWjWMHTtW+/rZs2eYM2cOqlatCpVKBTc3N0yZMgVqtVrnfW5ubujWrRuOHz+Ot956C6ampqhSpQq+/fZb7TEzZ85E5cqVAQCTJk2CQqGAm5sbgOfDDy/+/k8zZ86EQqHQaQsNDUXLli1ha2sLS0tL1KxZE1OmTNHuf9mcjfDwcLRq1QoWFhawtbWFr68vrl69WuD5bt68iUGDBsHW1hY2NjYYPHgwMjMzX/7B/kv//v1x4MABpKSkaNvOnDmDGzduoH///vmOf/z4MSZOnIj69evD0tIS1tbW6Ny5My5cuKA95ujRo2jWrBkAYPDgwdrhmBfX2aZNG9SrVw9RUVFo3bo1zM3NtZ/Lv+ds+Pn5wdTUNN/1d+zYEWXLlsXDhw8Lfa1E9HqYbFCp9fPPP6NKlSpo0aJFoY4fOnQopk+fjsaNG2PJkiXw9vZGUFAQ+vXrl+/Ymzdvok+fPnjnnXewaNEilC1bFoMGDcLly5cBAL169cKSJUsAAB988AG2bt2KpUuXFin+y5cvo1u3blCr1Zg9ezYWLVqEd999F3/88ccr33fkyBF07NgRCQkJmDlzJgICAnDixAl4eXnh7t27+Y7v27cvnj59iqCgIPTt2xebN2/GrFmzCh1nr169oFAo8NNPP2nbQkJCUKtWLTRu3Djf8bdv38aePXvQrVs3LF68GJMmTcLFixfh7e2t/cFfu3ZtzJ49GwAwfPhwbN26FVu3bkXr1q21/SQnJ6Nz585o2LAhli5dirZt2xYY37Jly1C+fHn4+fkhLy8PALB27VocPnwYK1asgIuLS6GvlYhek0BUCqWmpgoABF9f30IdHx0dLQAQhg4dqtM+ceJEAYAQHh6ubatcubIAQIiIiNC2JSQkCCqVSpgwYYK27c6dOwIA4euvv9bp08/PT6hcuXK+GGbMmCH887/kkiVLBABCYmLiS+N+cY5NmzZp2xo2bCg4ODgIycnJ2rYLFy4ISqVSGDhwYL7zffzxxzp99uzZU7C3t3/pOf95HRYWFoIgCEKfPn2E9u3bC4IgCHl5eYKTk5Mwa9asAj+D7OxsIS8vL991qFQqYfbs2dq2M2fO5Lu2F7y9vQUAQnBwcIH7vL29ddoOHTokABC+/PJL4fbt24KlpaXQo0eP/7xGItIPVjaoVEpLSwMAWFlZFer4X3/9FQAQEBCg0z5hwgQAyDe3o06dOmjVqpX2dfny5VGzZk3cvn37tWP+txdzPfbu3QuNRlOo9zx69AjR0dEYNGgQ7OzstO0NGjTAO++8o73Of/r00091Xrdq1QrJycnaz7Aw+vfvj6NHjyIuLg7h4eGIi4srcAgFeD7PQ6l8/q0nLy8PycnJ2iGic+fOFfqcKpUKgwcPLtSxHTp0wCeffILZs2ejV69eMDU1xdq1awt9LiJ6M0w2qFSytrYGADx9+rRQx//1119QKpWoVq2aTruTkxNsbW3x119/6bS7urrm66Ns2bJ48uTJa0ac3/vvvw8vLy8MHToUjo6O6NevH3744YdXJh4v4qxZs2a+fbVr10ZSUhIyMjJ02v99LWXLlgWAIl1Lly5dYGVlhR07dmDbtm1o1qxZvs/yBY1GgyVLlqB69epQqVQoV64cypcvjz///BOpqamFPmeFChWKNBl04cKFsLOzQ3R0NJYvXw4HB4dCv5eI3gyTDSqVrK2t4eLigkuXLhXpff+eoPkyRkZGBbYLgvDa53gxn+AFMzMzRERE4MiRI/joo4/w559/4v3338c777yT79g38SbX8oJKpUKvXr2wZcsW7N69+6VVDQCYN28eAgIC0Lp1a3z33Xc4dOgQQkNDUbdu3UJXcIDnn09RnD9/HgkJCQCAixcvFum9RPRmmGxQqdWtWzfcunULkZGR/3ls5cqVodFocOPGDZ32+Ph4pKSkaFeW6EPZsmV1Vm688O/qCQAolUq0b98eixcvxpUrVzB37lyEh4fjt99+K7DvF3HGxMTk23ft2jWUK1cOFhYWb3YBL9G/f3+cP38eT58+LXBS7Qs7d+5E27ZtsWHDBvTr1w8dOnSAj49Pvs+ksIlfYWRkZGDw4MGoU6cOhg8fjgULFuDMmTN665+IXo3JBpVan332GSwsLDB06FDEx8fn23/r1i0sW7YMwPNhAAD5VowsXrwYANC1a1e9xVW1alWkpqbizz//1LY9evQIu3fv1jnu8ePH+d774uZW/16O+4KzszMaNmyILVu26PzwvnTpEg4fPqy9TjG0bdsWc+bMwcqVK+Hk5PTS44yMjPJVTX788Uc8ePBAp+1FUlRQYlZUkydPRmxsLLZs2YLFixfDzc0Nfn5+L/0ciUi/eFMvKrWqVq2KkJAQvP/++6hdu7bOHURPnDiBH3/8EYMGDQIAeHh4wM/PD+vWrUNKSgq8vb1x+vRpbNmyBT169HjpssrX0a9fP0yePBk9e/bEmDFjkJmZiTVr1qBGjRo6EyRnz56NiIgIdO3aFZUrV0ZCQgJWr16NihUromXLli/t/+uvv0bnzp3h6emJIUOGICsrCytWrICNjQ1mzpypt+v4N6VSialTp/7ncd26dcPs2bMxePBgtGjRAhcvXsS2bdtQpUoVneOqVq0KW1tbBAcHw8rKChYWFmjevDnc3d2LFFd4eDhWr16NGTNmaJfibtq0CW3atMG0adOwYMGCIvVHRK9B4tUwRKK7fv26MGzYMMHNzU0wMTERrKysBC8vL2HFihVCdna29rjc3Fxh1qxZgru7u2BsbCxUqlRJCAwM1DlGEJ4vfe3atWu+8/x7yeXLlr4KgiAcPnxYqFevnmBiYiLUrFlT+O677/ItfQ0LCxN8fX0FFxcXwcTERHBxcRE++OAD4fr16/nO8e/loUeOHBG8vLwEMzMzwdraWujevbtw5coVnWNenO/fS2s3bdokABDu3Lnz0s9UEHSXvr7My5a+TpgwQXB2dhbMzMwELy8vITIyssAlq3v37hXq1KkjlClTRuc6vb29hbp16xZ4zn/2k5aWJlSuXFlo3LixkJubq3Pc+PHjBaVSKURGRr7yGojozSkEoQizwIiIiIiKiHM2iIiISFRMNoiIiEhUTDaIiIhIVEw2iIiISFRMNoiIiEhUTDaIiIhIVEw2iIiISFSl8g6
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Optional: Plot confusion matrix\n",
"cm = confusion_matrix(test_y, preds)\n",
"sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\")\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"Actual\")\n",
"plt.title(\"Confusion Matrix\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAJOCAYAAACqS2TfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABfeklEQVR4nO3deVhUdeP+8XtAAUFZXEAlBBTX3Nc0t4xcssz8PmppqZTW02qilmbikglpKvm4l2aWlWZqm7lEtlvu5pL7gprgDoqpCOf3hz+nJtBAORyGeb+ua65LPnNm5uYEMfecz/kcm2EYhgAAAAAAQJ5zszoAAAAAAACFFaUbAAAAAACTULoBAAAAADAJpRsAAAAAAJNQugEAAAAAMAmlGwAAAAAAk1C6AQAAAAAwCaUbAAAAAACTULoBAAAAADAJpRsAAAAAAJNQugEABdrcuXNls9myvQ0ZMsSU1/z55581cuRInT171pTnvxXX9sf69eutjnLTpk2bprlz51odAwCAfFHE6gAAAOTE6NGjFR4e7jBWs2ZNU17r559/1qhRo9SnTx/5+/ub8hqubNq0aSpdurT69OljdRQAAExH6QYAOIUOHTqoYcOGVse4JWlpafLx8bE6hmUuXLggb29vq2MAAJCvmF4OACgUvvrqK7Vo0UI+Pj4qUaKEOnbsqO3btzts89tvv6lPnz6qWLGivLy8VLZsWT322GM6deqUfZuRI0dq8ODBkqTw8HD7VPaDBw/q4MGDstls2U6NttlsGjlypMPz2Gw27dixQz169FBAQICaN29uv//9999XgwYNVKxYMZUsWVIPPfSQDh8+fFPfe58+fVS8eHElJibqvvvuU/HixRUcHKypU6dKkrZu3ao2bdrIx8dHoaGh+uCDDxwef23K+vfff68nn3xSpUqVkq+vr3r16qUzZ85keb1p06bp9ttvl6enp8qXL69nnnkmy1T81q1bq2bNmtqwYYNatmwpb29vvfzyywoLC9P27dv13Xff2fdt69atJUmnT5/WoEGDVKtWLRUvXly+vr7q0KGDtmzZ4vDc3377rWw2mxYuXKjXXntNt912m7y8vHT33Xdr7969WfL++uuvuvfeexUQECAfHx/Vrl1bb775psM2O3fu1H/+8x+VLFlSXl5eatiwoT777DOHbdLT0zVq1ChVrlxZXl5eKlWqlJo3b65Vq1bl6L8TAMA1caQbAOAUUlJSdPLkSYex0qVLS5Lee+899e7dW+3atdPrr7+uCxcuaPr06WrevLk2bdqksLAwSdKqVau0f/9+RUVFqWzZstq+fbtmzZql7du365dffpHNZlOXLl20e/duffjhh5o0aZL9NcqUKaMTJ07kOnfXrl1VuXJljR07VoZhSJJee+01DR8+XN26dVPfvn114sQJ/e9//1PLli21adOmm5rSnpGRoQ4dOqhly5YaN26c5s+fr2effVY+Pj4aNmyYevbsqS5dumjGjBnq1auXmjZtmmW6/rPPPit/f3+NHDlSu3bt0vTp03Xo0CF7yZWufpgwatQoRUZG6qmnnrJvt27dOv30008qWrSo/flOnTqlDh066KGHHtIjjzyioKAgtW7dWs8995yKFy+uYcOGSZKCgoIkSfv379fSpUvVtWtXhYeHKzk5WTNnzlSrVq20Y8cOlS9f3iFvXFyc3NzcNGjQIKWkpGjcuHHq2bOnfv31V/s2q1at0n333ady5cqpf//+Klu2rH7//Xd98cUX6t+/vyRp+/btuvPOOxUcHKwhQ4bIx8dHCxcuVOfOnfXJJ5/owQcftH/vsbGx6tu3rxo3bqzU1FStX79eGzdu1D333JPr/2YAABdhAABQgL3zzjuGpGxvhmEY586dM/z9/Y1+/fo5PC4pKcnw8/NzGL9w4UKW5//www8NScb3339vHxs/frwhyThw4IDDtgcOHDAkGe+8806W55FkjBgxwv71iBEjDEnGww8/7LDdwYMHDXd3d+O1115zGN+6datRpEiRLOPX2x/r1q2zj/Xu3duQZIwdO9Y+dubMGaNYsWKGzWYzPvroI/v4zp07s2S99pwNGjQwLl++bB8fN26cIcn49NNPDcMwjOPHjxseHh5G27ZtjYyMDPt2U6ZMMSQZc+bMsY+1atXKkGTMmDEjy/dw++23G61atcoyfvHiRYfnNYyr+9zT09MYPXq0fWz16tWGJKN69erGpUuX7ONvvvmmIcnYunWrYRiGceXKFSM8PNwIDQ01zpw54/C8mZmZ9n/ffffdRq1atYyLFy863N+sWTOjcuXK9rE6deoYHTt2zJIbAIAbYXo5AMApTJ06VatWrXK4SVePZJ49e1YPP/ywTp48ab+5u7urSZMmWr16tf05ihUrZv/3xYsXdfLkSd1xxx2SpI0bN5qS+7///a/D14sXL1ZmZqa6devmkLds2bKqXLmyQ97c6tu3r/3f/v7+qlq1qnx8fNStWzf7eNWqVeXv76/9+/dnefwTTzzhcKT6qaeeUpEiRbRs2TJJ0tdff63Lly/rhRdekJvbX28h+vXrJ19fX3355ZcOz+fp6amoqKgc5/f09LQ/b0ZGhk6dOqXixYuratWq2f73iYqKkoeHh/3rFi1aSJL9e9u0aZMOHDigF154IcvsgWtH7k+fPq1vvvlG3bp107lz5+z/PU6dOqV27dppz549Onr0qKSr+3T79u3as2dPjr8nAACYXg4AcAqNGzfOdiG1awWoTZs22T7O19fX/u/Tp09r1KhR+uijj3T8+HGH7VJSUvIw7V/+OYV7z549MgxDlStXznb7v5fe3PDy8lKZMmUcxvz8/HTbbbfZC+bfx7M7V/ufmYoXL65y5crp4MGDkqRDhw5Julrc/87Dw0MVK1a0339NcHCwQyn+N5mZmXrzzTc1bdo0HThwQBkZGfb7SpUqlWX7ChUqOHwdEBAgSfbvbd++fZJuvMr93r17ZRiGhg8fruHDh2e7zfHjxxUcHKzRo0frgQceUJUqVVSzZk21b99ejz76qGrXrp3j7xEA4Hoo3QAAp5aZmSnp6nndZcuWzXJ/kSJ//anr1q2bfv75Zw0ePFh169ZV8eLFlZmZqfbt29uf50b+WV6v+Xs5/Ke/H12/ltdms+mrr76Su7t7lu2LFy/+rzmyk91z3Wjc+P/nl5vpn9/7vxk7dqyGDx+uxx57TK+++qpKliwpNzc3vfDCC9n+98mL7+3a8w4aNEjt2rXLdpuIiAhJUsuWLbVv3z59+umnWrlypd5++21NmjRJM2bMcJhlAADA31G6AQBOrVKlSpKkwMBARUZGXne7M2fOKCEhQaNGjVJMTIx9PLupwtcr19eOpP5zpe5/HuH9t7yGYSg8PFxVqlTJ8ePyw549e3TXXXfZvz5//ryOHTume++9V5IUGhoqSdq1a5cqVqxo3+7y5cs6cODADff/311v/y5atEh33XWXZs+e7TB+9uxZ+4J2uXHtZ2Pbtm3XzXbt+yhatGiO8pcsWVJRUVGKiorS+fPn1bJlS40cOZLSDQC4Ls7pBgA4tXbt2snX11djx45Venp6lvuvrTh+7ajoP4+CxsfHZ3nMtWtp/7Nc+/r6qnTp0vr+++8dxqdNm5bjvF26dJG7u7tGjRqVJYthGA6XL8tvs2bNctiH06dP15UrV9ShQwdJUmRkpDw8PDR58mSH7LNnz1ZKSoo6duyYo9fx8fHJsm+lq/+N/rlPPv74Y/s51blVv359hYeHKz4+PsvrXXudwMBAtW7dWjNnztSxY8eyPMffV6z/53+b4sWLKyIiQpcuXbqpfAAA18CRbgCAU/P19dX06dP16KOPqn79+nrooYdUpkwZJSYm6ssvv9Sdd96pKVOmyNfX1345rfT0dAUHB2vlypU6cOBAluds0KCBJGnYsGF66KGHVLRoUd1///3y8fFR3759FRcXp759+6phw4b6/vvvtXv37hznrVSpksaMGaOhQ4fq4MGD6ty5s0qUKKEDBw5oyZIleuKJJzRo0KA82z+5cfnyZd19993q1q2bdu3apWnTpql58+bq1KmTpKuXTRs6dKhGjRql9u3bq1OnTvb
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot the feature importance\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"feature_importances = model.feature_importances_\n",
"# Sort the feature importances in descending order\n",
"sorted_idx = np.argsort(feature_importances)[::-1]\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"plt.title(\"Feature Importances\")\n",
"plt.bar(range(len(feature_importances)), feature_importances[sorted_idx], align=\"center\")\n",
"plt.xticks(range(len(feature_importances)), np.array(feature_names)[sorted_idx], rotation=90)\n",
"plt.xlim([-1, len(feature_importances)])\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABBfUlEQVR4nO3de5iVVd038O8wwHAGFUFFcjJP4bEgDfHQgURFezFT8hCKSnng0aRMMRXNFC0PWKmYiljpo3nqoVTMUJ6yMFOCtFdNTcS04eABEHMQZr9/+DI5MRrgzL1h5vO5rn1d7rXX2vt3cwet63uve90VpVKpFAAAAAAoUJtyFwAAAABA6yOUAgAAAKBwQikAAAAACieUAgAAAKBwQikAAAAACieUAgAAAKBwQikAAAAACieUAgAAAKBwQikAAAAACieUAmhCFRUVOffcc+vfT548ORUVFZkzZ07ZagIAaC5HH310qqur12jM9OnTU1FRkenTpzdLTcD6QygFrFdWhjwrX23btk2fPn1y9NFH56WXXip3eQAAze7f50MdOnTINttsk9GjR2fevHnlLg9gtbUtdwEAa+Pb3/52PvzhD+ett97Kww8/nMmTJ+ehhx7KE088kQ4dOpS7PACAZvfu+dBDDz2Uq6++Ovfcc0+eeOKJdOrUqZAarr322tTV1a3RmL322iv//Oc/0759+2aqClhfCKWA9dJ+++2XAQMGJEmOO+649OzZMxdffHGmTJmSQw89tMzVAQA0v3+fD2200Ua57LLL8j//8z857LDDVum/dOnSdO7cuUlraNeu3RqPadOmjYuIQBK37wEtxJ577pkkee655+rbnnrqqXzxi1/MhhtumA4dOmTAgAGZMmXKKmNff/31nHrqqamurk5VVVU233zzjBgxIgsXLkySLFu2LOecc0769++f7t27p3Pnztlzzz3z4IMPFnNwAACr4TOf+UyS5Pnnn8/RRx+dLl265Lnnnsv++++frl275ogjjkiS1NXVZcKECdl+++3ToUOH9O7dO1/96lfz2muvrfKd9957b/bee+907do13bp1yyc+8YncfPPN9Z83tqfULbfckv79+9eP2XHHHXPFFVfUf/5ee0rddttt6d+/fzp27JiePXvmyCOPXGV7hpXH9dJLL2XYsGHp0qVLNt5443zjG9/IihUrPsgfH1AGQimgRVi5kfgGG2yQJPnLX/6ST37yk3nyySdzxhln5NJLL03nzp0zbNiw3HXXXfXj3njjjey55575wQ9+kH322SdXXHFFjj/++Dz11FP5+9//niRZvHhxrrvuunzqU5/KxRdfnHPPPTcLFizIkCFDMmvWrKIPFQCgUSsvzm200UZJkuXLl2fIkCHp1atXLrnkkhx88MFJkq9+9as57bTTMmjQoFxxxRUZOXJkbrrppgwZMiRvv/12/fdNnjw5Q4cOzauvvpqxY8fmoosuyi677JKpU6e+Zw33339/DjvssGywwQa5+OKLc9FFF+VTn/pUfve7371v7ZMnT86hhx6aysrKjB8/PqNGjcqdd96ZPfbYI6+//nqDvitWrMiQIUOy0UYb5ZJLLsnee++dSy+9ND/60Y/W5o8NKCO37wHrpUWLFmXhwoV566238oc//CHnnXdeqqqqcsABByRJTjnllHzoQx/KH//4x1RVVSVJTjzxxOyxxx45/fTTc9BBByVJvve97+WJJ57InXfeWd+WJGeddVZKpVKSd4KuOXPmNNj3YNSoUdluu+3ygx/8INdff31Rhw0AUO/d86Hf/e53+fa3v52OHTvmgAMOyIwZM1JbW5tDDjkk48ePrx/z0EMP5brrrstNN92Uww8/vL7905/+dPbdd9/cdtttOfzww7No0aKcfPLJ2XXXXTN9+vQGt9utnCM15u677063bt1y3333pbKycrWO4+23387pp5+eHXbYIb/5zW/qf2uPPfbIAQcckMsvvzznnXdeff+33norw4cPz9lnn50kOf744/Pxj388119/fU444YTV+8MD1glWSgHrpcGDB2fjjTdO375988UvfjGdO3fOlClTsvnmm+fVV1/NAw88kEMPPTRLlizJwoULs3DhwrzyyisZMmRInnnmmfql4HfccUd23nnnBoHUShUVFUmSysrK+kCqrq4ur776apYvX54BAwZk5syZxR00AMC7vHs+9KUvfSldunTJXXfdlT59+tT3+feQ5rbbbkv37t3zuc99rn6OtHDhwvTv3z9dunSp357g/vvvz5IlS3LGGWessv/TyjlSY3r06JGlS5fm/vvvX+3jePTRRzN//vyceOKJDX5r6NCh2W677XL33XevMub4449v8H7PPffM3/72t9X+TWDdYKUUsF668sors80222TRokWZNGlSfvOb39SviHr22WdTKpVy9tln119B+3fz589Pnz598txzz9UvZX8/N954Yy699NI89dRTDZa1f/jDH26aAwIAWEMr50Nt27ZN7969s+2226ZNm3+tO2jbtm0233zzBmOeeeaZLFq0KL169Wr0O+fPn5/kX7cC7rDDDmtU04knnpif/exn2W+//dKnT5/ss88+OfTQQ7Pvvvu+55gXXnghSbLtttuu8tl2222Xhx56qEFbhw4dsvHGGzdo22CDDRrdEwtYtwmlgPXSrrvuWv+0mWHDhmWPPfbI4Ycfnqeffrr+scTf+MY3MmTIkEbHb7XVVqv9Wz/96U9z9NFHZ9iwYTnttNPSq1ev+v0O3r2xOgBAkd49H2pMVVVVg5AqeWfVd69evXLTTTc1Oubfw5411atXr8yaNSv33Xdf7r333tx777254YYbMmLEiNx4440f6LtXWt3bAoF1n1AKWO+tDIg+/elP54c//GGOOeaYJO88onjw4MHvO/YjH/lInnjiifftc/vtt2fLLbfMnXfe2WC5+rhx4z548QAABfrIRz6SX//61xk0aFA6duz4vv2S5Iknnliji3lJ0r59+xx44IE58MADU1dXlxNPPDHXXHNNzj777Ea/a4sttkiSPP300/VPEFzp6aefrv8caHnsKQW0CJ/61Key6667ZsKECenWrVs+9alP5Zprrsk//vGPVfouWLCg/r8PPvjgzJ49u8ET+VZauYnnyqtx797U8w9/+ENmzJjR1IcBANCsDj300KxYsSLnn3/+Kp8tX768/kl3++yzT7p27Zrx48fnrbfeatDv/TY6f+WVVxq8b9OmTXbaaackSW1tbaNjBgwYkF69emXixIkN+tx777158sknM3To0NU6NmD9Y6UU0GKcdtppOeSQQzJ58uRceeWV2WOPPbLjjjtm1KhR2XLLLTNv3rzMmDEjf//73zN79uz6MbfffnsOOeSQHHPMMenfv39effXVTJkyJRMnTszOO++cAw44oP7pfEOHDs3zzz+fiRMnpl+/fnnjjTfKfNQAAKtv7733zle/+tWMHz8+s2bNyj777JN27drlmWeeyW233ZYrrrgiX/ziF9OtW7dcfvnlOe644/KJT3wihx9+eDbYYIPMnj07b7755nveinfcccfl1VdfzWc+85lsvvnmeeGFF/KDH/wgu+yySz760Y82OqZdu3a5+OKLM3LkyOy999457LDDMm/evFxxxRWprq7Oqaee2px/JEAZCaWAFuMLX/hCPvKRj+SSSy7JqFGj8uijj+a8887L5MmT88orr6RXr1752Mc+lnPOOad+TJcuXfLb3/4248aNy1133ZUbb7wxvXr1ymc/+9n6jUGPPvro1NTU5Jprrsl9992Xfv365ac//Wluu+22TJ8+vUxHCwCwdiZOnJj+/fvnmmuuyZlnnpm2bdumuro6Rx55ZAYNGlTf79hjj02vXr1y0UUX5fzzz0+7du2y3XbbvW9IdOSRR+ZHP/pRrrrqqrz++uvZZJNNMnz48Jx77rmr7G/1bkcffXQ6deqUiy66KKeffno6d+6cgw46KBdffHF69OjRlIcPrEMqSu+39hIAAAAAmoE9pQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMIJpQAAAAAonFAKAAAAgMK1LXcBRaurq8vLL7+crl2
"text/plain": [
"<Figure size 1200x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot recall and precision\n",
"# Calculate the recall and precision\n",
"recall = cm.diagonal() / cm.sum(axis=1)\n",
"precision = cm.diagonal() / cm.sum(axis=0)\n",
"\n",
"# plot in a bar chart\n",
"fig, ax = plt.subplots(1, 2, figsize=(12, 6))\n",
"ax[0].bar(range(num_classes), recall)\n",
"ax[0].set_xticks(range(num_classes))\n",
"ax[0].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
"ax[0].set_xlabel('Class')\n",
"ax[0].set_ylabel('Recall')\n",
"ax[0].set_title('Recall')\n",
"\n",
"ax[1].bar(range(num_classes), precision)\n",
"ax[1].set_xticks(range(num_classes))\n",
"ax[1].set_xticklabels(['GSVT', 'AFIB', 'SR', 'SB'])\n",
"ax[1].set_xlabel('Class')\n",
"ax[1].set_ylabel('Precision')\n",
"ax[1].set_title('Precision')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}