ANLP_WS24_CA2/ml_history.py

48 lines
1.4 KiB
Python

import numpy as np
class History:
"""
Class to store the history of the training process.
Used to store the loss and accuracy of the training and validation sets.
"""
def __init__(self):
self.history = {
'loss': [],
'train_acc': [],
'val_acc': [],
}
self.batch_history = {
'loss': [],
'train_acc': [],
'val_acc': [],
}
def update(self):
self.history['loss'].append(np.mean(self.batch_history['loss']))
self.history['train_acc'].append(np.mean(self.batch_history['train_acc']))
self.history['val_acc'].append(np.mean(self.batch_history['val_acc']))
def get_history(self):
return self.history
def batch_reset(self):
self.batch_history = {
'loss': [],
'train_acc': [],
'val_acc': [],
}
def batch_update(self, loss, train_acc, val_acc):
self.batch_history['loss'].append(loss)
self.batch_history['train_acc'].append(train_acc)
self.batch_history['val_acc'].append(val_acc)
def batch_update_train(self, loss, train_acc):
self.batch_history['loss'].append(loss)
self.batch_history['train_acc'].append(train_acc)
def batch_update_val(self, val_acc):
self.batch_history['val_acc'].append(val_acc)
def get_batch_history(self):
return self.batch_history