130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
import torch
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from sklearn.metrics import confusion_matrix, f1_score
|
|
import pandas as pd
|
|
import matplotlib.patches as mpatches
|
|
|
|
def get_accuracy(outputs, labels):
|
|
correct = np.array([p == l for p, l in zip(outputs, labels)])
|
|
accuracy = correct.sum() / len(labels)
|
|
return accuracy
|
|
|
|
def get_f1_score(outputs, labels):
|
|
outputs = torch.tensor(outputs)
|
|
labels = torch.tensor(labels)
|
|
f1 = f1_score(labels, outputs)
|
|
return f1
|
|
|
|
def plot_confusion_matrix(outputs, labels, class_names=['No Humor', 'Humor'], title='Confusion Matrix'):
|
|
conf_matrix = confusion_matrix(labels, outputs)
|
|
|
|
plt.figure(figsize=(6,5))
|
|
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="Blues", xticklabels=class_names, yticklabels=class_names)
|
|
plt.xlabel("Predicted Label")
|
|
plt.ylabel("True Label")
|
|
plt.title(title)
|
|
return plt
|
|
|
|
|
|
def get_label_distribution(labels, preds):
|
|
# Calculate wrong predictions
|
|
wrong_preds = np.array(labels) != np.array(preds)
|
|
|
|
# Calculate the number of wrong predictions for each class
|
|
class_0_wrong_preds = np.sum(np.array(labels)[wrong_preds] == 0)
|
|
class_1_wrong_preds = np.sum(np.array(labels)[wrong_preds] == 1)
|
|
# Calculate the total number of wrong predictions
|
|
total_wrong_preds = np.sum(wrong_preds)
|
|
# Calculate and print the ratio of wrong predictions for each class
|
|
class_0_ratio = class_0_wrong_preds / total_wrong_preds
|
|
class_1_ratio = class_1_wrong_preds / total_wrong_preds
|
|
|
|
print(f"Class 0: {class_0_ratio:.2f}")
|
|
print(f"Class 1: {class_1_ratio:.2f}")
|
|
|
|
def plot_training_history(history, title='Training History'):
|
|
hist_data = history.get_history()
|
|
|
|
epochs = range(1, len(hist_data['train_loss']) + 1)
|
|
|
|
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
|
|
|
|
# Plot accuracy
|
|
axs[1].plot(epochs, hist_data['train_acc'], label='Train Accuracy')
|
|
axs[1].plot(epochs, hist_data['val_acc'], label='Validation Accuracy')
|
|
axs[1].set_title('Accuracy')
|
|
axs[1].set_xlabel('Epochs')
|
|
axs[1].set_ylabel('Accuracy')
|
|
axs[1].legend()
|
|
|
|
# Plot loss
|
|
axs[0].plot(epochs, hist_data['train_loss'], label='Train Loss')
|
|
axs[0].plot(epochs, hist_data['val_loss'], label='Validation Loss')
|
|
axs[0].set_title('Loss')
|
|
axs[0].set_xlabel('Epochs')
|
|
axs[0].set_ylabel('Loss')
|
|
axs[0].legend()
|
|
|
|
plt.tight_layout()
|
|
plt.suptitle(title)
|
|
return plt
|
|
|
|
|
|
|
|
def load_data(filepath):
|
|
"""
|
|
Load the data from a CSV file.
|
|
"""
|
|
df = pd.read_csv(filepath)
|
|
#print(df.shape)
|
|
return df
|
|
|
|
def process_data(df, test_dataset, all_preds, all_labels):
|
|
"""
|
|
Process the data to prepare it for plotting.
|
|
"""
|
|
df_test = df.iloc[test_dataset.original_indices].copy()
|
|
df_test['prediction'] = all_preds
|
|
df_test['label'] = all_labels
|
|
df_test['pred_correct'] = (df_test['prediction'] == df_test['label'])
|
|
df_test_sorted = df_test.sort_values(by='humor_rating').reset_index(drop=True)
|
|
return df_test_sorted
|
|
|
|
def plot_rating_df_based(df_test_sorted, title='Humor Rating vs Prediction for Test Set'):
|
|
"""
|
|
Plot the results of the predictions.
|
|
"""
|
|
median_rating = df_test_sorted['humor_rating'].median()
|
|
median_idx = df_test_sorted[df_test_sorted['humor_rating'] > median_rating].index[0]
|
|
#print(median_idx)
|
|
|
|
range_idx = range(len(df_test_sorted))
|
|
colors = df_test_sorted['pred_correct'].map({True: 'g', False: 'r'})
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
plt.bar(range_idx, df_test_sorted['humor_rating'], color=colors)
|
|
plt.axvline(x=median_idx, color='black', linestyle='--')
|
|
|
|
green_patch = mpatches.Patch(color='g', label='Correct Prediction')
|
|
red_patch = mpatches.Patch(color='r', label='Incorrect Prediction')
|
|
line_patch = mpatches.Patch(color='black', label='humor_rating cut off')
|
|
|
|
plt.title(title)
|
|
plt.xlabel('Index')
|
|
plt.ylabel('Humor Rating')
|
|
plt.legend(handles=[green_patch, red_patch, line_patch])
|
|
return plt
|
|
|
|
|
|
def plot_rating_preds(all_preds, all_labels,
|
|
test_dataset,
|
|
title='Humor Rating vs Prediction for Test Set',
|
|
data_path = 'data/hack.csv'):
|
|
|
|
data = load_data(data_path)
|
|
df_test_sorted = process_data(data, test_dataset, all_preds, all_labels)
|
|
plt = plot_rating_df_based(df_test_sorted, title=title)
|
|
return plt
|