diff --git a/.gitignore b/.gitignore index 76c6ab0d12..cda9ebc15a 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,6 @@ runs *.pth *zarr/* +issue38366/ + +issue8366/ diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 766486a807..7929817bf9 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -36,7 +36,73 @@ tqdm, _ = optional_import("tqdm", name="tqdm") _nearest_mode = "nearest-exact" -__all__ = ["sliding_window_inference"] +__all__ = ["ensure_channel_first","sliding_window_inference"] + +def ensure_channel_first( + x: torch.Tensor, + spatial_ndim: Optional[int] = None, + channel_hint: Optional[int] = None, + threshold: int = 32, +) -> tuple[torch.Tensor, int]: + """ + Normalize a tensor to channel-first layout (N, C, spatial...). + + Args: + x: Tensor with shape (N, C, spatial...) or (N, spatial..., C). + spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2. + channel_hint: If provided, the expected channel size (e.g., num_classes). When present, + we prioritize matching this size at either dim=1 (channel-first) or dim=-1 (channel-last). + threshold: Heuristic upper bound for typical channel counts to disambiguate layouts. + + Returns: + A tuple (x_cf, orig_channel_dim): + - x_cf: the tensor in channel-first layout. + - orig_channel_dim: 1 if input was already channel-first; -1 if the channel was last. + + Raises: + TypeError: if x is not a torch.Tensor. + ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously. + + Notes: + 1. When channel_hint is provided, it is used as a strong signal to decide layout. + 2. Otherwise, uses a heuristic where channels are usually small (<= threshold). + 3. In ambiguous cases (both candidate dims small or both large), the input layout + is preserved (assumed channel-first) to avoid silent mis-reordering. + """ + if not isinstance(x, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(x)}") + if x.ndim < 3: + raise ValueError(f"Expected >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}") + + if spatial_ndim is None: + spatial_ndim = x.ndim - 2 # informative only + + s1 = int(x.shape[1]) # candidate channel at dim=1 + sl = int(x.shape[-1]) # candidate channel at dim=-1 + + # 1) Strong signal: use channel_hint if provided + if channel_hint is not None: + if s1 == channel_hint and sl != channel_hint: + return x, 1 + if sl == channel_hint and s1 != channel_hint: + return x.movedim(-1, 1), -1 + # if both match or both mismatch, fall back to heuristic + + # 2) Heuristic: channels are usually small + if s1 <= threshold and sl > threshold: + return x, 1 + if sl <= threshold and s1 > threshold: + return x.movedim(-1, 1), -1 + + # 3) Ambiguous: both sides small OR both sides large → preserve as channel-first + if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold): + return x, 1 + + # 4) Should not reach here under normal cases + raise ValueError( + f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]" + ) + def sliding_window_inference( diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 0802cc3364..24426c16c7 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -15,6 +15,8 @@ from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, deprecated_arg +from monai.inferers.utils import ensure_channel_first + from .metric import CumulativeIterationMetric @@ -123,6 +125,7 @@ def __init__( num_classes=self.num_classes, ) + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Compute the dice value using ``DiceHelper``. @@ -306,6 +309,28 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). """ + + + # --- Normalize layout to channel-first (N, C, spatial...) --- + # Prefer a strong signal when available. + if self.num_classes is not None: + y_pred, _ = ensure_channel_first(y_pred, channel_hint=self.num_classes) + else: + # First pass: heuristic only. + y_pred, _ = ensure_channel_first(y_pred) + # Fallback: if implausible vs y's layout, retry with a hint from y's last dim. + if y.ndim == y_pred.ndim: + plausible = {1, y.shape[1], y.shape[-1]} + if y_pred.shape[1] not in plausible: + y_pred, _ = ensure_channel_first(y_pred, channel_hint=int(y.shape[-1])) + + # Infer channels after normalization (or use provided). + n_ch = self.num_classes or y_pred.shape[1] + + # Normalize y if it plausibly is channel-last. + if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch): + y, _ = ensure_channel_first(y, channel_hint=n_ch) + _apply_argmax, _threshold = self.apply_argmax, self.threshold if self.num_classes is None: n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores