ANLP_WS24_CA2/ml_plots.py

164 lines
5.8 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.cm as cm
import scipy.stats as stats
import matplotlib.gridspec as gridspec
from sklearn.linear_model import LinearRegression
import os
import time
def save_plot(plt, plot_name):
if not os.path.exists('plots'):
os.makedirs('plots')
# create timestamp
time_stamp = time.strftime('%Y%m%d-%H%M%S')
plt.savefig(f'plots/{plot_name}_{time_stamp}.png')
def plot_training_history(hist_data, colors, title='Training History', save=True):
epochs = range(1, len(hist_data['train_loss']) + 1)
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
# Plot accuracy
axs[1].plot(epochs, hist_data['train_rmse'], label='Train RMSE', color=colors['blue'])
axs[1].plot(epochs, hist_data['val_rmse'], label='Validation RMSE', color=colors['green'])
axs[1].set_title('RMSE')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('RMSE')
axs[1].legend()
# Plot loss
axs[0].plot(epochs, hist_data['train_loss'], label='Train Loss', color=colors['blue'])
axs[0].plot(epochs, hist_data['val_loss'], label='Validation Loss', color=colors['green'])
axs[0].set_title('Loss')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')
axs[0].legend()
plt.tight_layout()
plt.suptitle(title)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_distribution(true_values, predicted_values, colors, title='Distribution of Predicted and True Values', save=True):
plt.figure(figsize=(10, 6))
plt.hist(true_values, bins=20, color=colors['green'], edgecolor='black', alpha=0.7, label='True Values')
plt.hist(predicted_values, bins=20, color=colors['blue'], edgecolor='black', alpha=0.7, label='Predicted Values')
plt.title(title)
plt.xlabel('Score')
plt.ylabel('Frequency')
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_predictions(true_values, predicted_values, colors, title='True vs Predicted Values', threshold=0.3, save=True):
plt.figure(figsize=(10, 6))
# Difference between predicted and true values
correct_indices = np.isclose(true_values, predicted_values, atol=threshold)
incorrect_indices = ~correct_indices
# Plot
plt.scatter(np.array(true_values)[correct_indices], np.array(predicted_values)[correct_indices], color=colors['green'], alpha=0.5, label='Correctly predicted')
plt.scatter(np.array(true_values)[incorrect_indices], np.array(predicted_values)[incorrect_indices], color=colors['red'], alpha=0.5, label='Incorrectly predicted')
plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], color=colors['blue'], linestyle='--', label='Ideal Line')
plt.xlabel('True Values')
plt.ylabel('Predicted Values')
plt.title(title)
plt.legend()
plt.grid(True)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_residuals(labels, preds, colors, title='Residuals Plot', save=True):
residuals = np.array(preds) - np.array(labels)
fig = plt.figure(figsize=(14, 6))
gs = gridspec.GridSpec(1, 2, width_ratios=[4, 1])
# Main plot
ax0 = plt.subplot(gs[0])
ax0.scatter(labels, residuals, label='Residuals', color=colors['blue'], alpha=0.5)
# Fit linear regression model to residuals
labels_reshaped = np.array(labels).reshape(-1, 1)
model = LinearRegression()
model.fit(labels_reshaped, residuals)
trend_line = model.predict(labels_reshaped)
# Plot trend line
ax0.plot(labels, trend_line, color=colors['red'], label='Trend Line', linewidth=2)
ax0.set_xlabel('True Values')
ax0.set_ylabel('Residuals')
ax0.axhline(y=0, color='k', linestyle='--')
ax0.set_title(title)
ax0.legend()
# Side plot for distribution of true values
ax1 = plt.subplot(gs[1], sharey=ax0)
ax1.hist(residuals, bins=30, alpha=0.5, color=colors['blue'], orientation='horizontal')
ax1.set_xlabel('Frequency')
ax1.set_title('Distribution of residuals')
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
plt.tight_layout()
# save plot
if save:
save_plot(plt, title)
return plt
def plot_qq(labels, preds, colors, title='Q-Q Plot of Residuals', save=True):
residuals = np.array(preds) - np.array(labels)
# Generate a Normal Q-Q plot
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
stats.probplot(residuals, dist="norm", plot=ax)
# Set colors
line = ax.get_lines()
line[0].set_color(colors['blue']) # Data points
line[1].set_color(colors['red']) # Fit line
plt.title(title)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_val_preds(val_preds, val_labels, colors, title='Histogram of Validation Predictions', save=True):
plt.figure(figsize=(10, 6))
plt.hist(val_labels, bins=20, alpha=0.5, label='True Values', color=colors['green'],)
cmap = cm.get_cmap('coolwarm', len(val_preds)) # Use 'coolwarm' colormap for gradient from red to blue
for epoch, preds in val_preds.items():
color = cmap(len(val_preds) - epoch ) # Get color from colormap
plt.hist(preds, bins=20, alpha=0.5, label=f'Epoch {epoch}', color=color)
plt.xlabel('Predicted Values')
plt.ylabel('Frequency')
plt.title(title)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
# save plot
if save:
save_plot(plt, title)
return plt
####################################################################################################
############### Comparison Plots ###################################################################
####################################################################################################