ich kann das alles nicht mehr

main
arman 2025-02-16 20:43:16 +01:00
parent c9109e1430
commit 8e8b8612da
2 changed files with 329 additions and 1 deletions

View File

@ -0,0 +1,328 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import ml_helper\n",
"import ml_plots\n",
"import re\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
"import matplotlib.gridspec as gridspec\n",
"from sklearn.linear_model import LinearRegression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"COLORS = {\n",
" \n",
" 'yellow': '#FFBE0B',\n",
" 'orange': '#FB5607',\n",
" 'blue': '#3A86FF',\n",
" 'pink': '#FF006E' ,\n",
" 'lila': '#8338EC',\n",
" 'rot': '#2a9d8f' \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load latest data if keyword is in the file name\n",
"hist_name_cnn_norm = ml_helper.get_newest_file('histories/', name='CNN_history', extension=\".json\", ensemble=False)\n",
"print(f\"Loading {hist_name_cnn_norm}\")\n",
"\n",
"hist_name_cnn = ml_helper.get_newest_file('histories/', name='CNN', extension=\".json\", ensemble=True)\n",
"print(f\"Loading {hist_name_cnn}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(\"histories/ensemble_preds_CNN_2025-02-16_18-06-10.json\", 'r') as file:\n",
" ensemble_avg_prediction = json.load(file)\n",
" print(len(ensemble_avg_prediction))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(hist_name_cnn_norm, 'r') as file:\n",
" hist_cnn_norm = json.load(file)\n",
"\n",
"labels_cnn_norm = hist_cnn_norm['test_labels']\n",
"preds_cnn_norm = hist_cnn_norm['test_preds']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(hist_name_cnn[0], 'r') as file:\n",
" hist_cnn = json.load(file)\n",
"\n",
"labels_cnn = hist_cnn['test_labels']\n",
"preds_cnn = hist_cnn['test_preds']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(hist_name_cnn[1], 'r') as file:\n",
" hist_cnn1 = json.load(file)\n",
"\n",
"labels_cnn1 = hist_cnn1['test_labels']\n",
"preds_cnn1 = hist_cnn1['test_preds']"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'hist_name_transformer' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Input \u001b[1;32mIn [2]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[43mhist_name_transformer\u001b[49m[\u001b[38;5;241m2\u001b[39m], \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m file:\n\u001b[0;32m 2\u001b[0m hist_transformer2 \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(file)\n\u001b[0;32m 4\u001b[0m labels_t2 \u001b[38;5;241m=\u001b[39m hist_transformer2[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_labels\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
"\u001b[1;31mNameError\u001b[0m: name 'hist_name_transformer' is not defined"
]
}
],
"source": [
"with open(hist_name_cnn[2], 'r') as file:\n",
" hist_cnn2 = json.load(file)\n",
"\n",
"labels_cnn2 = hist_cnn2['test_labels']\n",
"preds_cnn2 = hist_cnn2['test_preds']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(hist_name_cnn[4], 'r') as file:\n",
" hist_cnn3 = json.load(file)\n",
"\n",
"labels_cnn3 = hist_cnn3['test_labels']\n",
"preds_cnn3 = hist_cnn3['test_preds']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(hist_name_cnn[5], 'r') as file:\n",
" hist_cnn4 = json.load(file)\n",
"\n",
"labels_cnn4 = hist_cnn4['test_labels']\n",
"preds_cnn4 = hist_cnn4['test_preds']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def save_plot(plt, title):\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_training_histories(hist_datas, names, colors, title='Training History', include_val=True, save=True):\n",
" fig, axs = plt.subplots(1, 2, figsize=(12, 5))\n",
" color_keys = list(colors.keys())\n",
" color_counter = 0\n",
"\n",
" for hist_data, name in zip(hist_datas, names):\n",
" epochs = range(1, len(hist_data['train_loss']) + 1)\n",
"\n",
" color = colors[color_keys[color_counter]]\n",
" color_counter = (color_counter + 1) % len(color_keys)\n",
"\n",
" # Plot RMSE\n",
" axs[1].plot(epochs, hist_data['train_rmse'], label=f'{name} Train RMSE', color=color)\n",
" if include_val:\n",
" axs[1].plot(epochs, hist_data['val_rmse'], label=f'{name} Validation RMSE', color=color, linestyle='dashed')\n",
" axs[1].set_title('RMSE')\n",
" axs[1].set_xlabel('Epochs')\n",
" axs[1].set_ylabel('RMSE')\n",
" axs[1].legend()\n",
"\n",
" # Plot Loss\n",
" axs[0].plot(epochs, hist_data['train_loss'], label=f'{name} Train Loss', color=color)\n",
" if include_val:\n",
" axs[0].plot(epochs, hist_data['val_loss'], label=f'{name} Validation Loss', color=color, linestyle='dashed')\n",
" axs[0].set_title('Loss')\n",
" axs[0].set_xlabel('Epochs')\n",
" axs[0].set_ylabel('Loss')\n",
" axs[0].legend()\n",
"\n",
" plt.tight_layout()\n",
" plt.suptitle(title)\n",
"\n",
" # Save plot\n",
" if save:\n",
" save_plot(plt, title)\n",
" \n",
" return plt\n",
"\n",
"plot_training_histories([hist_cnn, hist_cnn2, hist_cnn3], ['Average', 'Transformer1', 'Transformer2'], colors=COLORS, include_val=False).show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_distributions(true_values, predicted_values_list, names, colors, title='Distribution of Predicted and True Values', save=True):\n",
" plt.figure(figsize=(10, 6))\n",
" color_keys = list(colors.keys())\n",
" color_counter = 0\n",
"\n",
" # Plot true values\n",
" plt.hist(true_values, bins=20, color=colors[\"blue\"], edgecolor='black', alpha=0.7, label='True Values')\n",
"\n",
" # Plot predicted values for each model\n",
" for predicted_values, name in zip(predicted_values_list, names):\n",
" color = colors[color_keys[color_counter]]\n",
" color_counter = (color_counter + 1) % len(color_keys)\n",
" plt.hist(predicted_values, bins=20, color=color, edgecolor='black', alpha=0.7, label=f'{name} Predicted Values')\n",
"\n",
" plt.title(title)\n",
" plt.xlabel('Score')\n",
" plt.ylabel('Frequency')\n",
" plt.legend()\n",
" plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
" # save plot\n",
" if save:\n",
" save_plot(plt, title)\n",
" return plt\n",
"\n",
"plot_distributions(labels_cnn,[ preds_cnn2, preds_cnn3, ensemble_avg_prediction ], [ 'Transformer1', 'Transformer2', 'Average'], colors=COLORS).show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def plot_multiple_residuals(labels, preds_list, names, colors, title='Residuals Plot', save=True):\n",
" fig = plt.figure(figsize=(14, 6))\n",
" gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])\n",
" color_keys = list(colors.keys())\n",
" color_counter = 0\n",
"\n",
" # Main plot\n",
" ax0 = plt.subplot(gs[0])\n",
"\n",
" for preds, name in zip(preds_list, names):\n",
" residuals = np.array(preds) - np.array(labels)\n",
" color = colors[color_keys[color_counter]]\n",
" color_counter = (color_counter + 1) % len(color_keys)\n",
"\n",
" ax0.scatter(labels, residuals, label=f'{name} Residuals', color=color, alpha=0.5)\n",
"\n",
" # Fit linear regression model to residuals\n",
" labels_reshaped = np.array(labels).reshape(-1, 1)\n",
" model = LinearRegression()\n",
" model.fit(labels_reshaped, residuals)\n",
" trend_line = model.predict(labels_reshaped)\n",
"\n",
" # Plot trend line\n",
" ax0.plot(labels, trend_line, color=color, label=f'{name} Trend Line', linewidth=2)\n",
"\n",
" ax0.set_xlabel('True Values')\n",
" ax0.set_ylabel('Residuals')\n",
" ax0.axhline(y=0, color='k', linestyle='--')\n",
" ax0.set_title(title)\n",
" ax0.legend()\n",
"\n",
" color_counter = 0\n",
" # Side plot for distribution of residuals\n",
" ax1 = plt.subplot(gs[1], sharey=ax0)\n",
" for preds, name in zip(preds_list, names):\n",
" residuals = np.array(preds) - np.array(labels)\n",
" color = colors[color_keys[color_counter]]\n",
" color_counter = (color_counter + 1) % len(color_keys)\n",
" ax1.hist(residuals, bins=30, alpha=0.5, color=color, orientation='horizontal', label=f'{name} Residuals')\n",
"\n",
" ax1.set_xlabel('Frequency')\n",
" ax1.set_title('Distribution of Residuals')\n",
" ax1.yaxis.tick_right()\n",
" ax1.yaxis.set_label_position(\"right\")\n",
"\n",
" plt.tight_layout()\n",
" # Save plot\n",
" if save:\n",
" save_plot(plt, title)\n",
" \n",
" return plt\n",
"plot_multiple_residuals(labels_cnn, [preds_cnn2, preds_cnn3, ensemble_avg_prediction], ['Transformer1', 'Transformer2', 'Average'], colors=COLORS).show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -46,7 +46,7 @@
"hist_name_trans_norm = ml_helper.get_newest_file('histories/', name='Transformer_history', extension=\".json\", ensemble=False)\n",
"print(f\"Loading {hist_name_trans_norm}\")\n",
"\n",
"hist_name_transformer = get_newest_file('histories/', name='Transformer', extension=\".json\", ensemble=True)\n",
"hist_name_transformer = ml_helper.get_newest_file('histories/', name='Transformer', extension=\".json\", ensemble=True)\n",
"print(f\"Loading {hist_name_transformer}\")"
]
},