{ "cells": [ { "cell_type": "markdown", "id": "f0ad356c-2f8d-4aea-968f-ebcdb9c8e857", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "# Exploring Titanic Dataset: Analyzing Survival Factors\n", "This Jupyter notebook looks into a dataset about the Titanic. The goal is to uncover insights into the factors that influenced survival aboard the Titanic. The dataset comprises passenger information such as age, gender, ticket class, fare, and more, allowing us to investigate correlations and patterns related to survival outcomes.\n", "\n", "## Data Preprocessing\n", "Before embarking on our analysis, it's crucial to understand the journey the dataset has undergone. Initially, the raw data was subjected to a series of Python scripts for preprocessing. These scripts handled tasks such as handling missing ids, separation of variables (Salutation, first name, last name), and ensuring data integrity.\n", "\n", "Following preprocessing, the dataset was loaded into a MySQL Database for efficient storage and retrieval. \n", "\n", "## Objective\n", "Our primary goal is to discern whether certain variables played a significant role in determining the survival of passengers aboard the Titanic. By analyzing features like age, gender, ticket class, and familial relationships, we aim to unravel potential correlations and uncover underlying trends that influenced survival rates.\n", "\n", "## Sources\n", "- [Kaggle - Titanic Data Set](https://www.kaggle.com/datasets/sakshisatre/titanic-dataset/data)\n", "- [ChatGPT](https://chat.openai.com/)\n", "- [IBM - Logistic Regression](https://www.ibm.com/topics/logistic-regression)\n", "\n", "-----" ] }, { "cell_type": "markdown", "id": "98d38553-0bf7-4aaf-a052-8aa6b84a620a", "metadata": {}, "source": [ "# Install needed packages" ] }, { "cell_type": "code", "execution_count": 1, "id": "d45b62de-3b2a-42f6-ac79-fc9adaacc3e2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: sqlalchemy in /opt/conda/lib/python3.11/site-packages (2.0.22)\n", "Requirement already satisfied: mysql-connector-python in /opt/conda/lib/python3.11/site-packages (8.3.0)\n", "Requirement already satisfied: typing-extensions>=4.2.0 in /opt/conda/lib/python3.11/site-packages (from sqlalchemy) (4.11.0)\n", "Requirement already satisfied: greenlet!=0.4.17 in /opt/conda/lib/python3.11/site-packages (from sqlalchemy) (3.0.0)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.11/site-packages (2.2.1)\n", "Requirement already satisfied: numpy<2,>=1.23.2 in /opt/conda/lib/python3.11/site-packages (from pandas) (1.26.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas) (2023.3.post1)\n", "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.11/site-packages (3.8.3)\n", "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (1.2.0)\n", "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (4.50.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (1.4.5)\n", "Requirement already satisfied: numpy<2,>=1.21 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (23.2)\n", "Requirement already satisfied: pillow>=8 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (10.2.0)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (3.1.2)\n", "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.11/site-packages (from matplotlib) (2.8.2)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n", "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.11/site-packages (1.4.2)\n", "Requirement already satisfied: numpy>=1.19.5 in /opt/conda/lib/python3.11/site-packages (from scikit-learn) (1.26.4)\n", "Requirement already satisfied: scipy>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn) (1.11.4)\n", "Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn) (1.4.0)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn) (3.4.0)\n", "Requirement already satisfied: ydata_profiling in /opt/conda/lib/python3.11/site-packages (4.7.0)\n", "Requirement already satisfied: scipy<1.12,>=1.4.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (1.11.4)\n", "Requirement already satisfied: pandas!=1.4.0,<3,>1.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (2.2.1)\n", "Requirement already satisfied: matplotlib<3.9,>=3.2 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (3.8.3)\n", "Requirement already satisfied: pydantic>=2 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (2.6.4)\n", "Requirement already satisfied: PyYAML<6.1,>=5.0.0 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (6.0.1)\n", "Requirement already satisfied: jinja2<3.2,>=2.11.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (3.1.2)\n", "Requirement already satisfied: visions<0.7.7,>=0.7.5 in /opt/conda/lib/python3.11/site-packages (from visions[type_image_path]<0.7.7,>=0.7.5->ydata_profiling) (0.7.6)\n", "Requirement already satisfied: numpy<2,>=1.16.0 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (1.26.4)\n", "Requirement already satisfied: htmlmin==0.1.12 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (0.1.12)\n", "Requirement already satisfied: phik<0.13,>=0.11.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (0.12.4)\n", "Requirement already satisfied: requests<3,>=2.24.0 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (2.31.0)\n", "Requirement already satisfied: tqdm<5,>=4.48.2 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (4.66.1)\n", "Requirement already satisfied: seaborn<0.13,>=0.10.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (0.12.2)\n", "Requirement already satisfied: multimethod<2,>=1.4 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (1.11.2)\n", "Requirement already satisfied: statsmodels<1,>=0.13.2 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (0.14.1)\n", "Requirement already satisfied: typeguard<5,>=4.1.2 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (4.2.1)\n", "Requirement already satisfied: imagehash==4.3.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (4.3.1)\n", "Requirement already satisfied: wordcloud>=1.9.1 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (1.9.3)\n", "Requirement already satisfied: dacite>=1.8 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (1.8.1)\n", "Requirement already satisfied: numba<1,>=0.56.0 in /opt/conda/lib/python3.11/site-packages (from ydata_profiling) (0.59.1)\n", "Requirement already satisfied: PyWavelets in /opt/conda/lib/python3.11/site-packages (from imagehash==4.3.1->ydata_profiling) (1.6.0)\n", "Requirement already satisfied: pillow in /opt/conda/lib/python3.11/site-packages (from imagehash==4.3.1->ydata_profiling) (10.2.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2<3.2,>=2.11.1->ydata_profiling) (2.1.3)\n", "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (1.2.0)\n", "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (4.50.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (1.4.5)\n", "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (23.2)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (3.1.2)\n", "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.11/site-packages (from matplotlib<3.9,>=3.2->ydata_profiling) (2.8.2)\n", "Requirement already satisfied: llvmlite<0.43,>=0.42.0dev0 in /opt/conda/lib/python3.11/site-packages (from numba<1,>=0.56.0->ydata_profiling) (0.42.0)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas!=1.4.0,<3,>1.1->ydata_profiling) (2023.3.post1)\n", "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas!=1.4.0,<3,>1.1->ydata_profiling) (2024.1)\n", "Requirement already satisfied: joblib>=0.14.1 in /opt/conda/lib/python3.11/site-packages (from phik<0.13,>=0.11.1->ydata_profiling) (1.4.0)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /opt/conda/lib/python3.11/site-packages (from pydantic>=2->ydata_profiling) (0.6.0)\n", "Requirement already satisfied: pydantic-core==2.16.3 in /opt/conda/lib/python3.11/site-packages (from pydantic>=2->ydata_profiling) (2.16.3)\n", "Requirement already satisfied: typing-extensions>=4.6.1 in /opt/conda/lib/python3.11/site-packages (from pydantic>=2->ydata_profiling) (4.11.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests<3,>=2.24.0->ydata_profiling) (3.3.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests<3,>=2.24.0->ydata_profiling) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.11/site-packages (from requests<3,>=2.24.0->ydata_profiling) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests<3,>=2.24.0->ydata_profiling) (2023.7.22)\n", "Requirement already satisfied: patsy>=0.5.4 in /opt/conda/lib/python3.11/site-packages (from statsmodels<1,>=0.13.2->ydata_profiling) (0.5.6)\n", "Requirement already satisfied: attrs>=19.3.0 in /opt/conda/lib/python3.11/site-packages (from visions<0.7.7,>=0.7.5->visions[type_image_path]<0.7.7,>=0.7.5->ydata_profiling) (23.1.0)\n", "Requirement already satisfied: networkx>=2.4 in /opt/conda/lib/python3.11/site-packages (from visions<0.7.7,>=0.7.5->visions[type_image_path]<0.7.7,>=0.7.5->ydata_profiling) (3.3)\n", "Requirement already satisfied: six in /opt/conda/lib/python3.11/site-packages (from patsy>=0.5.4->statsmodels<1,>=0.13.2->ydata_profiling) (1.16.0)\n", "Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.11/site-packages (8.1.2)\n", "Requirement already satisfied: comm>=0.1.3 in /opt/conda/lib/python3.11/site-packages (from ipywidgets) (0.1.4)\n", "Requirement already satisfied: ipython>=6.1.0 in /opt/conda/lib/python3.11/site-packages (from ipywidgets) (8.16.1)\n", "Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.11/site-packages (from ipywidgets) (5.11.2)\n", "Requirement already satisfied: widgetsnbextension~=4.0.10 in /opt/conda/lib/python3.11/site-packages (from ipywidgets) (4.0.10)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.10 in /opt/conda/lib/python3.11/site-packages (from ipywidgets) (3.0.10)\n", "Requirement already satisfied: backcall in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (0.2.0)\n", "Requirement already satisfied: decorator in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n", "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)\n", "Requirement already satisfied: matplotlib-inline in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.6)\n", "Requirement already satisfied: pickleshare in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (0.7.5)\n", "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.39)\n", "Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (2.16.1)\n", "Requirement already satisfied: stack-data in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.2)\n", "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.11/site-packages (from ipython>=6.1.0->ipywidgets) (4.8.0)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /opt/conda/lib/python3.11/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3)\n", "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.11/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n", "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.11/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.1.0->ipywidgets) (0.2.8)\n", "Requirement already satisfied: executing>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (1.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /opt/conda/lib/python3.11/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.0)\n", "Requirement already satisfied: pure-eval in /opt/conda/lib/python3.11/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n", "Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.11/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n" ] } ], "source": [ "# Needed Packages to comunicate with mySQL from python\n", "!pip install sqlalchemy mysql-connector-python\n", "\n", "!pip install pandas\n", "!pip install matplotlib\n", "!pip install scikit-learn\n", "!pip install ydata_profiling\n", "!pip install ipywidgets" ] }, { "cell_type": "markdown", "id": "c1090b5b-54c2-40f0-9626-dea7d94b9b6d", "metadata": {}, "source": [ "# Load Libraries" ] }, { "cell_type": "code", "execution_count": 2, "id": "d4b2f0d1-fc1c-45ca-92bb-7db94de10749", "metadata": {}, "outputs": [], "source": [ "from sqlalchemy import create_engine, text\n", "from ydata_profiling import ProfileReport\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.metrics import roc_curve, auc, accuracy_score, confusion_matrix, classification_report\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import numpy as np" ] }, { "cell_type": "markdown", "id": "baeb36de-752b-42dd-aa0a-7db5dcef545b", "metadata": {}, "source": [ "# Connect to the Database" ] }, { "cell_type": "code", "execution_count": 3, "id": "9e51ef55-9662-45cf-b5bf-458f37a53d1d", "metadata": {}, "outputs": [], "source": [ "# Connect to MySQL\n", "connection_string = 'mysql+mysqlconnector://root:pw@172.17.0.1:3306/titanic'" ] }, { "cell_type": "code", "execution_count": 4, "id": "d3fa9c72-37c9-4ffd-8e37-6b61279b3f02", "metadata": {}, "outputs": [], "source": [ "engine = create_engine(connection_string)\n", "conn = engine.connect()" ] }, { "cell_type": "markdown", "id": "f5812623-59a6-4bd3-ac39-09daed7cd2f0", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## Test Connection" ] }, { "cell_type": "code", "execution_count": 5, "id": "fc04a792-acf9-48b1-8127-ef5f264afc32", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " salutation first_name last_name\n", "0 Miss Helen Loraine Allison\n", "1 Mr Hudson Joshua Creighton Allison\n", "2 Mrs Hudson J C (Bessie Waldo Daniels) Allison\n", "3 Mr Thomas Jr Andrews\n", "4 Mr Ramon Artagaveytia\n", ".. ... ... ...\n", "792 Miss Hileni Zabour\n", "793 Miss Thamine Zabour\n", "794 Mr Mapriededer Zakarian\n", "795 Mr Ortin Zakarian\n", "796 Mr Leo Zimmerman\n", "\n", "[797 rows x 3 columns]\n" ] } ], "source": [ "query = text(\"SELECT salutation, first_name, last_name FROM passengers WHERE survived = false;\")\n", "result = conn.execute(query)\n", "df = pd.DataFrame(result.fetchall(), columns=result.keys())\n", "\n", "# Display the results\n", "print(df)" ] }, { "cell_type": "markdown", "id": "d477c239-bb05-4f5b-bcd1-c0f353971335", "metadata": {}, "source": [ "# Fetch all the data to pandas" ] }, { "cell_type": "code", "execution_count": 6, "id": "4d272bb4-daa3-4b75-b09c-d100692abe05", "metadata": {}, "outputs": [], "source": [ "query = text(\"SELECT * FROM passengers;\")\n", "result = conn.execute(query)\n", "df = pd.DataFrame(result.fetchall(), columns=result.keys())" ] }, { "cell_type": "markdown", "id": "57e0ecb7-d73b-4a97-b683-d48dcd148ca7", "metadata": {}, "source": [ "# Analyze the dataset distribution aimed towards the survived target class" ] }, { "cell_type": "markdown", "id": "ccb8f573-f5fe-4e62-86fa-3923a2071425", "metadata": {}, "source": [ "## Bin by age and see survived distribution" ] }, { "cell_type": "code", "execution_count": 7, "id": "732bb2e5-be75-4ea7-a55f-f300393d574d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Survived percentage\n", "age_bin \n", "(0, 10] 58.139535\n", "(10, 20] 39.490446\n", "(20, 30] 36.676218\n", "(30, 40] 42.307692\n", "(40, 50] 39.230769\n", "(50, 60] 49.180328\n", "(60, 70] 22.222222\n", "(70, 80] 33.333333\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Create bins\n", "min_age = df['age'].min()\n", "max_age = df['age'].max()\n", "bins = range(int(min_age // 10 * 10), int((max_age // 10 + 1) * 10), 10)\n", "\n", "# Cut the 'age' column into bins\n", "df['age_bin'] = pd.cut(df['age'], bins=bins)\n", "\n", "# Count the frequency of each bin for both 'age' and 'survived'\n", "age_counts = df['age_bin'].value_counts().sort_index()\n", "survived_counts = df.groupby('age_bin', observed=False)['survived'].sum()\n", "\n", "# Calculate percentage of survivors per age bin\n", "survival_percentages_age = (survived_counts / age_counts) * 100\n", "\n", "# Plot the results\n", "bar_width = 0.35\n", "fig, ax = plt.subplots()\n", "age_bar = ax.bar(age_counts.index.astype(str), age_counts.values, bar_width, label='Did not survived')\n", "survived_bar = ax.bar(age_counts.index.astype(str), survived_counts, bar_width, label='Survived')\n", "\n", "ax.set_xlabel('Age Bins')\n", "ax.set_ylabel('Frequency')\n", "ax.set_title('Age Distribution and Survival')\n", "ax.legend()\n", "\n", "plt.xticks(rotation=45)\n", "plt.show\n", "\n", "# Create a DataFrame for the table\n", "table_data = pd.DataFrame({'Survived percentage': survival_percentages_age})\n", "\n", "print(table_data)" ] }, { "cell_type": "markdown", "id": "edab23bd-faa7-49be-933d-81d110ec58b1", "metadata": {}, "source": [ "## See ticket class and survived distribution" ] }, { "cell_type": "code", "execution_count": 8, "id": "ae056dc8-aea5-4426-aaad-d220a2bcc206", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " Survived percentage\n", "pclass \n", "1 61.919505\n", "2 42.745098\n", "3 25.528914\n" ] } ], "source": [ "# Count the frequency of each category for both 'embarked' and 'survived'\n", "pclass_counts = df['pclass'].value_counts().sort_index()\n", "survived_counts = df.groupby('pclass')['survived'].sum()\n", "\n", "# Calculate percentage of survivors\n", "survival_percentages_class = (survived_counts / pclass_counts) * 100\n", "\n", "# Plot the results\n", "bar_width = 0.35\n", "fig, ax = plt.subplots()\n", "embarked_bar = ax.bar(pclass_counts.index, pclass_counts, bar_width, label='Did not survive')\n", "survived_bar = ax.bar(pclass_counts.index, survived_counts, bar_width, label='Survived')\n", "\n", "ax.set_xlabel('Ticket class')\n", "ax.set_ylabel('Frequency')\n", "ax.set_title('Survival by Ticket Class')\n", "ax.legend()\n", "\n", "# Add labels for each class\n", "labels = {1: 'Upper', 2: 'Middle', 3: 'Lower'}\n", "ax.set_xticks(pclass_counts.index)\n", "ax.set_xticklabels([labels[x] for x in pclass_counts.index])\n", "\n", "plt.show()\n", "\n", "# Create a DataFrame for the table\n", "table_data = pd.DataFrame({'Survived percentage': survival_percentages_class})\n", "print(table_data)\n" ] }, { "cell_type": "markdown", "id": "116f00d1-9b9a-495c-9e51-fa083d892d07", "metadata": {}, "source": [ "## See sex and survived distribution" ] }, { "cell_type": "code", "execution_count": 9, "id": "29f8187d-e298-413d-950a-427c335db207", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " Survived percentage\n", "sex \n", "female 72.149123\n", "male 19.374248\n" ] } ], "source": [ "# Count the frequency of each category for both 'sex' and 'survived'\n", "sex_counts = df['sex'].value_counts()\n", "survived_counts_by_sex = df.groupby('sex')['survived'].sum()\n", "\n", "# Calculate percentage of survivors\n", "survival_percentages_sex = (survived_counts_by_sex / sex_counts) * 100\n", "\n", "# Plot the results\n", "fig, ax = plt.subplots()\n", "bar_width = 0.35\n", "sex_bar = ax.bar(sex_counts.index, sex_counts, bar_width, label='Not Survived')\n", "survived_bar = ax.bar(survived_counts_by_sex.index, survived_counts_by_sex, bar_width, label='Survived')\n", "\n", "ax.set_xlabel('Sex')\n", "ax.set_ylabel('Frequency')\n", "ax.set_title('Survival by Sex')\n", "ax.legend()\n", "\n", "plt.show()\n", "\n", "# Create a DataFrame for the table\n", "table_data = pd.DataFrame({'Survived percentage': survival_percentages_sex})\n", "print(table_data)" ] }, { "cell_type": "markdown", "id": "b0130007-d0cc-4358-9ea1-b9ac8c77b040", "metadata": {}, "source": [ "## See fare and survived distribution" ] }, { "cell_type": "code", "execution_count": 10, "id": "d072465f-bcc1-40ab-9bbf-9b6e604c14d7", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " Survived percentage\n", "fare_bin \n", "(0, 50] 32.038835\n", "(50, 100] 63.225806\n", "(100, 150] 78.787879\n", "(150, 200] 61.538462\n", "(200, 250] 57.142857\n", "(250, 300] 76.923077\n" ] } ], "source": [ "# Create bins for fare\n", "min_fare = df['fare'].min()\n", "max_fare = df['fare'].max()\n", "bins = np.arange(int(min_fare // 50) * 50, int((max_fare // 50 + 1) * 50), 50)\n", "\n", "# Cut the 'fare' column into bins\n", "df['fare_bin'] = pd.cut(df['fare'], bins=bins)\n", "\n", "\n", "# Count the frequency of each bin for both 'fare' and 'survived'\n", "fare_counts = df['fare_bin'].value_counts().sort_index()\n", "survived_counts = df.groupby('fare_bin', observed=False)['survived'].sum()\n", "\n", "# Remove bins with no data\n", "fare_counts = fare_counts[fare_counts > 0]\n", "survived_counts = survived_counts[fare_counts.index] # Update survived counts accordingly\n", "\n", "# Calculate percentage of survivors per fare bin\n", "survival_percentages_fare = (survived_counts / fare_counts) * 100\n", "\n", "# Plot the results\n", "bar_width = 0.35\n", "fig, ax = plt.subplots()\n", "fare_bar = ax.bar(fare_counts.index.astype(str), fare_counts.values, bar_width, label='Did not survived')\n", "survived_bar = ax.bar(fare_counts.index.astype(str), survived_counts, bar_width, label='Survived')\n", "\n", "ax.set_xlabel('Fare Bins')\n", "ax.set_ylabel('Frequency')\n", "ax.set_title('Fare Distribution and Survival')\n", "ax.legend()\n", "\n", "plt.xticks(rotation=45)\n", "plt.show()\n", "\n", "# Create a DataFrame for the table\n", "table_data = pd.DataFrame({'Survived percentage': survival_percentages_fare})\n", "print(table_data)\n" ] }, { "cell_type": "markdown", "id": "38843fc3-81fb-4d4a-9045-eae86ff6bd1e", "metadata": {}, "source": [ "# Correlation Heatmap\n", "A correlation heatmap is a visual representation of the correlation between variables in a dataset. It uses colors to show the strength and direction of these relationships.\n", "\n", "## How to read\n", "- Color Gradient: Colors indicate correlation strength, with red for positive correlation and blue for negative correlation.\n", "- Diagonal Line: Represents perfect correlation (1) of variables with themselves.\n", "- Symmetry: The heatmap is symmetrical around the diagonal." ] }, { "cell_type": "code", "execution_count": 11, "id": "126eb3d9-67a0-4b00-8c57-860d10b5645e", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Categorize 'pclass'\n", "pclass_mapping = {1: 'Upper', 2: 'Middle', 3: 'Lower'}\n", "df['pclass_category'] = df['pclass'].map(pclass_mapping)\n", "\n", "# Calculate correlation coefficients for categorical parameters\n", "categorical_params = ['pclass_category', 'sex', 'embarked']\n", "correlation_data = pd.DataFrame(columns=categorical_params)\n", "for param in categorical_params:\n", " if df[param].dtype == 'O': # If parameter is categorical\n", " df[param + '_encoded'] = pd.Categorical(df[param]).codes\n", " correlation_data[param] = df[param + '_encoded']\n", "\n", "# Add 'age', 'fare', 'sibsp', 'parch', and 'survived' to correlation data\n", "correlation_data['age'] = df['age']\n", "correlation_data['fare'] = df['fare']\n", "correlation_data['sibsp'] = df['sibsp']\n", "correlation_data['parch'] = df['parch']\n", "correlation_data['survived'] = df['survived']\n", "\n", "# Calculate correlation matrix\n", "correlation_matrix = correlation_data.corr()\n", "\n", "# Plot correlation heatmap\n", "plt.figure(figsize=(10, 6))\n", "sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=\".2f\", linewidths=0.5)\n", "plt.title('Correlation Heatmap')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a722f2a8-f219-4aee-9324-5af2f4ff6750", "metadata": {}, "source": [ "# Generate general report" ] }, { "cell_type": "code", "execution_count": 12, "id": "6d7d55c5-16db-4b46-abc4-bc9b00b6066a", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "963786f49b904d95b40120ed1b3ecdf7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Summarize dataset: 0%| | 0/5 [00:00" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df.drop(columns=['age_bin', 'fare_bin'], inplace=True)\n", "profile = ProfileReport(df, title=\"Profiling Report\")\n", "\n", "profile.to_notebook_iframe()" ] }, { "cell_type": "markdown", "id": "400e2a71-2d6b-41d5-a553-618e302e5f68", "metadata": {}, "source": [ "# Multivariable logistic Regression\n", "A multivariable logistic regression with a binary target variable aims to predict a binary outcome based on multiple predictor variables.\n", "\n", "- Objective: The goal is to predict the probability of a binary outcome (e.g., Yes/No, 0/1) based on the values of multiple predictor variables.\n", "- Model: It uses a regression model (e.g., logistic regression) to estimate the relationship between the predictors and the probability of the binary outcome.\n", "- Prediction: Given values for the predictor variables, the model calculates the predicted probability of the binary outcome. This probability can then be converted into a binary decision based on a chosen threshold.\n", "\n", "## Interpretation of regressions metrics\n", "- Accuracy\n", " - The accuracy score measures the proportion of correct predictions made by the model out of the total number of predictions.\n", " - It is calculated as the number of correct predictions divided by the total number of predictions.\n", "- Confusion Matrix\n", " - A confusion matrix is a table that is often used to describe the performance of a classification model on a set of test data for which the true values are known.\n", " - It's a matrix with four different combinations of predicted and actual classes: True Positive (TP), False Positive (FP), True Negative (TN), and False Negative (FN).\n", " - It provides a more detailed breakdown of the model's performance than just the accuracy score.\n", "- Classification Report\n", " - Precision: Out of all predicted positive instances, how many were actually positive.\n", " - Recall (Sensitivity): Out of all actual positive instances, how many were predicted correctly.\n", " - F1-score: Harmonic mean of precision and recall. It gives a balance between precision and recall.\n", " - Support: Number of actual occurrences of the class in the dataset.\n", "\n", "## ROC Curve (Receiver Operating Characteristic Curve)\n", "The ROC curve is a graphical representation of the performance of a binary classification model across different thresholds. It plots the True Positive Rate (sensitivity) against the False Positive Rate (1 - specificity) for various threshold values. Each point on the curve represents a sensitivity/specificity pair corresponding to a particular decision threshold.\n", "\n", "### Interpretation of ROC Area (AUC):\n", "The Area Under the ROC Curve (AUC) quantifies the overall performance of a binary classification model. It ranges from 0 to 1, where:\n", "\n", "- AUC = 1 indicates a perfect classifier that perfectly separates the classes.\n", "- AUC = 0.5 indicates a classifier that performs no better than random guessing.\n", "- AUC < 0.5 indicates a classifier that performs worse than random guessing (inverted predictions).\n", "\n", "Typically, an AUC above 0.7 is considered acceptable, while an AUC above 0.8 is considered good discrimination." ] }, { "cell_type": "code", "execution_count": 13, "id": "8f122f3d-bfbf-4270-b394-e36f1bc8998e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.7622739018087855\n", "Confusion Matrix:\n", " [[193 40]\n", " [ 52 102]]\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 0.79 0.83 0.81 233\n", " 1 0.72 0.66 0.69 154\n", "\n", " accuracy 0.76 387\n", " macro avg 0.75 0.75 0.75 387\n", "weighted avg 0.76 0.76 0.76 387\n", "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Select independent variables (X) and dependent variable (y)\n", "independent_variables = df[['age', 'fare', 'pclass', 'sex']]\n", "dependent_variable = df['survived']\n", "\n", "# Perform one-hot encoding for the 'sex' column\n", "# Convert categorical variables into a numerical format\n", "column_transformer = ColumnTransformer(\n", " [('one_hot_encoder', OneHotEncoder(), ['sex'])], \n", " remainder='passthrough'\n", ")\n", "independent_variables_encoded = column_transformer.fit_transform(independent_variables)\n", "\n", "# Split data into training and testing sets\n", "# Changing test_size to increment/decrease the sample size\n", "X_train, X_test, y_train, y_test = train_test_split(independent_variables_encoded, dependent_variable, test_size=0.3, random_state=42)\n", "\n", "# Initialize the logistic regression model\n", "model = LogisticRegression()\n", "\n", "# Fit the model to the training data\n", "model.fit(X_train, y_train)\n", "\n", "# Predict probabilities on the test set\n", "#1: to predict survivors\n", "#0: to predict non survivors\n", "y_pred_proba = model.predict_proba(X_test)[:, 1]\n", "\n", "# Compute ROC curve and ROC area for each class\n", "fpr, tpr, _ = roc_curve(y_test, y_pred_proba)\n", "roc_auc = auc(fpr, tpr)\n", "\n", "# Make predictions on the testing data\n", "y_pred = model.predict(X_test)\n", "\n", "# Evaluate the model\n", "accuracy = accuracy_score(y_test, y_pred)\n", "conf_matrix = confusion_matrix(y_test, y_pred)\n", "class_report = classification_report(y_test, y_pred)\n", "\n", "print(\"Accuracy:\", accuracy)\n", "print(\"Confusion Matrix:\\n\", conf_matrix)\n", "print(\"Classification Report:\\n\", class_report)\n", "\n", "# Plot ROC curve\n", "plt.figure()\n", "plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)\n", "plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "plt.title('Receiver Operating Characteristic (ROC) Curve')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }