added: single model eval plots
parent
8b655b58ca
commit
64b2d5e411
24
ml_plots.py
24
ml_plots.py
|
|
@ -11,23 +11,23 @@ def save_plot(plt, plot_name):
|
||||||
time_stamp = time.strftime('%Y%m%d-%H%M%S')
|
time_stamp = time.strftime('%Y%m%d-%H%M%S')
|
||||||
plt.savefig(f'plots/{plot_name}_{time_stamp}.png')
|
plt.savefig(f'plots/{plot_name}_{time_stamp}.png')
|
||||||
|
|
||||||
def plot_training_history(hist_data, title='Training History', save=True):
|
def plot_training_history(hist_data, colors, title='Training History', save=True):
|
||||||
|
|
||||||
epochs = range(1, len(hist_data['train_loss']) + 1)
|
epochs = range(1, len(hist_data['train_loss']) + 1)
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
|
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
|
||||||
|
|
||||||
# Plot accuracy
|
# Plot accuracy
|
||||||
axs[1].plot(epochs, hist_data['train_rmse'], label='Train RMSE')
|
axs[1].plot(epochs, hist_data['train_rmse'], label='Train RMSE', color=colors['blue'])
|
||||||
axs[1].plot(epochs, hist_data['val_rmse'], label='Validation RMSE')
|
axs[1].plot(epochs, hist_data['val_rmse'], label='Validation RMSE', color=colors['green'])
|
||||||
axs[1].set_title('RMSE')
|
axs[1].set_title('RMSE')
|
||||||
axs[1].set_xlabel('Epochs')
|
axs[1].set_xlabel('Epochs')
|
||||||
axs[1].set_ylabel('RMSE')
|
axs[1].set_ylabel('RMSE')
|
||||||
axs[1].legend()
|
axs[1].legend()
|
||||||
|
|
||||||
# Plot loss
|
# Plot loss
|
||||||
axs[0].plot(epochs, hist_data['train_loss'], label='Train Loss')
|
axs[0].plot(epochs, hist_data['train_loss'], label='Train Loss', color=colors['blue'])
|
||||||
axs[0].plot(epochs, hist_data['val_loss'], label='Validation Loss')
|
axs[0].plot(epochs, hist_data['val_loss'], label='Validation Loss', color=colors['green'])
|
||||||
axs[0].set_title('Loss')
|
axs[0].set_title('Loss')
|
||||||
axs[0].set_xlabel('Epochs')
|
axs[0].set_xlabel('Epochs')
|
||||||
axs[0].set_ylabel('Loss')
|
axs[0].set_ylabel('Loss')
|
||||||
|
|
@ -41,10 +41,10 @@ def plot_training_history(hist_data, title='Training History', save=True):
|
||||||
save_plot(plt, title)
|
save_plot(plt, title)
|
||||||
return plt
|
return plt
|
||||||
|
|
||||||
def plot_distribution(true_values, predicted_values, title='Distribution of Predicted and True Values', save=True):
|
def plot_distribution(true_values, predicted_values, colors, title='Distribution of Predicted and True Values', save=True):
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.hist(true_values, bins=20, color='skyblue', edgecolor='black', alpha=0.7, label='True Values')
|
plt.hist(true_values, bins=20, color=colors['green'], edgecolor='black', alpha=0.7, label='True Values')
|
||||||
plt.hist(predicted_values, bins=20, color='salmon', edgecolor='black', alpha=0.7, label='Predicted Values')
|
plt.hist(predicted_values, bins=20, color=colors['blue'], edgecolor='black', alpha=0.7, label='Predicted Values')
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
plt.xlabel('Score')
|
plt.xlabel('Score')
|
||||||
plt.ylabel('Frequency')
|
plt.ylabel('Frequency')
|
||||||
|
|
@ -55,15 +55,15 @@ def plot_distribution(true_values, predicted_values, title='Distribution of Pred
|
||||||
save_plot(plt, title)
|
save_plot(plt, title)
|
||||||
return plt
|
return plt
|
||||||
|
|
||||||
def plot_predictions(true_values, predicted_values, title='True vs Predicted Values', threshold=0.3, save=True):
|
def plot_predictions(true_values, predicted_values, colors, title='True vs Predicted Values', threshold=0.3, save=True):
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
# Difference between predicted and true values
|
# Difference between predicted and true values
|
||||||
correct_indices = np.isclose(true_values, predicted_values, atol=threshold)
|
correct_indices = np.isclose(true_values, predicted_values, atol=threshold)
|
||||||
incorrect_indices = ~correct_indices
|
incorrect_indices = ~correct_indices
|
||||||
# Plot
|
# Plot
|
||||||
plt.scatter(np.array(true_values)[correct_indices], np.array(predicted_values)[correct_indices], color='green', label='Correctly predicted')
|
plt.scatter(np.array(true_values)[correct_indices], np.array(predicted_values)[correct_indices], color=colors['green'], alpha=0.5, label='Correctly predicted')
|
||||||
plt.scatter(np.array(true_values)[incorrect_indices], np.array(predicted_values)[incorrect_indices], color='red', label='Incorrectly predicted')
|
plt.scatter(np.array(true_values)[incorrect_indices], np.array(predicted_values)[incorrect_indices], color=colors['red'], alpha=0.5, label='Incorrectly predicted')
|
||||||
plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], color='blue', linestyle='--', label='Ideal Line')
|
plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], color=colors['blue'], linestyle='--', label='Ideal Line')
|
||||||
plt.xlabel('True Values')
|
plt.xlabel('True Values')
|
||||||
plt.ylabel('Predicted Values')
|
plt.ylabel('Predicted Values')
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,7 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"vscode": {
|
|
||||||
"languageId": "plaintext"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# TODO: compare"
|
"# TODO: compare"
|
||||||
|
|
@ -15,8 +11,14 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"name": "python"
|
"name": "python",
|
||||||
|
"version": "3.10.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue