added helpfull functionality
parent
394167488f
commit
c444b0d451
|
|
@ -4,8 +4,6 @@ __pycache__/
|
||||||
# Ignore virtual environment directory
|
# Ignore virtual environment directory
|
||||||
.venv/
|
.venv/
|
||||||
|
|
||||||
# Ignore requirements file
|
|
||||||
reqs_venv.txt
|
|
||||||
|
|
||||||
# Ignore models directory
|
# Ignore models directory
|
||||||
models/
|
models/
|
||||||
|
|
@ -15,6 +13,8 @@ models/
|
||||||
*.keras
|
*.keras
|
||||||
*.pth
|
*.pth
|
||||||
|
|
||||||
|
checkpoints/
|
||||||
|
|
||||||
# Ignore plots directory
|
# Ignore plots directory
|
||||||
plots/
|
plots/
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
def __init__(self, patience=5, verbose=False):
|
||||||
|
self.patience = patience
|
||||||
|
self.verbose = verbose
|
||||||
|
self.counter = 0
|
||||||
|
self.best_score = None
|
||||||
|
self.early_stop = False
|
||||||
|
|
||||||
|
def __call__(self, val_loss, model):
|
||||||
|
score = -val_loss
|
||||||
|
if self.best_score is None:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
elif score < self.best_score:
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.early_stop = True
|
||||||
|
else:
|
||||||
|
self.best_score = score
|
||||||
|
self.save_checkpoint(val_loss, model)
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
def save_checkpoint(self, val_loss, model, filename='checkpoint.pt'):
|
||||||
|
if self.verbose:
|
||||||
|
print(f'Validation loss decreased ({self.best_score:.6f} --> {val_loss:.6f}). Saving model ...')
|
||||||
|
torch.save(model.state_dict(), f'checkpoints/{filename}')
|
||||||
|
|
@ -3,7 +3,42 @@ This file contains the HumorDataset class.
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
|
||||||
|
class TextDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, texts, labels, word_index, max_len=50):
|
||||||
|
|
||||||
|
self.original_indices = labels.index.to_list()
|
||||||
|
|
||||||
|
self.texts = texts.reset_index(drop=True)
|
||||||
|
self.labels = labels.reset_index(drop=True)
|
||||||
|
self.word_index = word_index
|
||||||
|
self.max_len = max_len
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
texts = self.texts[idx]
|
||||||
|
tokens = word_tokenize(texts.lower())
|
||||||
|
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
# Tokenize and convert to indices
|
||||||
|
input_ids = [self.word_index.get(word, self.word_index['<UNK>']) for word in tokens]
|
||||||
|
|
||||||
|
# Pad or truncate to max_len
|
||||||
|
if len(input_ids) < self.max_len:
|
||||||
|
input_ids += [self.word_index['<PAD>']] * (self.max_len - len(input_ids))
|
||||||
|
else:
|
||||||
|
input_ids = input_ids[:self.max_len]
|
||||||
|
|
||||||
|
# Convert to PyTorch tensors
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
label = torch.tensor(label, dtype=torch.long)
|
||||||
|
|
||||||
|
return input_ids, label
|
||||||
|
|
||||||
class HumorDataset(torch.utils.data.Dataset):
|
class HumorDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, data, labels, vocab_size=0, emb_dim=None):
|
def __init__(self, data, labels, vocab_size=0, emb_dim=None):
|
||||||
self.original_indices = labels.index.to_list()
|
self.original_indices = labels.index.to_list()
|
||||||
|
|
|
||||||
|
|
@ -9,47 +9,101 @@ import gensim
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
|
import regex as re
|
||||||
|
|
||||||
from HumorDataset import HumorDataset
|
import HumorDataset
|
||||||
|
|
||||||
def get_embedding_idx(model, word):
|
# def load_glove_embeddings(glove_file_path):
|
||||||
if word in model.wv:
|
# embeddings_index = {}
|
||||||
return model.wv.key_to_index[word]
|
# with open(glove_file_path, 'r', encoding='utf-8') as f:
|
||||||
else:
|
# for line in f:
|
||||||
return unk_index
|
# try:
|
||||||
|
# values = line.split()
|
||||||
|
# #print(values)
|
||||||
|
# word = values[0]
|
||||||
|
# coefs = np.asarray(values[1:], dtype='float32')
|
||||||
|
# embeddings_index[word] = coefs
|
||||||
|
# except ValueError:
|
||||||
|
# print('Error with line:', line[:100])
|
||||||
|
# return embeddings_index
|
||||||
|
|
||||||
def get_embedding_vector(model, word):
|
def load_glove_embeddings(glove_file_path, emb_len=100):
|
||||||
if word in model.wv:
|
|
||||||
return model.wv[word]
|
|
||||||
else:
|
|
||||||
return np.zeros(model.vector_size)
|
|
||||||
|
|
||||||
def load_glove_embeddings(glove_file_path):
|
|
||||||
embeddings_index = {}
|
embeddings_index = {}
|
||||||
with open(glove_file_path, 'r', encoding='utf-8') as f:
|
with open(glove_file_path, 'r', encoding='utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
values = line.split()
|
try:
|
||||||
word = values[0]
|
# Use regex to split the line into word and coefficients
|
||||||
coefs = np.asarray(values[1:], dtype='float32')
|
match = re.match(r"(.+?)\s+([\d\s\.\-e]+)", line)
|
||||||
embeddings_index[word] = coefs
|
# regex explanation: Match word followed by one or more spaces and then the coefficients
|
||||||
|
if match:
|
||||||
|
word = match.group(1)
|
||||||
|
coefs = np.fromstring(match.group(2), sep=' ', dtype='float32')
|
||||||
|
|
||||||
|
#check list length
|
||||||
|
if len(coefs) != emb_len:
|
||||||
|
print('Skip: Length mismatch with line:', line[:100])
|
||||||
|
else:
|
||||||
|
embeddings_index[word] = coefs
|
||||||
|
else:
|
||||||
|
print('Error with line:', line[:100])
|
||||||
|
except ValueError:
|
||||||
|
print('Error with line:', line[:100])
|
||||||
return embeddings_index
|
return embeddings_index
|
||||||
|
|
||||||
def get_embedding_glove_vector(tokens, embeddings_index, default_vector_len=100, pad_tok='<PAD>'):
|
|
||||||
default_vec = [0] * default_vector_len
|
|
||||||
emb_matrix = []
|
|
||||||
for token in tokens:
|
|
||||||
if token == pad_tok:
|
|
||||||
embedding_vector = default_vec
|
|
||||||
else:
|
|
||||||
embedding_vector = embeddings_index.get(token, default_vec)
|
|
||||||
emb_matrix.append(embedding_vector)
|
|
||||||
return emb_matrix
|
|
||||||
|
|
||||||
def encode_tokens(tokens, vector=False):
|
def create_embbedings_matrix(embeddings_glove, max_len=100):
|
||||||
if vector:
|
embeddings_glove['<UNK>'] = np.random.rand(max_len)
|
||||||
return [get_embedding_vector(model_embedding, token) for token in tokens]
|
embeddings_glove['<PAD>'] = np.zeros(max_len)
|
||||||
else:
|
# Create a word index (vocabulary)
|
||||||
return [get_embedding_idx(model_embedding, token) for token in tokens]
|
word_index = {word: idx for idx, word in enumerate(embeddings_glove.keys())}
|
||||||
|
# Special tokens are in the word index
|
||||||
|
word_index['<UNK>'] = len(word_index) - 2
|
||||||
|
word_index['<PAD>'] = len(word_index) - 1
|
||||||
|
# print len of word_index
|
||||||
|
print(len(word_index))
|
||||||
|
# Create an embedding matrix
|
||||||
|
embedding_dim = len(next(iter(embeddings_glove.values())))
|
||||||
|
|
||||||
|
embedding_matrix = np.zeros((len(word_index), embedding_dim))
|
||||||
|
|
||||||
|
for word, idx in word_index.items():
|
||||||
|
embedding_vector = embeddings_glove.get(word)
|
||||||
|
if embedding_vector is not None:
|
||||||
|
embedding_matrix[idx] = embedding_vector
|
||||||
|
|
||||||
|
# Convert the embedding matrix to a tensor
|
||||||
|
embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32)
|
||||||
|
return embedding_matrix, word_index
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_matrix(gloVe_path='glove.6B/glove.6B.100d.txt', emb_len=100):
|
||||||
|
embeddings_glove = load_glove_embeddings(gloVe_path, emb_len=emb_len)
|
||||||
|
|
||||||
|
embedding_matrix, word_index = create_embbedings_matrix(embeddings_glove)
|
||||||
|
|
||||||
|
vocab_size = len(embedding_matrix)
|
||||||
|
d_model = len(embedding_matrix[0])
|
||||||
|
vocab_size, d_model = embedding_matrix.size()
|
||||||
|
print(f"vocab_size: {vocab_size}, d_model: {d_model}")
|
||||||
|
|
||||||
|
return embedding_matrix, word_index, vocab_size, d_model
|
||||||
|
|
||||||
|
|
||||||
|
def load_preprocess_data(path_data='data/hack.csv'):
|
||||||
|
df = pd.read_csv(path_data)
|
||||||
|
df = df.dropna(subset=['humor_rating'])
|
||||||
|
# find median of humor_rating
|
||||||
|
median_rating = df['humor_rating'].median()
|
||||||
|
df['y'] = df['humor_rating'] > median_rating
|
||||||
|
X = df['text']
|
||||||
|
y = df['y']
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
def encode_tokens(tokens, embedding_index, default_vector_len=100):
|
||||||
|
return [embedding_index.get(token, np.random.zeros(default_vector_len)) for token in tokens]
|
||||||
|
|
||||||
|
|
||||||
def pad_sequences(sequences, max_len, pad_index):
|
def pad_sequences(sequences, max_len, pad_index):
|
||||||
return np.array([np.pad(seq, (0, max_len - len(seq)), mode='constant', constant_values=pad_index) if len(seq) < max_len else seq[:max_len] for seq in sequences])
|
return np.array([np.pad(seq, (0, max_len - len(seq)), mode='constant', constant_values=pad_index) if len(seq) < max_len else seq[:max_len] for seq in sequences])
|
||||||
|
|
@ -82,7 +136,9 @@ def save_data(data_dict, path, prefix, vocab_size=0, emb_dim=None):
|
||||||
dataset = HumorDataset(value['X'], value['y'], vocab_size, emb_dim)
|
dataset = HumorDataset(value['X'], value['y'], vocab_size, emb_dim)
|
||||||
# save dataset
|
# save dataset
|
||||||
torch.save(dataset, path + prefix + key + '.pt')
|
torch.save(dataset, path + prefix + key + '.pt')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Load the data from csv
|
# Load the data from csv
|
||||||
df = pd.read_csv('data/hack.csv')
|
df = pd.read_csv('data/hack.csv')
|
||||||
|
|
@ -114,41 +170,38 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# split data into train, test, and validation
|
# split data into train, test, and validation
|
||||||
data_dict = split_data(padded_indices, y)
|
data_dict = split_data(padded_indices, y)
|
||||||
|
|
||||||
# Embed the data with word2vec
|
|
||||||
model_embedding = gensim.models.Word2Vec(tokens, window=5, min_count=1, workers=4)
|
|
||||||
|
|
||||||
# Add a special token for out-of-vocabulary words
|
# data_idx_based = copy.deepcopy(data_dict)
|
||||||
model_embedding.wv.add_vector('<UNK>', np.zeros(model_embedding.vector_size))
|
# vector_based = False
|
||||||
unk_index = model_embedding.wv.key_to_index['<UNK>']
|
|
||||||
|
|
||||||
# Add padding index for padding
|
# for key in data_idx_based.keys():
|
||||||
model_embedding.wv.add_vector('<PAD>', np.zeros(model_embedding.vector_size))
|
# data_idx_based[key]['X'] = [encode_tokens(tokens, vector_based) for tokens in data_dict[key]['X']]
|
||||||
pad_index = model_embedding.wv.key_to_index['<PAD>']
|
# # print shape of data
|
||||||
|
# #print(key, len(data_dict[key]['X']), len(data_dict[key]['y']))
|
||||||
|
|
||||||
|
# # save the data
|
||||||
data_idx_based = copy.deepcopy(data_dict)
|
# save_data(data_idx_based, 'data/idx_based_padded/', '', vocab_size)
|
||||||
vector_based = False
|
|
||||||
|
|
||||||
for key in data_idx_based.keys():
|
|
||||||
data_idx_based[key]['X'] = [encode_tokens(tokens, vector_based) for tokens in data_dict[key]['X']]
|
|
||||||
# print shape of data
|
|
||||||
#print(key, len(data_dict[key]['X']), len(data_dict[key]['y']))
|
|
||||||
|
|
||||||
# save the data
|
|
||||||
save_data(data_idx_based, 'data/idx_based_padded/', '', vocab_size)
|
|
||||||
|
|
||||||
print('loading GloVe embeddings')
|
print('loading GloVe embeddings')
|
||||||
vector_based = True
|
|
||||||
# Load GloVe embeddings
|
# Load GloVe embeddings
|
||||||
glove_file_path = 'glove.6B/glove.6B.100d.txt'
|
glove_file_path = 'glove.6B/glove.6B.100d.txt'
|
||||||
|
#glove_file_path = 'glove.840B.300d/glove.840B.300d.txt'
|
||||||
embeddings_index = load_glove_embeddings(glove_file_path)
|
embeddings_index = load_glove_embeddings(glove_file_path)
|
||||||
|
emb_len = 100
|
||||||
print('starting with embedding the data')
|
print('starting with embedding the data')
|
||||||
# Encode the tokens
|
# Encode the tokens
|
||||||
for key in data_dict.keys():
|
#for key in data_dict.keys():
|
||||||
data_dict[key]['X'] = [get_embedding_glove_vector(tokens, embeddings_index) for tokens in data_dict[key]['X']]
|
#data_dict[key]['X'] = [get_embedding_glove_vector(tokens, embeddings_index, default_vector_len=emb_len) for tokens in data_dict[key]['X']]
|
||||||
# print shape of data
|
# print shape of data
|
||||||
#print(key, len(data_dict[key]['X']), len(data_dict[key]['y']))
|
#print(key, len(data_dict[key]['X']), len(data_dict[key]['y']))
|
||||||
|
|
||||||
# Save the data
|
# Save the data
|
||||||
save_data(data_dict, 'data/embedded_padded/', '', vocab_size, emb_dim=model_embedding.vector_size)
|
#save_data(data_dict, 'data/embedded_padded/', '', vocab_size, emb_dim=model_embedding.vector_size)
|
||||||
|
|
||||||
|
|
||||||
|
max_len = 100
|
||||||
|
gloVe_path = 'glove.6B/glove.6B.100d.txt'
|
||||||
|
embeddings_glove = load_glove_embeddings(gloVe_path, emb_len=max_len)
|
||||||
|
|
||||||
|
embeddings_glove['<UNK>'] = np.random.rand(max_len)
|
||||||
|
embeddings_glove['<PAD>'] = np.zeros(max_len)
|
||||||
|
|
@ -0,0 +1,129 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from sklearn.metrics import confusion_matrix, f1_score
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
|
||||||
|
def get_accuracy(outputs, labels):
|
||||||
|
correct = np.array([p == l for p, l in zip(outputs, labels)])
|
||||||
|
accuracy = correct.sum() / len(labels)
|
||||||
|
return accuracy
|
||||||
|
|
||||||
|
def get_f1_score(outputs, labels):
|
||||||
|
outputs = torch.tensor(outputs)
|
||||||
|
labels = torch.tensor(labels)
|
||||||
|
f1 = f1_score(labels, outputs)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
def plot_confusion_matrix(outputs, labels, class_names=['No Humor', 'Humor'], title='Confusion Matrix'):
|
||||||
|
conf_matrix = confusion_matrix(labels, outputs)
|
||||||
|
|
||||||
|
plt.figure(figsize=(6,5))
|
||||||
|
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap="Blues", xticklabels=class_names, yticklabels=class_names)
|
||||||
|
plt.xlabel("Predicted Label")
|
||||||
|
plt.ylabel("True Label")
|
||||||
|
plt.title(title)
|
||||||
|
return plt
|
||||||
|
|
||||||
|
|
||||||
|
def get_label_distribution(labels, preds):
|
||||||
|
# Calculate wrong predictions
|
||||||
|
wrong_preds = np.array(labels) != np.array(preds)
|
||||||
|
|
||||||
|
# Calculate the number of wrong predictions for each class
|
||||||
|
class_0_wrong_preds = np.sum(np.array(labels)[wrong_preds] == 0)
|
||||||
|
class_1_wrong_preds = np.sum(np.array(labels)[wrong_preds] == 1)
|
||||||
|
# Calculate the total number of wrong predictions
|
||||||
|
total_wrong_preds = np.sum(wrong_preds)
|
||||||
|
# Calculate and print the ratio of wrong predictions for each class
|
||||||
|
class_0_ratio = class_0_wrong_preds / total_wrong_preds
|
||||||
|
class_1_ratio = class_1_wrong_preds / total_wrong_preds
|
||||||
|
|
||||||
|
print(f"Class 0: {class_0_ratio:.2f}")
|
||||||
|
print(f"Class 1: {class_1_ratio:.2f}")
|
||||||
|
|
||||||
|
def plot_training_history(history, title='Training History'):
|
||||||
|
hist_data = history.get_history()
|
||||||
|
|
||||||
|
epochs = range(1, len(hist_data['train_loss']) + 1)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
|
||||||
|
|
||||||
|
# Plot accuracy
|
||||||
|
axs[1].plot(epochs, hist_data['train_acc'], label='Train Accuracy')
|
||||||
|
axs[1].plot(epochs, hist_data['val_acc'], label='Validation Accuracy')
|
||||||
|
axs[1].set_title('Accuracy')
|
||||||
|
axs[1].set_xlabel('Epochs')
|
||||||
|
axs[1].set_ylabel('Accuracy')
|
||||||
|
axs[1].legend()
|
||||||
|
|
||||||
|
# Plot loss
|
||||||
|
axs[0].plot(epochs, hist_data['train_loss'], label='Train Loss')
|
||||||
|
axs[0].plot(epochs, hist_data['val_loss'], label='Validation Loss')
|
||||||
|
axs[0].set_title('Loss')
|
||||||
|
axs[0].set_xlabel('Epochs')
|
||||||
|
axs[0].set_ylabel('Loss')
|
||||||
|
axs[0].legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.suptitle(title)
|
||||||
|
return plt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(filepath):
|
||||||
|
"""
|
||||||
|
Load the data from a CSV file.
|
||||||
|
"""
|
||||||
|
df = pd.read_csv(filepath)
|
||||||
|
#print(df.shape)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def process_data(df, test_dataset, all_preds, all_labels):
|
||||||
|
"""
|
||||||
|
Process the data to prepare it for plotting.
|
||||||
|
"""
|
||||||
|
df_test = df.iloc[test_dataset.original_indices].copy()
|
||||||
|
df_test['prediction'] = all_preds
|
||||||
|
df_test['label'] = all_labels
|
||||||
|
df_test['pred_correct'] = (df_test['prediction'] == df_test['label'])
|
||||||
|
df_test_sorted = df_test.sort_values(by='humor_rating').reset_index(drop=True)
|
||||||
|
return df_test_sorted
|
||||||
|
|
||||||
|
def plot_rating_df_based(df_test_sorted, title='Humor Rating vs Prediction for Test Set'):
|
||||||
|
"""
|
||||||
|
Plot the results of the predictions.
|
||||||
|
"""
|
||||||
|
median_rating = df_test_sorted['humor_rating'].median()
|
||||||
|
median_idx = df_test_sorted[df_test_sorted['humor_rating'] > median_rating].index[0]
|
||||||
|
#print(median_idx)
|
||||||
|
|
||||||
|
range_idx = range(len(df_test_sorted))
|
||||||
|
colors = df_test_sorted['pred_correct'].map({True: 'g', False: 'r'})
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.bar(range_idx, df_test_sorted['humor_rating'], color=colors)
|
||||||
|
plt.axvline(x=median_idx, color='black', linestyle='--')
|
||||||
|
|
||||||
|
green_patch = mpatches.Patch(color='g', label='Correct Prediction')
|
||||||
|
red_patch = mpatches.Patch(color='r', label='Incorrect Prediction')
|
||||||
|
line_patch = mpatches.Patch(color='black', label='humor_rating cut off')
|
||||||
|
|
||||||
|
plt.title(title)
|
||||||
|
plt.xlabel('Index')
|
||||||
|
plt.ylabel('Humor Rating')
|
||||||
|
plt.legend(handles=[green_patch, red_patch, line_patch])
|
||||||
|
return plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_rating_preds(all_preds, all_labels,
|
||||||
|
test_dataset,
|
||||||
|
title='Humor Rating vs Prediction for Test Set',
|
||||||
|
data_path = 'data/hack.csv'):
|
||||||
|
|
||||||
|
data = load_data(data_path)
|
||||||
|
df_test_sorted = process_data(data, test_dataset, all_preds, all_labels)
|
||||||
|
plt = plot_rating_df_based(df_test_sorted, title=title)
|
||||||
|
return plt
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
class History:
|
class History:
|
||||||
"""
|
"""
|
||||||
|
|
@ -7,42 +8,63 @@ class History:
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.history = {
|
self.history = {
|
||||||
'loss': [],
|
'train_loss': [],
|
||||||
|
'val_loss': [],
|
||||||
|
|
||||||
'train_acc': [],
|
'train_acc': [],
|
||||||
'val_acc': [],
|
'val_acc': [],
|
||||||
}
|
}
|
||||||
self.batch_history = {
|
self.batch_history = {
|
||||||
'loss': [],
|
'train_loss': [],
|
||||||
|
'val_loss': [],
|
||||||
|
|
||||||
'train_acc': [],
|
'train_acc': [],
|
||||||
'val_acc': [],
|
'val_acc': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.history['loss'].append(np.mean(self.batch_history['loss']))
|
self.history['train_loss'].append(np.mean(self.batch_history['train_loss']))
|
||||||
|
self.history['val_loss'].append(np.mean(self.batch_history['val_loss']))
|
||||||
self.history['train_acc'].append(np.mean(self.batch_history['train_acc']))
|
self.history['train_acc'].append(np.mean(self.batch_history['train_acc']))
|
||||||
self.history['val_acc'].append(np.mean(self.batch_history['val_acc']))
|
self.history['val_acc'].append(np.mean(self.batch_history['val_acc']))
|
||||||
|
|
||||||
def get_history(self):
|
def get_history(self):
|
||||||
return self.history
|
return self.history
|
||||||
|
|
||||||
|
def calculate_accuracy(self, outputs, labels):
|
||||||
|
preds = torch.argmax(outputs, dim=1)
|
||||||
|
correct = (preds == labels).sum().item()
|
||||||
|
accuracy = correct / len(labels)
|
||||||
|
return accuracy
|
||||||
|
|
||||||
def batch_reset(self):
|
def batch_reset(self):
|
||||||
self.batch_history = {
|
self.batch_history = {
|
||||||
'loss': [],
|
'train_loss': [],
|
||||||
|
'val_loss': [],
|
||||||
'train_acc': [],
|
'train_acc': [],
|
||||||
'val_acc': [],
|
'val_acc': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def batch_update(self, loss, train_acc, val_acc):
|
def batch_update(self, train_loss, val_loss, train_acc, val_acc):
|
||||||
self.batch_history['loss'].append(loss)
|
self.batch_history['train_loss'].append(train_loss)
|
||||||
|
self.batch_history['val_loss'].append(val_loss)
|
||||||
self.batch_history['train_acc'].append(train_acc)
|
self.batch_history['train_acc'].append(train_acc)
|
||||||
self.batch_history['val_acc'].append(val_acc)
|
self.batch_history['val_acc'].append(val_acc)
|
||||||
|
|
||||||
def batch_update_train(self, loss, train_acc):
|
def batch_update_train(self, train_loss, preds, labels):
|
||||||
self.batch_history['loss'].append(loss)
|
train_acc = self.calculate_accuracy(preds, labels)
|
||||||
|
self.batch_history['train_loss'].append(train_loss)
|
||||||
self.batch_history['train_acc'].append(train_acc)
|
self.batch_history['train_acc'].append(train_acc)
|
||||||
|
|
||||||
def batch_update_val(self, val_acc):
|
def batch_update_val(self, val_loss, preds, labels):
|
||||||
|
val_acc = self.calculate_accuracy(preds, labels)
|
||||||
|
self.batch_history['val_loss'].append(val_loss)
|
||||||
self.batch_history['val_acc'].append(val_acc)
|
self.batch_history['val_acc'].append(val_acc)
|
||||||
|
|
||||||
def get_batch_history(self):
|
def get_batch_history(self):
|
||||||
return self.batch_history
|
return self.batch_history
|
||||||
|
|
||||||
|
def print_history(self, epoch, max_epochs, time_elapsed, verbose=True):
|
||||||
|
if verbose:
|
||||||
|
print(f'Epoch {epoch:>3}/{max_epochs} - {time_elapsed:.2f}s - loss: {self.history["train_loss"][-1]:.4f} - accuracy: {self.history["train_acc"][-1]:.4f} - val_loss: {self.history["val_loss"][-1]:.4f} - val_accuracy: {self.history["val_acc"][-1]:.4f}')
|
||||||
|
|
||||||
Loading…
Reference in New Issue