70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
class History:
|
|
"""
|
|
Class to store the history of the training process.
|
|
Used to store the loss and accuracy of the training and validation sets.
|
|
"""
|
|
def __init__(self):
|
|
self.history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
|
|
'train_acc': [],
|
|
'val_acc': [],
|
|
}
|
|
self.batch_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
|
|
'train_acc': [],
|
|
'val_acc': [],
|
|
}
|
|
|
|
def update(self):
|
|
self.history['train_loss'].append(np.mean(self.batch_history['train_loss']))
|
|
self.history['val_loss'].append(np.mean(self.batch_history['val_loss']))
|
|
self.history['train_acc'].append(np.mean(self.batch_history['train_acc']))
|
|
self.history['val_acc'].append(np.mean(self.batch_history['val_acc']))
|
|
|
|
def get_history(self):
|
|
return self.history
|
|
|
|
def calculate_accuracy(self, outputs, labels):
|
|
preds = torch.argmax(outputs, dim=1)
|
|
correct = (preds == labels).sum().item()
|
|
accuracy = correct / len(labels)
|
|
return accuracy
|
|
|
|
def batch_reset(self):
|
|
self.batch_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
'train_acc': [],
|
|
'val_acc': [],
|
|
}
|
|
|
|
def batch_update(self, train_loss, val_loss, train_acc, val_acc):
|
|
self.batch_history['train_loss'].append(train_loss)
|
|
self.batch_history['val_loss'].append(val_loss)
|
|
self.batch_history['train_acc'].append(train_acc)
|
|
self.batch_history['val_acc'].append(val_acc)
|
|
|
|
def batch_update_train(self, train_loss, preds, labels):
|
|
train_acc = self.calculate_accuracy(preds, labels)
|
|
self.batch_history['train_loss'].append(train_loss)
|
|
self.batch_history['train_acc'].append(train_acc)
|
|
|
|
def batch_update_val(self, val_loss, preds, labels):
|
|
val_acc = self.calculate_accuracy(preds, labels)
|
|
self.batch_history['val_loss'].append(val_loss)
|
|
self.batch_history['val_acc'].append(val_acc)
|
|
|
|
def get_batch_history(self):
|
|
return self.batch_history
|
|
|
|
def print_history(self, epoch, max_epochs, time_elapsed, verbose=True):
|
|
if verbose:
|
|
print(f'Epoch {epoch:>3}/{max_epochs} - {time_elapsed:.2f}s - loss: {self.history["train_loss"][-1]:.4f} - accuracy: {self.history["train_acc"][-1]:.4f} - val_loss: {self.history["val_loss"][-1]:.4f} - val_accuracy: {self.history["val_acc"][-1]:.4f}')
|
|
|