From 0883c9c93bcc5f30d8b7925b186396b45c683356 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 6 Nov 2025 15:07:53 +0800 Subject: [PATCH 01/16] modify unified_focal_loss.py Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 8484eb67ed..4f1a72dc2d 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -162,6 +162,9 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + include_background: bool = True, + sigmoid: bool = False, + softmax: bool = False, ): """ Args: @@ -188,6 +191,9 @@ def __init__( self.weight: float = weight self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.include_background = include_background + self.sigmoid = sigmoid + self.softmax = softmax # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: From 0a29b5eee3337f6e49840e8df0ed7a126b41bae3 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 6 Nov 2025 16:42:48 +0800 Subject: [PATCH 02/16] deleta use_sigmoid parameter Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 4f1a72dc2d..5c37f47c43 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -163,17 +163,19 @@ def __init__( delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, include_background: bool = True, - sigmoid: bool = False, - softmax: bool = False, + use_softmax: bool = False ): """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. num_classes : number of classes, it only supports 2 now. Defaults to 2. + weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + reduction : reduction mode for the loss. Defaults to MEAN. + include_background : whether to include the background class in loss calculation. Defaults to True. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. Example: >>> import torch @@ -184,6 +186,8 @@ def __init__( >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) + if use_sigmoid and use_softmax: + raise ValueError("use_sigmoid and use_softmax are mutually exclusive; only one can be True.") self.to_onehot_y = to_onehot_y self.num_classes = num_classes self.gamma = gamma @@ -192,8 +196,7 @@ def __init__( self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) self.include_background = include_background - self.sigmoid = sigmoid - self.softmax = softmax + self.use_softmax = use_softmax # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: From 83d2318dae2963871683f28e26626ef95f3073da Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 10:28:48 +0800 Subject: [PATCH 03/16] Generalize AsymmetricUnifiedFocalLoss for align interface Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 33 +++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 5c37f47c43..666c4fdc73 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -163,16 +163,16 @@ def __init__( delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, include_background: bool = True, - use_softmax: bool = False + use_softmax: bool = False, ): """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. num_classes : number of classes, it only supports 2 now. Defaults to 2. - weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - reduction : reduction mode for the loss. Defaults to MEAN. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + weight : weight for each loss function, if it's none it's 0.5. Defaults to None. include_background : whether to include the background class in loss calculation. Defaults to True. use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. @@ -186,15 +186,15 @@ def __init__( >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) - if use_sigmoid and use_softmax: - raise ValueError("use_sigmoid and use_softmax are mutually exclusive; only one can be True.") self.to_onehot_y = to_onehot_y self.num_classes = num_classes self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + self.asy_focal_loss = AsymmetricFocalLoss(to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta) + self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( + to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta + ) self.include_background = include_background self.use_softmax = use_softmax @@ -205,8 +205,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. The input should be the original logits since it will be transformed by - a sigmoid in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. + a sigmoid or softmax in the forward function. + y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes. It only supports binary segmentation. Raises: @@ -235,6 +235,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: else: y_true = one_hot(y_true, num_classes=n_pred_ch) + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + y_pred = y_pred[:, 1:] + y_true = y_true[:, 1:] + + if self.use_softmax: + y_pred = torch.softmax(y_pred.float(), dim=1) + else: + y_pred = torch.sigmoid(y_pred.float()) + asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) From 1a2891727c887bba3a08334e04b849bfdecf6fe9 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 11:04:36 +0800 Subject: [PATCH 04/16] Enhanced AsymmetricUnifiedFocalLoss with Sigmoid/Softmax Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 666c4fdc73..56227a8d96 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -39,6 +39,7 @@ def __init__( gamma: float = 0.75, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + include_background: bool = True, ) -> None: """ Args: @@ -46,12 +47,14 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include background class in loss calculation. Defaults to True. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.include_background = include_background def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -79,6 +82,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: back_dice = 1 - dice_class[:, 0] fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + if not self.include_background: + back_dice = back_dice * 0.0 + # Average class scores loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) return loss @@ -103,6 +109,7 @@ def __init__( gamma: float = 2, epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + include_background: bool = True, ): """ Args: @@ -110,12 +117,14 @@ def __init__( delta : weight of the background. Defaults to 0.7. gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + include_background: whether to include background class in loss calculation. Defaults to True. """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y self.delta = delta self.gamma = gamma self.epsilon = epsilon + self.include_background = include_background def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -138,6 +147,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fore_ce = cross_entropy[:, 1] fore_ce = self.delta * fore_ce + if not self.include_background: + back_ce = back_ce * 0.0 + loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) return loss From 836ce42c7fb1c30b599a6f9519004b3eed1489c9 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 11:21:21 +0800 Subject: [PATCH 05/16] Enhanced AsymmetricUnifiedFocalLoss with Sigmoid/Softmax Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 56227a8d96..d84929e43b 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -79,14 +79,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes - back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) + back_dice = 1 - dice_class[:, 0:1] + fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) if not self.include_background: back_dice = back_dice * 0.0 + all_dice = torch.cat([back_dice, fore_dice], dim=1) + # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) + loss = torch.mean(all_dice) return loss @@ -141,16 +143,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * torch.log(y_pred) - back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] + back_ce = torch.pow(1 - y_pred[:, 0:1], self.gamma) * cross_entropy[:, 0:1] back_ce = (1 - self.delta) * back_ce - fore_ce = cross_entropy[:, 1] + fore_ce = cross_entropy[:, 1:] fore_ce = self.delta * fore_ce if not self.include_background: back_ce = back_ce * 0.0 - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) + all_ce = torch.cat([back_ce, fore_ce], dim=1) + + loss = torch.mean(torch.sum(all_ce, dim=1)) return loss From 81af139cc4885ad50b62d4a007d9cdf917e7bef6 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 11:44:03 +0800 Subject: [PATCH 06/16] minor fix: 1. Add stacklevel=2 to warning. 2. Pass include_background to sub-losses. 3. let sub-losses handle include_background. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index d84929e43b..6b429038de 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -207,9 +207,11 @@ def __init__( self.gamma = gamma self.delta = delta self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta) + self.asy_focal_loss = AsymmetricFocalLoss( + to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background + ) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( - to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta + to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background ) self.include_background = include_background self.use_softmax = use_softmax @@ -251,14 +253,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: else: y_true = one_hot(y_true, num_classes=n_pred_ch) - if not self.include_background: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `include_background=False` ignored.") - else: - # if skipping background, removing first channel - y_pred = y_pred[:, 1:] - y_true = y_true[:, 1:] - if self.use_softmax: y_pred = torch.softmax(y_pred.float(), dim=1) else: From 08bfb0e1d429462d3e9a2a84273ffa5f2bf0595f Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 13:25:15 +0800 Subject: [PATCH 07/16] minor fixes Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 42 ++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 6b429038de..aa5874d735 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -54,7 +54,7 @@ def __init__( self.delta = delta self.gamma = gamma self.epsilon = epsilon - self.include_background = include_background + self.include_background: bool = include_background def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -77,6 +77,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: fn = torch.sum(y_true * (1 - y_pred), dim=axis) fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) + dice_class = torch.clamp(dice_class, self.epsilon, 1.0 - self.epsilon) # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0:1] @@ -126,7 +127,7 @@ def __init__( self.delta = delta self.gamma = gamma self.epsilon = epsilon - self.include_background = include_background + self.include_background: bool = include_background def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] @@ -154,7 +155,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: all_ce = torch.cat([back_ce, fore_ce], dim=1) - loss = torch.mean(torch.sum(all_ce, dim=1)) + loss = torch.mean(all_ce) return loss @@ -184,11 +185,11 @@ def __init__( """ Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. + num_classes : number of classes. Defaults to 2. + weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. + reduction : reduction mode for the loss. Defaults to LossReduction.MEAN. include_background : whether to include the background class in loss calculation. Defaults to True. use_softmax: whether to use softmax to transform the original logits into probabilities. If True, softmax is used. If False, sigmoid is used. Defaults to False. @@ -208,12 +209,20 @@ def __init__( self.delta = delta self.weight: float = weight self.asy_focal_loss = AsymmetricFocalLoss( - to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background + to_onehot_y=self.to_onehot_y, + gamma=self.gamma, + delta=self.delta, + include_background=self.include_background, + reduction=LossReduction.NONE, ) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( - to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background + to_onehot_y=self.to_onehot_y, + gamma=self.gamma, + delta=self.delta, + include_background=self.include_background, + reduction=LossReduction.NONE, ) - self.include_background = include_background + self.include_background: bool = include_background self.use_softmax = use_softmax # TODO: Implement this function to support multiple classes segmentation @@ -240,10 +249,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) - - if torch.max(y_true) != self.num_classes - 1: + if self.num_classes != 2: + raise ValueError( + f"Single-channel input only supported for binary (num_classes=2), got {self.num_classes}" + ) + y_pred = torch.cat([torch.zeros_like(y_pred), y_pred], dim=1) + if y_true.shape[1] == 1: + y_true = one_hot(y_true, num_classes=self.num_classes) + + if y_true.shape[1] != self.num_classes and torch.max(y_true) > self.num_classes - 1: raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") n_pred_ch = y_pred.shape[1] From 3c1ec33a20a9d2b676ef19b3f3e9689d8734b51a Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 13:34:28 +0800 Subject: [PATCH 08/16] Fix docstring default value and validation logic and message. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index aa5874d735..025cbbb178 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -45,7 +45,7 @@ def __init__( Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. + gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. include_background: whether to include background class in loss calculation. Defaults to True. """ @@ -257,8 +257,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape[1] == 1: y_true = one_hot(y_true, num_classes=self.num_classes) - if y_true.shape[1] != self.num_classes and torch.max(y_true) > self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}") + if y_true.shape[1] != self.num_classes or torch.max(y_true) > self.num_classes - 1: + raise ValueError( + f"y_true must have {self.num_classes} channels (one-hot) or label values in [0, {self.num_classes - 1}], " + f"but got shape {y_true.shape} with max value {torch.max(y_true)}" + ) n_pred_ch = y_pred.shape[1] if self.to_onehot_y: From c3c25707fe75298b2bdbbb1ce9afcc7e12eb595d Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 13:39:21 +0800 Subject: [PATCH 09/16] Fix docstring default value. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 025cbbb178..cfbce9b895 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -45,7 +45,7 @@ def __init__( Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. + gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. include_background: whether to include background class in loss calculation. Defaults to True. """ From 912235e252e9a77ca7856fa1dcd16a410ee693d6 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 13:58:01 +0800 Subject: [PATCH 10/16] fix AttributeError: 'AsymmetricUnifiedFocalLoss' object has no attribute 'include_background'. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index cfbce9b895..d8428e6092 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -208,6 +208,8 @@ def __init__( self.gamma = gamma self.delta = delta self.weight: float = weight + self.include_background: bool = include_background + self.use_softmax = use_softmax self.asy_focal_loss = AsymmetricFocalLoss( to_onehot_y=self.to_onehot_y, gamma=self.gamma, @@ -222,8 +224,6 @@ def __init__( include_background=self.include_background, reduction=LossReduction.NONE, ) - self.include_background: bool = include_background - self.use_softmax = use_softmax # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: From 230394ba4ede6b061fa7f1ee742b8c32a10e7fe8 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 14:17:21 +0800 Subject: [PATCH 11/16] Relocate the sigmoid/softmax application to be conditional on the number of channels. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index d8428e6092..5bafe67c35 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -253,9 +253,20 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: raise ValueError( f"Single-channel input only supported for binary (num_classes=2), got {self.num_classes}" ) - y_pred = torch.cat([torch.zeros_like(y_pred), y_pred], dim=1) + + if self.use_softmax: + raise ValueError("use_softmax=True is not compatible with single-channel input") + + y_pred_sigmoid = torch.sigmoid(y_pred.float()) + y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1) + if y_true.shape[1] == 1: y_true = one_hot(y_true, num_classes=self.num_classes) + else: + if self.use_softmax: + y_pred = torch.softmax(y_pred.float(), dim=1) + else: + y_pred = torch.sigmoid(y_pred.float()) if y_true.shape[1] != self.num_classes or torch.max(y_true) > self.num_classes - 1: raise ValueError( From b52c57079ba8439fb2c6d056c74f761619e59d21 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 14:23:58 +0800 Subject: [PATCH 12/16] Fix docstring default value. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 5bafe67c35..d70ed736a1 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -118,7 +118,7 @@ def __init__( Args: to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. + gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2. epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. include_background: whether to include background class in loss calculation. Defaults to True. """ From c0e9d78fe05a2596aafc9aa8c5a3cb5b78364ef4 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 15:04:23 +0800 Subject: [PATCH 13/16] fix the loss function activates y_pred more than once (double activation) Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 33 +++++++++--------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index d70ed736a1..3d0954421c 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -248,26 +248,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - if y_pred.shape[1] == 1: - if self.num_classes != 2: - raise ValueError( - f"Single-channel input only supported for binary (num_classes=2), got {self.num_classes}" - ) - - if self.use_softmax: - raise ValueError("use_softmax=True is not compatible with single-channel input") - - y_pred_sigmoid = torch.sigmoid(y_pred.float()) - y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1) - - if y_true.shape[1] == 1: - y_true = one_hot(y_true, num_classes=self.num_classes) - else: - if self.use_softmax: - y_pred = torch.softmax(y_pred.float(), dim=1) - else: - y_pred = torch.sigmoid(y_pred.float()) - if y_true.shape[1] != self.num_classes or torch.max(y_true) > self.num_classes - 1: raise ValueError( f"y_true must have {self.num_classes} channels (one-hot) or label values in [0, {self.num_classes - 1}], " @@ -281,10 +261,17 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: else: y_true = one_hot(y_true, num_classes=n_pred_ch) - if self.use_softmax: - y_pred = torch.softmax(y_pred.float(), dim=1) + if y_pred.shape[1] == 1: + y_pred_sigmoid = torch.sigmoid(y_pred.float()) + y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1) + + if y_true.shape[1] == 1: + y_true = one_hot(y_true, num_classes=self.num_classes) else: - y_pred = torch.sigmoid(y_pred.float()) + if self.use_softmax: + y_pred = torch.softmax(y_pred.float(), dim=1) + else: + y_pred = torch.sigmoid(y_pred.float()) asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) From b7a5013b9c358c15a12ae4ac3ecd65329b89f3fa Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 15:14:02 +0800 Subject: [PATCH 14/16] Validation logic may reject valid inputs. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 3d0954421c..17cc27ed7c 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -248,10 +248,17 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") - if y_true.shape[1] != self.num_classes or torch.max(y_true) > self.num_classes - 1: + if y_true.shape[1] == self.num_classes: + if not torch.all((y_true == 0) | (y_true == 1)): + raise ValueError(f"y_true appears to be one-hot but contains values other than 0 and 1") + elif y_true.shape[1] == 1: + if torch.max(y_true) >= self.num_classes: + raise ValueError( + f"y_true labels must be in [0, {self.num_classes - 1}], but got max {torch.max(y_true)}" + ) + else: raise ValueError( - f"y_true must have {self.num_classes} channels (one-hot) or label values in [0, {self.num_classes - 1}], " - f"but got shape {y_true.shape} with max value {torch.max(y_true)}" + f"y_true must have {self.num_classes} channels (one-hot) or 1 channel (labels), got {y_true.shape[1]}" ) n_pred_ch = y_pred.shape[1] From 630ad7a5c3b8faeec852cbae7ec251fedd31028f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Nov 2025 07:16:52 +0000 Subject: [PATCH 15/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/unified_focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 17cc27ed7c..7b0b9949b7 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -250,7 +250,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if y_true.shape[1] == self.num_classes: if not torch.all((y_true == 0) | (y_true == 1)): - raise ValueError(f"y_true appears to be one-hot but contains values other than 0 and 1") + raise ValueError("y_true appears to be one-hot but contains values other than 0 and 1") elif y_true.shape[1] == 1: if torch.max(y_true) >= self.num_classes: raise ValueError( From b46f089c59a229893d5e5a876d58b105d06ddc17 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 7 Nov 2025 16:54:58 +0800 Subject: [PATCH 16/16] fix using function twice. Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 17cc27ed7c..9c7e412c30 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -22,9 +22,9 @@ class AsymmetricFocalTverskyLoss(_Loss): """ - AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that focuses on foreground classes. - Actually, it's only supported for binary image segmentation now. + Supports multi-class segmentation with optional background inclusion. Reimplementation of the Asymmetric Focal Tversky Loss described in: @@ -61,7 +61,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -95,9 +95,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: class AsymmetricFocalLoss(_Loss): """ - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricFocalLoss is a variant of Focal Loss that focuses on foreground classes. - Actually, it's only supported for binary image segmentation now. + Supports multi-class segmentation with optional background inclusion. Reimplementation of the Asymmetric Focal Loss described in: @@ -134,7 +134,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: y_true = one_hot(y_true, num_classes=n_pred_ch) @@ -161,9 +161,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: class AsymmetricUnifiedFocalLoss(_Loss): """ - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. + AsymmetricUnifiedFocalLoss combines Asymmetric Focal Loss and Asymmetric Focal Tversky Loss. - Actually, it's only supported for binary image segmentation now + Supports multi-class segmentation with configurable activation (sigmoid/softmax) and optional background inclusion. Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: @@ -201,6 +201,11 @@ def __init__( >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) >>> fl(pred, grnd) + >>> # Multiclass example with 3 classes + >>> pred_mc = torch.randn((1,3,32,32), dtype=torch.float32) + >>> grnd_mc = torch.randint(0, 3, (1,1,32,32), dtype=torch.int64) + >>> fl_mc = AsymmetricUnifiedFocalLoss(to_onehot_y=True, num_classes=3, use_softmax=True) + >>> fl_mc(pred_mc, grnd_mc) """ super().__init__(reduction=LossReduction(reduction).value) self.to_onehot_y = to_onehot_y @@ -225,16 +230,13 @@ def __init__( reduction=LossReduction.NONE, ) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. The input should be the original logits since it will be transformed by a sigmoid or softmax in the forward function. y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes. - It only supports binary segmentation. Raises: ValueError: When input and target are different shape @@ -264,11 +266,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: y_true = one_hot(y_true, num_classes=n_pred_ch) if y_pred.shape[1] == 1: + warnings.warn("single channel prediction, augmenting with background channel.", stacklevel=2) y_pred_sigmoid = torch.sigmoid(y_pred.float()) y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1) @@ -278,7 +281,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.use_softmax: y_pred = torch.softmax(y_pred.float(), dim=1) else: - y_pred = torch.sigmoid(y_pred.float()) + y_pred = y_pred.float() asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)