116 lines
4.0 KiB
Python
116 lines
4.0 KiB
Python
import numpy as np
|
|
from sklearn.metrics import mean_squared_error
|
|
|
|
from datetime import datetime
|
|
import json
|
|
import os
|
|
|
|
|
|
class History:
|
|
"""
|
|
Class to store the history of the training process.
|
|
Used to store the loss and rmse of the training and validation sets.
|
|
"""
|
|
def __init__(self):
|
|
self.history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
|
|
'train_rmse': [],
|
|
'val_rmse': [],
|
|
|
|
'val_labels': [],
|
|
# val_preds contains structs {epoch: [preds], ...}
|
|
'val_preds': [],
|
|
|
|
# only needed in the end not in training
|
|
'test_labels': [],
|
|
'test_preds': [],
|
|
}
|
|
self.batch_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
|
|
'train_rmse': [],
|
|
'val_rmse': [],
|
|
}
|
|
|
|
def update(self):
|
|
if self.batch_history['train_loss']:
|
|
self.history['train_loss'].append(np.mean(self.batch_history['train_loss']))
|
|
if self.batch_history['val_loss']:
|
|
self.history['val_loss'].append(np.mean(self.batch_history['val_loss']))
|
|
if self.batch_history['train_rmse']:
|
|
self.history['train_rmse'].append(np.mean(self.batch_history['train_rmse']))
|
|
if self.batch_history['val_rmse']:
|
|
self.history['val_rmse'].append(np.mean(self.batch_history['val_rmse']))
|
|
|
|
def get_history(self):
|
|
return self.history
|
|
|
|
def calculate_rmse(self, outputs, labels):
|
|
return np.sqrt(mean_squared_error(labels, outputs))
|
|
|
|
def batch_reset(self):
|
|
self.batch_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
'train_rmse': [],
|
|
'val_rmse': [],
|
|
}
|
|
|
|
def batch_update(self, train_loss, val_loss, train_rmse, val_rmse):
|
|
self.batch_history['train_loss'].append(train_loss)
|
|
self.batch_history['val_loss'].append(val_loss)
|
|
self.batch_history['train_rmse'].append(train_rmse)
|
|
self.batch_history['val_rmse'].append(val_rmse)
|
|
|
|
def batch_update_train(self, train_loss, preds, labels):
|
|
train_rmse = self.calculate_rmse(preds, labels)
|
|
self.batch_history['train_loss'].append(train_loss)
|
|
self.batch_history['train_rmse'].append(train_rmse)
|
|
|
|
def batch_update_val(self, val_loss, preds, labels, epoch):
|
|
val_rmse = self.calculate_rmse(preds, labels)
|
|
self.batch_history['val_loss'].append(val_loss)
|
|
self.batch_history['val_rmse'].append(val_rmse)
|
|
|
|
self.history['val_labels'] = labels.tolist()
|
|
self.history['val_preds'].append({epoch: preds.tolist()})
|
|
|
|
|
|
def get_batch_history(self):
|
|
return self.batch_history
|
|
|
|
def add_test_results(self, test_labels, test_preds):
|
|
self.history['test_labels'] = test_labels
|
|
self.history['test_preds'] = test_preds
|
|
|
|
|
|
def convert_hist(self):
|
|
# Needed for saving the history to a json file:
|
|
# convert numpy arrays to lists and use float instead of numpy float
|
|
history_to_save = {}
|
|
for hist_key, hist_val in self.history.items():
|
|
if hist_key == 'val_preds':
|
|
history_to_save[hist_key] = [{k: [float(x) for x in v] for k, v in val.items()} for val in hist_val]
|
|
else:
|
|
history_to_save[hist_key] = [float(x) for x in hist_val]
|
|
|
|
return history_to_save
|
|
|
|
def save_history(self, hist_name, timestamp=None):
|
|
directory = "histories"
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory) # Create the directory if it does not exist
|
|
if timestamp is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filepath = os.path.join(directory, f"{hist_name}_{timestamp}.json")
|
|
|
|
# Needed for saving the history to a json file:
|
|
# convert numpy arrays to lists and use float instead of numpy float
|
|
history_to_save = self.convert_hist()
|
|
|
|
with open(filepath, 'w') as f:
|
|
json.dump(history_to_save, f, indent=4)
|
|
print(f"History saved to {filepath}") |