import torch import torch.nn as nn import numpy as np class BalancedCELoss(nn.Module): def __init__(self, alpha=0.1): super(BalancedCELoss, self).__init__() self.bce_loss = nn.CrossEntropyLoss() self.alpha = alpha def forward(self, predictions, targets): # detect num of unique classes num_classes = len(torch.unique(targets)) if num_classes == 1: # If only one class than split it into two classes predictions = torch.cat((1 - predictions, predictions), dim=1) # Calculate the standard binary cross-entropy loss bce_loss = self.bce_loss(predictions, targets) predictions = torch.argmax(predictions, dim=1) # Calculate the number of predictions for each class class_0_preds_n = predictions[predictions == 0] class_1_preds_n = predictions[predictions == 1] # Calculate the number of labels for each class based on predictions class_0_labels_n = targets[targets == 0] class_1_labels_n = targets[targets == 1] preds_ratio_0 = len(class_0_preds_n) / len(predictions) preds_ratio_1 = len(class_1_preds_n) / len(predictions) labels_ratio_0 = len(class_0_labels_n) / len(targets) labels_ratio_1 = len(class_1_labels_n) / len(targets) # Calculate the imbalance penalty imbalance_penalty = np.abs(preds_ratio_0 - labels_ratio_0) + np.abs(preds_ratio_1 - labels_ratio_1) # Combine the BCE loss with the imbalance penalty total_loss = bce_loss + self.alpha * imbalance_penalty return total_loss