48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import numpy as np
|
|
|
|
class History:
|
|
"""
|
|
Class to store the history of the training process.
|
|
Used to store the loss and accuracy of the training and validation sets.
|
|
"""
|
|
def __init__(self):
|
|
self.history = {
|
|
'loss': [],
|
|
'train_acc': [],
|
|
'val_acc': [],
|
|
}
|
|
self.batch_history = {
|
|
'loss': [],
|
|
'train_acc': [],
|
|
'val_acc': [],
|
|
}
|
|
|
|
def update(self):
|
|
self.history['loss'].append(np.mean(self.batch_history['loss']))
|
|
self.history['train_acc'].append(np.mean(self.batch_history['train_acc']))
|
|
self.history['val_acc'].append(np.mean(self.batch_history['val_acc']))
|
|
|
|
def get_history(self):
|
|
return self.history
|
|
|
|
def batch_reset(self):
|
|
self.batch_history = {
|
|
'loss': [],
|
|
'train_acc': [],
|
|
'val_acc': [],
|
|
}
|
|
|
|
def batch_update(self, loss, train_acc, val_acc):
|
|
self.batch_history['loss'].append(loss)
|
|
self.batch_history['train_acc'].append(train_acc)
|
|
self.batch_history['val_acc'].append(val_acc)
|
|
|
|
def batch_update_train(self, loss, train_acc):
|
|
self.batch_history['loss'].append(loss)
|
|
self.batch_history['train_acc'].append(train_acc)
|
|
|
|
def batch_update_val(self, val_acc):
|
|
self.batch_history['val_acc'].append(val_acc)
|
|
|
|
def get_batch_history(self):
|
|
return self.batch_history |