added plots

main
Felix Jan Michael Mucha 2025-02-16 11:42:38 +01:00
parent 282cb128e3
commit cd2e2e5858
3 changed files with 391 additions and 115 deletions

View File

@ -1,6 +1,11 @@
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.cm as cm
import scipy.stats as stats
import matplotlib.gridspec as gridspec
from sklearn.linear_model import LinearRegression
import os
import time
@ -72,4 +77,87 @@ def plot_predictions(true_values, predicted_values, colors, title='True vs Predi
# save plot
if save:
save_plot(plt, title)
return plt
return plt
def plot_residuals(labels, preds, colors, title='Residuals Plot', save=True):
residuals = np.array(preds) - np.array(labels)
fig = plt.figure(figsize=(14, 6))
gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
# Main plot
ax0 = plt.subplot(gs[0])
ax0.scatter(labels, residuals, label='Residuals', color=colors['blue'], alpha=0.5)
# Fit linear regression model to residuals
labels_reshaped = np.array(labels).reshape(-1, 1)
model = LinearRegression()
model.fit(labels_reshaped, residuals)
trend_line = model.predict(labels_reshaped)
# Plot trend line
ax0.plot(labels, trend_line, color=colors['red'], label='Trend Line', linewidth=2)
ax0.set_xlabel('True Values')
ax0.set_ylabel('Residuals')
ax0.axhline(y=0, color='k', linestyle='--')
ax0.set_title(title)
ax0.legend()
# Side plot for distribution of true values
ax1 = plt.subplot(gs[1], sharey=ax0)
ax1.hist(residuals, bins=30, alpha=0.5, color=colors['blue'], orientation='horizontal')
ax1.set_xlabel('Frequency')
ax1.set_title('Distribution of residuals')
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
plt.tight_layout()
# save plot
if save:
save_plot(plt, title)
return plt
def plot_qq(labels, preds, colors, title='Q-Q Plot of Residuals', save=True):
residuals = np.array(preds) - np.array(labels)
# Generate a Normal Q-Q plot
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
stats.probplot(residuals, dist="norm", plot=ax)
# Set colors
line = ax.get_lines()
line[0].set_color(colors['blue']) # Data points
line[1].set_color(colors['red']) # Fit line
plt.title(title)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_val_preds(val_preds, val_labels, colors, title='Histogram of Validation Predictions', save=True):
plt.figure(figsize=(10, 6))
plt.hist(val_labels, bins=20, alpha=0.5, label='True Values', color=colors['green'],)
cmap = cm.get_cmap('coolwarm', len(val_preds)) # Use 'coolwarm' colormap for gradient from red to blue
for epoch, preds in val_preds.items():
color = cmap(len(val_preds) - epoch ) # Get color from colormap
plt.hist(preds, bins=20, alpha=0.5, label=f'Epoch {epoch}', color=color)
plt.xlabel('Predicted Values')
plt.ylabel('Frequency')
plt.title(title)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
# save plot
if save:
save_plot(plt, title)
return plt
####################################################################################################
############### Comparison Plots ###################################################################
####################################################################################################

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long