diff --git a/BalancedCELoss.py b/BalancedCELoss.py new file mode 100644 index 0000000..51aa0a4 --- /dev/null +++ b/BalancedCELoss.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import numpy as np + +class BalancedCELoss(nn.Module): + def __init__(self, alpha=0.1): + super(BalancedCELoss, self).__init__() + self.bce_loss = nn.CrossEntropyLoss() + self.alpha = alpha + + def forward(self, predictions, targets): + # detect num of unique classes + num_classes = len(torch.unique(targets)) + if num_classes == 1: + # If only one class than split it into two classes + predictions = torch.cat((1 - predictions, predictions), dim=1) + + + # Calculate the standard binary cross-entropy loss + bce_loss = self.bce_loss(predictions, targets) + + predictions = torch.argmax(predictions, dim=1) + + # Calculate the number of predictions for each class + class_0_preds_n = predictions[predictions == 0] + class_1_preds_n = predictions[predictions == 1] + + # Calculate the number of labels for each class based on predictions + class_0_labels_n = targets[targets == 0] + class_1_labels_n = targets[targets == 1] + + preds_ratio_0 = len(class_0_preds_n) / len(predictions) + preds_ratio_1 = len(class_1_preds_n) / len(predictions) + + labels_ratio_0 = len(class_0_labels_n) / len(targets) + labels_ratio_1 = len(class_1_labels_n) / len(targets) + + # Calculate the imbalance penalty + imbalance_penalty = np.abs(preds_ratio_0 - labels_ratio_0) + np.abs(preds_ratio_1 - labels_ratio_1) + + # Combine the BCE loss with the imbalance penalty + total_loss = bce_loss + self.alpha * imbalance_penalty + + return total_loss \ No newline at end of file