added plots

main
Felix Jan Michael Mucha 2025-02-16 11:42:38 +01:00
parent 282cb128e3
commit cd2e2e5858
3 changed files with 391 additions and 115 deletions

View File

@ -1,6 +1,11 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
import matplotlib.cm as cm
import scipy.stats as stats
import matplotlib.gridspec as gridspec
from sklearn.linear_model import LinearRegression
import os import os
import time import time
@ -72,4 +77,87 @@ def plot_predictions(true_values, predicted_values, colors, title='True vs Predi
# save plot # save plot
if save: if save:
save_plot(plt, title) save_plot(plt, title)
return plt return plt
def plot_residuals(labels, preds, colors, title='Residuals Plot', save=True):
residuals = np.array(preds) - np.array(labels)
fig = plt.figure(figsize=(14, 6))
gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
# Main plot
ax0 = plt.subplot(gs[0])
ax0.scatter(labels, residuals, label='Residuals', color=colors['blue'], alpha=0.5)
# Fit linear regression model to residuals
labels_reshaped = np.array(labels).reshape(-1, 1)
model = LinearRegression()
model.fit(labels_reshaped, residuals)
trend_line = model.predict(labels_reshaped)
# Plot trend line
ax0.plot(labels, trend_line, color=colors['red'], label='Trend Line', linewidth=2)
ax0.set_xlabel('True Values')
ax0.set_ylabel('Residuals')
ax0.axhline(y=0, color='k', linestyle='--')
ax0.set_title(title)
ax0.legend()
# Side plot for distribution of true values
ax1 = plt.subplot(gs[1], sharey=ax0)
ax1.hist(residuals, bins=30, alpha=0.5, color=colors['blue'], orientation='horizontal')
ax1.set_xlabel('Frequency')
ax1.set_title('Distribution of residuals')
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
plt.tight_layout()
# save plot
if save:
save_plot(plt, title)
return plt
def plot_qq(labels, preds, colors, title='Q-Q Plot of Residuals', save=True):
residuals = np.array(preds) - np.array(labels)
# Generate a Normal Q-Q plot
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
stats.probplot(residuals, dist="norm", plot=ax)
# Set colors
line = ax.get_lines()
line[0].set_color(colors['blue']) # Data points
line[1].set_color(colors['red']) # Fit line
plt.title(title)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_val_preds(val_preds, val_labels, colors, title='Histogram of Validation Predictions', save=True):
plt.figure(figsize=(10, 6))
plt.hist(val_labels, bins=20, alpha=0.5, label='True Values', color=colors['green'],)
cmap = cm.get_cmap('coolwarm', len(val_preds)) # Use 'coolwarm' colormap for gradient from red to blue
for epoch, preds in val_preds.items():
color = cmap(len(val_preds) - epoch ) # Get color from colormap
plt.hist(preds, bins=20, alpha=0.5, label=f'Epoch {epoch}', color=color)
plt.xlabel('Predicted Values')
plt.ylabel('Frequency')
plt.title(title)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
# save plot
if save:
save_plot(plt, title)
return plt
####################################################################################################
############### Comparison Plots ###################################################################
####################################################################################################

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long