ANLP_WS24_CA2/ml_plots.py

75 lines
2.8 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import time
def save_plot(plt, plot_name):
if not os.path.exists('plots'):
os.makedirs('plots')
# create timestamp
time_stamp = time.strftime('%Y%m%d-%H%M%S')
plt.savefig(f'plots/{plot_name}_{time_stamp}.png')
def plot_training_history(hist_data, colors, title='Training History', save=True):
epochs = range(1, len(hist_data['train_loss']) + 1)
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
# Plot accuracy
axs[1].plot(epochs, hist_data['train_rmse'], label='Train RMSE', color=colors['blue'])
axs[1].plot(epochs, hist_data['val_rmse'], label='Validation RMSE', color=colors['green'])
axs[1].set_title('RMSE')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('RMSE')
axs[1].legend()
# Plot loss
axs[0].plot(epochs, hist_data['train_loss'], label='Train Loss', color=colors['blue'])
axs[0].plot(epochs, hist_data['val_loss'], label='Validation Loss', color=colors['green'])
axs[0].set_title('Loss')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Loss')
axs[0].legend()
plt.tight_layout()
plt.suptitle(title)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_distribution(true_values, predicted_values, colors, title='Distribution of Predicted and True Values', save=True):
plt.figure(figsize=(10, 6))
plt.hist(true_values, bins=20, color=colors['green'], edgecolor='black', alpha=0.7, label='True Values')
plt.hist(predicted_values, bins=20, color=colors['blue'], edgecolor='black', alpha=0.7, label='Predicted Values')
plt.title(title)
plt.xlabel('Score')
plt.ylabel('Frequency')
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
# save plot
if save:
save_plot(plt, title)
return plt
def plot_predictions(true_values, predicted_values, colors, title='True vs Predicted Values', threshold=0.3, save=True):
plt.figure(figsize=(10, 6))
# Difference between predicted and true values
correct_indices = np.isclose(true_values, predicted_values, atol=threshold)
incorrect_indices = ~correct_indices
# Plot
plt.scatter(np.array(true_values)[correct_indices], np.array(predicted_values)[correct_indices], color=colors['green'], alpha=0.5, label='Correctly predicted')
plt.scatter(np.array(true_values)[incorrect_indices], np.array(predicted_values)[incorrect_indices], color=colors['red'], alpha=0.5, label='Incorrectly predicted')
plt.plot([min(true_values), max(true_values)], [min(true_values), max(true_values)], color=colors['blue'], linestyle='--', label='Ideal Line')
plt.xlabel('True Values')
plt.ylabel('Predicted Values')
plt.title(title)
plt.legend()
plt.grid(True)
# save plot
if save:
save_plot(plt, title)
return plt