Skip to content

Fix: add ensure_channel_first in utils and integrate into DiceHelper (refs #8366) #8532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ runs
*.pth

*zarr/*
issue38366/
52 changes: 52 additions & 0 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,58 @@

__all__ = ["sliding_window_inference"]

def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify import path used by DiceHelper.

PR text says the function was added to monai/metrics/utils.py, but it’s implemented here (monai/inferers/utils.py). Confirm call sites import the correct module path, or move the helper to the intended module.

Run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail
echo "Usages of ensure_channel_first:"
rg -n --hidden -S '\bensure_channel_first\b' || true
echo
echo "Imports of ensure_channel_first:"
rg -n --hidden -S 'from monai\.(metrics|inferers)\.utils import ensure_channel_first' || true
echo
echo "Definition locations:"
ast-grep --pattern $'def ensure_channel_first($_):\n  $$$'

Length of output: 3584


Ensure correct import path for ensure_channel_first
The helper lives in monai/inferers/utils.py, but in monai/metrics/meandice.py (line 18) it’s imported from monai.metrics.utils, which doesn’t define it.

• Either move ensure_channel_first into monai/metrics/utils.py and keep the existing import
• Or update the import in monai/metrics/meandice.py to:
from monai.inferers.utils import ensure_channel_first

🧰 Tools
🪛 Ruff (0.12.2)

41-41: Undefined name Optional

(F821)


41-41: Undefined name Tuple

(F821)

🤖 Prompt for AI Agents
In monai/metrics/meandice.py at line 18, the import statement for
ensure_channel_first is incorrect because it imports from monai.metrics.utils
where the function is not defined. To fix this, update the import statement to
import ensure_channel_first from monai.inferers.utils instead, i.e., change it
to 'from monai.inferers.utils import ensure_channel_first'. Alternatively, if
preferred, move the ensure_channel_first function from monai/inferers/utils.py
to monai/metrics/utils.py and keep the existing import unchanged.

"""
Normalize a tensor to channel-first layout (N, C, spatial...).

Args:
x: Tensor with shape (N, C, spatial...) or (N, spatial..., C).
spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2.

Returns:
A tuple (x_cf, orig_channel_dim):
- x_cf: the tensor in channel-first layout.
- orig_channel_dim: 1 if input was already channel-first; -1 if the channel was last.

Raises:
TypeError: if x is not a torch.Tensor.
ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously.

Notes:
Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous,
prefers preserving the input layout or raises ValueError to avoid silent errors.
"""
if not isinstance(x, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(x)}")
if x.ndim < 3:
raise ValueError(f"Expected >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}")

# Infer spatial dims if not provided (handles 1D/2D/3D uniformly).
if spatial_ndim is None:
spatial_ndim = x.ndim - 2 # not directly used for logic; informative only

# Heuristic: channels are usually small (e.g., <=32) in segmentation/classification.
threshold = 32
s1 = int(x.shape[1]) # candidate channel at dim=1 (N, C, ...)
sl = int(x.shape[-1]) # candidate channel at last dim (..., C)

# Unambiguous cases first.
if s1 <= threshold and sl > threshold:
# Looks like NCHW/D already.
return x, 1
if sl <= threshold and s1 > threshold:
# Looks like NHWC/D: move last dim to channel dim.
return x.movedim(-1, 1), -1

# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
if s1 <= threshold and sl <= threshold:
return x, 1

raise ValueError(
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
f"both dim1={s1} and dim-1={sl} look like spatial dims"
)
Comment on lines +84 to +91
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: both-large case wrongly raises; breaks valid NCHW with many classes (C > 32).

For y_pred in channel-first with many classes (e.g., N, 64, H, W), s1 > 32 and sl (W) > 32; this path raises ValueError and will break Dice. Prefer preserving layout in both-large cases (consistent with the comment on Line 84).

Apply:

-    # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
-    if s1 <= threshold and sl <= threshold:
-        return x, 1
-
-    raise ValueError(
-        f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
-        f"both dim1={s1} and dim-1={sl} look like spatial dims"
-    )
+    # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
+    if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
+        return x, 1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
if s1 <= threshold and sl <= threshold:
return x, 1
raise ValueError(
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
f"both dim1={s1} and dim-1={sl} look like spatial dims"
)
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
return x, 1
🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 84 to 91, the code currently raises a
ValueError when both dim1 and dim-1 are large (e.g., C>32), which incorrectly
rejects valid channel-first tensors; change the branching so that when both s1
and sl are small OR both are large the function preserves the existing layout
(return x, 1) instead of raising; only raise when sizes are truly ambiguous
(e.g., one small and one small-ish conflict) or adjust the thresholds/conditions
accordingly and update the error message to reflect remaining ambiguous cases.



def sliding_window_inference(
inputs: torch.Tensor | MetaTensor,
Expand Down
10 changes: 10 additions & 0 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
from monai.inferers.utils import ensure_channel_first


from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
num_classes=self.num_classes,
)


def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Compute the dice value using ``DiceHelper``.
Expand Down Expand Up @@ -306,6 +309,13 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
"""
y_pred, _ = ensure_channel_first(y_pred)

n_ch = self.num_classes or y_pred.shape[1]
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch):
y, _ = ensure_channel_first(y)

Comment on lines +312 to +317
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Channel-last normalization is correctly integrated; ensure upstream heuristic doesn’t blow up.

This is the right place to normalize y_pred and y. However, with current ensure_channel_first, channel-first tensors with many classes (C > 32) will raise unless the utils fix is applied. Once that’s in, this path is solid.

To be safe, verify both channel-first and channel-last with larger class counts:

Expected: both print tensor shapes without exceptions.


🏁 Script executed:

#!/bin/bash
python - <<'PY'
import torch
from monai.metrics.meandice import DiceHelper

def run(n=2, c=64, h=64, w=64):
    y_pred_cf = torch.rand(n, c, h, w)
    y_cf = torch.randint(0, c, (n, 1, h, w))
    print("CF/C>32:", DiceHelper(num_classes=c)(y_pred_cf, y_cf).shape)

    y_pred_cl = torch.rand(n, h, w, c)
    y_cl = torch.randint(0, c, (n, h, w, 1))
    print("CL/C>32:", DiceHelper(num_classes=c)(y_pred_cl, y_cl).shape)

run()
PY

Length of output: 128


🏁 Script executed:

#!/bin/bash
set -eo pipefail

# Locate ensure_channel_first definition
rg -n "def ensure_channel_first" -g "*.py"

Length of output: 197


🏁 Script executed:

#!/bin/bash
set -eo pipefail

# Show ensure_channel_first implementation
sed -n '1,200p' monai/inferers/utils.py

Length of output: 10537


Support >32 channels in ensure_channel_first
ensure_channel_first in monai/inferers/utils.py hard-codes a 32-channel threshold and will throw on both NCHW and NHWC inputs when C>32. Before merging, please:

  • Update ensure_channel_first to accept a configurable threshold (e.g., use num_classes or expose a kwarg) so it no longer raises for large class counts.
  • Or catch its ValueError in DiceHelper and default to treating dim-1 as the channel axis.
  • Add unit tests for C>32 on both channel-first and channel-last inputs to verify no exceptions.
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 312 to 317, ensure_channel_first can
raise for channel counts >32; modify the code to catch ValueError from
ensure_channel_first and on exception treat dim-1 as the channel axis (i.e.,
transpose NHWC to NCHW manually or accept input as already channel-first),
ensuring n_ch is computed from num_classes or y_pred.shape[1] accordingly;
alternatively, update monai/inferers/utils.py:ensure_channel_first to accept a
configurable threshold or a kwarg (e.g., max_channel_threshold or
force_last_dim_channel) and use that in this call so it won’t raise for large C;
add unit tests that cover C>32 for both channel-first (NCHW) and channel-last
(NHWC) inputs to verify no exceptions and correct channel handling.


_apply_argmax, _threshold = self.apply_argmax, self.threshold
if self.num_classes is None:
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
Expand Down
Loading