import numpy as np import torch class History: """ Class to store the history of the training process. Used to store the loss and accuracy of the training and validation sets. """ def __init__(self): self.history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], } self.batch_history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], } def update(self): self.history['train_loss'].append(np.mean(self.batch_history['train_loss'])) self.history['val_loss'].append(np.mean(self.batch_history['val_loss'])) self.history['train_acc'].append(np.mean(self.batch_history['train_acc'])) self.history['val_acc'].append(np.mean(self.batch_history['val_acc'])) def get_history(self): return self.history def calculate_accuracy(self, outputs, labels): preds = torch.argmax(outputs, dim=1) correct = (preds == labels).sum().item() accuracy = correct / len(labels) return accuracy def batch_reset(self): self.batch_history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], } def batch_update(self, train_loss, val_loss, train_acc, val_acc): self.batch_history['train_loss'].append(train_loss) self.batch_history['val_loss'].append(val_loss) self.batch_history['train_acc'].append(train_acc) self.batch_history['val_acc'].append(val_acc) def batch_update_train(self, train_loss, preds, labels): train_acc = self.calculate_accuracy(preds, labels) self.batch_history['train_loss'].append(train_loss) self.batch_history['train_acc'].append(train_acc) def batch_update_val(self, val_loss, preds, labels): val_acc = self.calculate_accuracy(preds, labels) self.batch_history['val_loss'].append(val_loss) self.batch_history['val_acc'].append(val_acc) def get_batch_history(self): return self.batch_history def print_history(self, epoch, max_epochs, time_elapsed, verbose=True): if verbose: print(f'Epoch {epoch:>3}/{max_epochs} - {time_elapsed:.2f}s - loss: {self.history["train_loss"][-1]:.4f} - accuracy: {self.history["train_acc"][-1]:.4f} - val_loss: {self.history["val_loss"][-1]:.4f} - val_accuracy: {self.history["val_acc"][-1]:.4f}')