ANLP_WS24_CA2/ml_history.py

70 lines
2.5 KiB
Python

import numpy as np
import torch
class History:
"""
Class to store the history of the training process.
Used to store the loss and accuracy of the training and validation sets.
"""
def __init__(self):
self.history = {
'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
}
self.batch_history = {
'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
}
def update(self):
self.history['train_loss'].append(np.mean(self.batch_history['train_loss']))
self.history['val_loss'].append(np.mean(self.batch_history['val_loss']))
self.history['train_acc'].append(np.mean(self.batch_history['train_acc']))
self.history['val_acc'].append(np.mean(self.batch_history['val_acc']))
def get_history(self):
return self.history
def calculate_accuracy(self, outputs, labels):
preds = torch.argmax(outputs, dim=1)
correct = (preds == labels).sum().item()
accuracy = correct / len(labels)
return accuracy
def batch_reset(self):
self.batch_history = {
'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': [],
}
def batch_update(self, train_loss, val_loss, train_acc, val_acc):
self.batch_history['train_loss'].append(train_loss)
self.batch_history['val_loss'].append(val_loss)
self.batch_history['train_acc'].append(train_acc)
self.batch_history['val_acc'].append(val_acc)
def batch_update_train(self, train_loss, preds, labels):
train_acc = self.calculate_accuracy(preds, labels)
self.batch_history['train_loss'].append(train_loss)
self.batch_history['train_acc'].append(train_acc)
def batch_update_val(self, val_loss, preds, labels):
val_acc = self.calculate_accuracy(preds, labels)
self.batch_history['val_loss'].append(val_loss)
self.batch_history['val_acc'].append(val_acc)
def get_batch_history(self):
return self.batch_history
def print_history(self, epoch, max_epochs, time_elapsed, verbose=True):
if verbose:
print(f'Epoch {epoch:>3}/{max_epochs} - {time_elapsed:.2f}s - loss: {self.history["train_loss"][-1]:.4f} - accuracy: {self.history["train_acc"][-1]:.4f} - val_loss: {self.history["val_loss"][-1]:.4f} - val_accuracy: {self.history["val_acc"][-1]:.4f}')