arman 2025-02-09 11:35:49 +01:00
commit e1e9ac57ba
1 changed files with 44 additions and 0 deletions

44
BalancedCELoss.py 100644
View File

@ -0,0 +1,44 @@
import torch
import torch.nn as nn
import numpy as np
class BalancedCELoss(nn.Module):
def __init__(self, alpha=0.1):
super(BalancedCELoss, self).__init__()
self.bce_loss = nn.CrossEntropyLoss()
self.alpha = alpha
def forward(self, predictions, targets):
# detect num of unique classes
num_classes = len(torch.unique(targets))
if num_classes == 1:
# If only one class than split it into two classes
predictions = torch.cat((1 - predictions, predictions), dim=1)
# Calculate the standard binary cross-entropy loss
bce_loss = self.bce_loss(predictions, targets)
predictions = torch.argmax(predictions, dim=1)
# Calculate the number of predictions for each class
class_0_preds_n = predictions[predictions == 0]
class_1_preds_n = predictions[predictions == 1]
# Calculate the number of labels for each class based on predictions
class_0_labels_n = targets[targets == 0]
class_1_labels_n = targets[targets == 1]
preds_ratio_0 = len(class_0_preds_n) / len(predictions)
preds_ratio_1 = len(class_1_preds_n) / len(predictions)
labels_ratio_0 = len(class_0_labels_n) / len(targets)
labels_ratio_1 = len(class_1_labels_n) / len(targets)
# Calculate the imbalance penalty
imbalance_penalty = np.abs(preds_ratio_0 - labels_ratio_0) + np.abs(preds_ratio_1 - labels_ratio_1)
# Combine the BCE loss with the imbalance penalty
total_loss = bce_loss + self.alpha * imbalance_penalty
return total_loss