added plots
parent
282cb128e3
commit
cd2e2e5858
90
ml_plots.py
90
ml_plots.py
|
|
@ -1,6 +1,11 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib.cm as cm
|
||||
import scipy.stats as stats
|
||||
import matplotlib.gridspec as gridspec
|
||||
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import os
|
||||
import time
|
||||
|
||||
|
|
@ -72,4 +77,87 @@ def plot_predictions(true_values, predicted_values, colors, title='True vs Predi
|
|||
# save plot
|
||||
if save:
|
||||
save_plot(plt, title)
|
||||
return plt
|
||||
return plt
|
||||
|
||||
def plot_residuals(labels, preds, colors, title='Residuals Plot', save=True):
|
||||
residuals = np.array(preds) - np.array(labels)
|
||||
|
||||
fig = plt.figure(figsize=(14, 6))
|
||||
gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
|
||||
|
||||
# Main plot
|
||||
ax0 = plt.subplot(gs[0])
|
||||
ax0.scatter(labels, residuals, label='Residuals', color=colors['blue'], alpha=0.5)
|
||||
|
||||
# Fit linear regression model to residuals
|
||||
labels_reshaped = np.array(labels).reshape(-1, 1)
|
||||
model = LinearRegression()
|
||||
model.fit(labels_reshaped, residuals)
|
||||
trend_line = model.predict(labels_reshaped)
|
||||
|
||||
# Plot trend line
|
||||
ax0.plot(labels, trend_line, color=colors['red'], label='Trend Line', linewidth=2)
|
||||
|
||||
ax0.set_xlabel('True Values')
|
||||
ax0.set_ylabel('Residuals')
|
||||
ax0.axhline(y=0, color='k', linestyle='--')
|
||||
ax0.set_title(title)
|
||||
ax0.legend()
|
||||
|
||||
# Side plot for distribution of true values
|
||||
ax1 = plt.subplot(gs[1], sharey=ax0)
|
||||
ax1.hist(residuals, bins=30, alpha=0.5, color=colors['blue'], orientation='horizontal')
|
||||
ax1.set_xlabel('Frequency')
|
||||
ax1.set_title('Distribution of residuals')
|
||||
ax1.yaxis.tick_right()
|
||||
ax1.yaxis.set_label_position("right")
|
||||
|
||||
plt.tight_layout()
|
||||
# save plot
|
||||
if save:
|
||||
save_plot(plt, title)
|
||||
return plt
|
||||
|
||||
def plot_qq(labels, preds, colors, title='Q-Q Plot of Residuals', save=True):
|
||||
residuals = np.array(preds) - np.array(labels)
|
||||
|
||||
# Generate a Normal Q-Q plot
|
||||
fig = plt.figure(figsize=(8, 6))
|
||||
ax = fig.add_subplot(111)
|
||||
stats.probplot(residuals, dist="norm", plot=ax)
|
||||
|
||||
# Set colors
|
||||
line = ax.get_lines()
|
||||
line[0].set_color(colors['blue']) # Data points
|
||||
line[1].set_color(colors['red']) # Fit line
|
||||
|
||||
plt.title(title)
|
||||
# save plot
|
||||
if save:
|
||||
save_plot(plt, title)
|
||||
return plt
|
||||
|
||||
def plot_val_preds(val_preds, val_labels, colors, title='Histogram of Validation Predictions', save=True):
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(val_labels, bins=20, alpha=0.5, label='True Values', color=colors['green'],)
|
||||
|
||||
cmap = cm.get_cmap('coolwarm', len(val_preds)) # Use 'coolwarm' colormap for gradient from red to blue
|
||||
for epoch, preds in val_preds.items():
|
||||
color = cmap(len(val_preds) - epoch ) # Get color from colormap
|
||||
plt.hist(preds, bins=20, alpha=0.5, label=f'Epoch {epoch}', color=color)
|
||||
|
||||
plt.xlabel('Predicted Values')
|
||||
plt.ylabel('Frequency')
|
||||
plt.title(title)
|
||||
plt.legend()
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
# save plot
|
||||
if save:
|
||||
save_plot(plt, title)
|
||||
return plt
|
||||
|
||||
|
||||
####################################################################################################
|
||||
############### Comparison Plots ###################################################################
|
||||
####################################################################################################
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue