import torch class EarlyStopping: def __init__(self, patience=5, verbose=False): self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score: self.counter += 1 if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model, filename='checkpoint.pt'): if self.verbose: print(f'Validation loss decreased ({self.best_score:.6f} --> {val_loss:.6f}). Saving model ...') torch.save(model.state_dict(), f'checkpoints/{filename}')