import torch import os class EarlyStoppingCallback: def __init__(self, model_name, patience=5, verbose=False): self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.model_name = model_name def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): directory = "models/checkpoints" if not os.path.exists(directory): os.makedirs(directory) # Create the directory if it does not exist if self.verbose: print(f'└ Validation loss decreased ({self.best_score:.6f} --> {val_loss:.6f}). Saving model ...') torch.save(model.state_dict(), os.path.join(directory, self.model_name))