diff --git a/model_comparison_bs_trans.ipynb b/model_comparison_bs_trans.ipynb new file mode 100644 index 0000000..2757219 --- /dev/null +++ b/model_comparison_bs_trans.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import ml_helper\n", + "import ml_plots\n", + "import re\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n", + "import matplotlib.gridspec as gridspec\n", + "from sklearn.linear_model import LinearRegression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "COLORS = {\n", + " \n", + " 'yellow': '#FFBE0B',\n", + " 'orange': '#FB5607',\n", + " 'blue': '#3A86FF',\n", + " 'pink': '#FF006E' ,\n", + " 'lila': '#8338EC',\n", + " 'rot': '#2a9d8f' \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load latest data if keyword is in the file name\n", + "hist_name_trans_norm = ml_helper.get_newest_file('histories/', name='Transformer_history', extension=\".json\", ensemble=False)\n", + "print(f\"Loading {hist_name_trans_norm}\")\n", + "\n", + "hist_name_transformer = get_newest_file('histories/', name='Transformer', extension=\".json\", ensemble=True)\n", + "print(f\"Loading {hist_name_transformer}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"histories/ensemble_preds_Transformer_2025-02-16_17-35-16.json\", 'r') as file:\n", + " ensemble_avg_prediction = json.load(file)\n", + " print(len(ensemble_avg_prediction))\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(hist_name_trans_norm, 'r') as file:\n", + " hist_trans_norm = json.load(file)\n", + "\n", + "labels_trans_norm = hist_trans_norm['test_labels']\n", + "preds_trans_norm = hist_trans_norm['test_preds']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(hist_name_transformer[0], 'r') as file:\n", + " hist_transformer = json.load(file)\n", + "\n", + "labels_t = hist_transformer['test_labels']\n", + "preds_t = hist_transformer['test_preds']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(hist_name_transformer[1], 'r') as file:\n", + " hist_transformer1 = json.load(file)\n", + "\n", + "labels_t1 = hist_transformer1['test_labels']\n", + "preds_t1 = hist_transformer1['test_preds']" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'hist_name_transformer' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", + "Input \u001b[1;32mIn [2]\u001b[0m, in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[43mhist_name_transformer\u001b[49m[\u001b[38;5;241m2\u001b[39m], \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m file:\n\u001b[0;32m 2\u001b[0m hist_transformer2 \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(file)\n\u001b[0;32m 4\u001b[0m labels_t2 \u001b[38;5;241m=\u001b[39m hist_transformer2[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_labels\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", + "\u001b[1;31mNameError\u001b[0m: name 'hist_name_transformer' is not defined" + ] + } + ], + "source": [ + "with open(hist_name_transformer[2], 'r') as file:\n", + " hist_transformer2 = json.load(file)\n", + "\n", + "labels_t2 = hist_transformer2['test_labels']\n", + "preds_t2 = hist_transformer2['test_preds']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(hist_name_transformer[4], 'r') as file:\n", + " hist_transformer3 = json.load(file)\n", + "\n", + "labels_t3 = hist_transformer3['test_labels']\n", + "preds_t3 = hist_transformer3['test_preds']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(hist_name_transformer[5], 'r') as file:\n", + " hist_transformer4 = json.load(file)\n", + "\n", + "labels_t4 = hist_transformer4['test_labels']\n", + "preds_t4 = hist_transformer4['test_preds']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def save_plot(plt, title):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_training_histories(hist_datas, names, colors, title='Training History', include_val=True, save=True):\n", + " fig, axs = plt.subplots(1, 2, figsize=(12, 5))\n", + " color_keys = list(colors.keys())\n", + " color_counter = 0\n", + "\n", + " for hist_data, name in zip(hist_datas, names):\n", + " epochs = range(1, len(hist_data['train_loss']) + 1)\n", + "\n", + " color = colors[color_keys[color_counter]]\n", + " color_counter = (color_counter + 1) % len(color_keys)\n", + "\n", + " # Plot RMSE\n", + " axs[1].plot(epochs, hist_data['train_rmse'], label=f'{name} Train RMSE', color=color)\n", + " if include_val:\n", + " axs[1].plot(epochs, hist_data['val_rmse'], label=f'{name} Validation RMSE', color=color, linestyle='dashed')\n", + " axs[1].set_title('RMSE')\n", + " axs[1].set_xlabel('Epochs')\n", + " axs[1].set_ylabel('RMSE')\n", + " axs[1].legend()\n", + "\n", + " # Plot Loss\n", + " axs[0].plot(epochs, hist_data['train_loss'], label=f'{name} Train Loss', color=color)\n", + " if include_val:\n", + " axs[0].plot(epochs, hist_data['val_loss'], label=f'{name} Validation Loss', color=color, linestyle='dashed')\n", + " axs[0].set_title('Loss')\n", + " axs[0].set_xlabel('Epochs')\n", + " axs[0].set_ylabel('Loss')\n", + " axs[0].legend()\n", + "\n", + " plt.tight_layout()\n", + " plt.suptitle(title)\n", + "\n", + " # Save plot\n", + " if save:\n", + " save_plot(plt, title)\n", + " \n", + " return plt\n", + "\n", + "plot_training_histories([hist_transformer, hist_transformer2, hist_transformer3], ['Average', 'Transformer1', 'Transformer2'], colors=COLORS, include_val=False).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_distributions(true_values, predicted_values_list, names, colors, title='Distribution of Predicted and True Values', save=True):\n", + " plt.figure(figsize=(10, 6))\n", + " color_keys = list(colors.keys())\n", + " color_counter = 0\n", + "\n", + " # Plot true values\n", + " plt.hist(true_values, bins=20, color=colors[\"blue\"], edgecolor='black', alpha=0.7, label='True Values')\n", + "\n", + " # Plot predicted values for each model\n", + " for predicted_values, name in zip(predicted_values_list, names):\n", + " color = colors[color_keys[color_counter]]\n", + " color_counter = (color_counter + 1) % len(color_keys)\n", + " plt.hist(predicted_values, bins=20, color=color, edgecolor='black', alpha=0.7, label=f'{name} Predicted Values')\n", + "\n", + " plt.title(title)\n", + " plt.xlabel('Score')\n", + " plt.ylabel('Frequency')\n", + " plt.legend()\n", + " plt.grid(axis='y', linestyle='--', alpha=0.7)\n", + " # save plot\n", + " if save:\n", + " save_plot(plt, title)\n", + " return plt\n", + "\n", + "plot_distributions(labels_cnn,[ preds_t2, preds_t3, ensemble_avg_prediction ], [ 'Transformer1', 'Transformer2', 'Average'], colors=COLORS).show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def plot_multiple_residuals(labels, preds_list, names, colors, title='Residuals Plot', save=True):\n", + " fig = plt.figure(figsize=(14, 6))\n", + " gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])\n", + " color_keys = list(colors.keys())\n", + " color_counter = 0\n", + "\n", + " # Main plot\n", + " ax0 = plt.subplot(gs[0])\n", + "\n", + " for preds, name in zip(preds_list, names):\n", + " residuals = np.array(preds) - np.array(labels)\n", + " color = colors[color_keys[color_counter]]\n", + " color_counter = (color_counter + 1) % len(color_keys)\n", + "\n", + " ax0.scatter(labels, residuals, label=f'{name} Residuals', color=color, alpha=0.5)\n", + "\n", + " # Fit linear regression model to residuals\n", + " labels_reshaped = np.array(labels).reshape(-1, 1)\n", + " model = LinearRegression()\n", + " model.fit(labels_reshaped, residuals)\n", + " trend_line = model.predict(labels_reshaped)\n", + "\n", + " # Plot trend line\n", + " ax0.plot(labels, trend_line, color=color, label=f'{name} Trend Line', linewidth=2)\n", + "\n", + " ax0.set_xlabel('True Values')\n", + " ax0.set_ylabel('Residuals')\n", + " ax0.axhline(y=0, color='k', linestyle='--')\n", + " ax0.set_title(title)\n", + " ax0.legend()\n", + "\n", + " color_counter = 0\n", + " # Side plot for distribution of residuals\n", + " ax1 = plt.subplot(gs[1], sharey=ax0)\n", + " for preds, name in zip(preds_list, names):\n", + " residuals = np.array(preds) - np.array(labels)\n", + " color = colors[color_keys[color_counter]]\n", + " color_counter = (color_counter + 1) % len(color_keys)\n", + " ax1.hist(residuals, bins=30, alpha=0.5, color=color, orientation='horizontal', label=f'{name} Residuals')\n", + "\n", + " ax1.set_xlabel('Frequency')\n", + " ax1.set_title('Distribution of Residuals')\n", + " ax1.yaxis.tick_right()\n", + " ax1.yaxis.set_label_position(\"right\")\n", + "\n", + " plt.tight_layout()\n", + " # Save plot\n", + " if save:\n", + " save_plot(plt, title)\n", + " \n", + " return plt\n", + "plot_multiple_residuals(labels_cnn, [preds_t2, preds_t3, ensemble_avg_prediction], ['Transformer1', 'Transformer2', 'Average'], colors=COLORS).show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}