ANLP_WS24_CA2/ml_history.py

116 lines
4.0 KiB
Python

import numpy as np
from sklearn.metrics import mean_squared_error
from datetime import datetime
import json
import os
class History:
"""
Class to store the history of the training process.
Used to store the loss and rmse of the training and validation sets.
"""
def __init__(self):
self.history = {
'train_loss': [],
'val_loss': [],
'train_rmse': [],
'val_rmse': [],
'val_labels': [],
# val_preds contains structs {epoch: [preds], ...}
'val_preds': [],
# only needed in the end not in training
'test_labels': [],
'test_preds': [],
}
self.batch_history = {
'train_loss': [],
'val_loss': [],
'train_rmse': [],
'val_rmse': [],
}
def update(self):
if self.batch_history['train_loss']:
self.history['train_loss'].append(np.mean(self.batch_history['train_loss']))
if self.batch_history['val_loss']:
self.history['val_loss'].append(np.mean(self.batch_history['val_loss']))
if self.batch_history['train_rmse']:
self.history['train_rmse'].append(np.mean(self.batch_history['train_rmse']))
if self.batch_history['val_rmse']:
self.history['val_rmse'].append(np.mean(self.batch_history['val_rmse']))
def get_history(self):
return self.history
def calculate_rmse(self, outputs, labels):
return np.sqrt(mean_squared_error(labels, outputs))
def batch_reset(self):
self.batch_history = {
'train_loss': [],
'val_loss': [],
'train_rmse': [],
'val_rmse': [],
}
def batch_update(self, train_loss, val_loss, train_rmse, val_rmse):
self.batch_history['train_loss'].append(train_loss)
self.batch_history['val_loss'].append(val_loss)
self.batch_history['train_rmse'].append(train_rmse)
self.batch_history['val_rmse'].append(val_rmse)
def batch_update_train(self, train_loss, preds, labels):
train_rmse = self.calculate_rmse(preds, labels)
self.batch_history['train_loss'].append(train_loss)
self.batch_history['train_rmse'].append(train_rmse)
def batch_update_val(self, val_loss, preds, labels, epoch):
val_rmse = self.calculate_rmse(preds, labels)
self.batch_history['val_loss'].append(val_loss)
self.batch_history['val_rmse'].append(val_rmse)
self.history['val_labels'] = labels.tolist()
self.history['val_preds'].append({epoch: preds.tolist()})
def get_batch_history(self):
return self.batch_history
def add_test_results(self, test_labels, test_preds):
self.history['test_labels'] = test_labels
self.history['test_preds'] = test_preds
def convert_hist(self):
# Needed for saving the history to a json file:
# convert numpy arrays to lists and use float instead of numpy float
history_to_save = {}
for hist_key, hist_val in self.history.items():
if hist_key == 'val_preds':
history_to_save[hist_key] = [{k: [float(x) for x in v] for k, v in val.items()} for val in hist_val]
else:
history_to_save[hist_key] = [float(x) for x in hist_val]
return history_to_save
def save_history(self, hist_name, timestamp=None):
directory = "histories"
if not os.path.exists(directory):
os.makedirs(directory) # Create the directory if it does not exist
if timestamp is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(directory, f"{hist_name}_{timestamp}.json")
# Needed for saving the history to a json file:
# convert numpy arrays to lists and use float instead of numpy float
history_to_save = self.convert_hist()
with open(filepath, 'w') as f:
json.dump(history_to_save, f, indent=4)
print(f"History saved to {filepath}")