Merge branch 'main' of https://gitty.informatik.hs-mannheim.de/3016498/ANLP_WS24_CA2
commit
e1e9ac57ba
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
class BalancedCELoss(nn.Module):
|
||||
def __init__(self, alpha=0.1):
|
||||
super(BalancedCELoss, self).__init__()
|
||||
self.bce_loss = nn.CrossEntropyLoss()
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
# detect num of unique classes
|
||||
num_classes = len(torch.unique(targets))
|
||||
if num_classes == 1:
|
||||
# If only one class than split it into two classes
|
||||
predictions = torch.cat((1 - predictions, predictions), dim=1)
|
||||
|
||||
|
||||
# Calculate the standard binary cross-entropy loss
|
||||
bce_loss = self.bce_loss(predictions, targets)
|
||||
|
||||
predictions = torch.argmax(predictions, dim=1)
|
||||
|
||||
# Calculate the number of predictions for each class
|
||||
class_0_preds_n = predictions[predictions == 0]
|
||||
class_1_preds_n = predictions[predictions == 1]
|
||||
|
||||
# Calculate the number of labels for each class based on predictions
|
||||
class_0_labels_n = targets[targets == 0]
|
||||
class_1_labels_n = targets[targets == 1]
|
||||
|
||||
preds_ratio_0 = len(class_0_preds_n) / len(predictions)
|
||||
preds_ratio_1 = len(class_1_preds_n) / len(predictions)
|
||||
|
||||
labels_ratio_0 = len(class_0_labels_n) / len(targets)
|
||||
labels_ratio_1 = len(class_1_labels_n) / len(targets)
|
||||
|
||||
# Calculate the imbalance penalty
|
||||
imbalance_penalty = np.abs(preds_ratio_0 - labels_ratio_0) + np.abs(preds_ratio_1 - labels_ratio_1)
|
||||
|
||||
# Combine the BCE loss with the imbalance penalty
|
||||
total_loss = bce_loss + self.alpha * imbalance_penalty
|
||||
|
||||
return total_loss
|
||||
Loading…
Reference in New Issue