-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Fix: add ensure_channel_first in utils and integrate into DiceHelper (refs #8366) #8532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
7006cbe
d6c68cd
59786d2
f5d2424
fce8287
c2b5fc4
02d528c
a81fce1
a06bf07
77d09cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,3 +165,4 @@ runs | |
*.pth | ||
|
||
*zarr/* | ||
issue38366/ |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -38,6 +38,58 @@ | |||||||||||||||||||||||
|
||||||||||||||||||||||||
__all__ = ["sliding_window_inference"] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]: | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
Normalize a tensor to channel-first layout (N, C, spatial...). | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
x: Tensor with shape (N, C, spatial...) or (N, spatial..., C). | ||||||||||||||||||||||||
spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||
A tuple (x_cf, orig_channel_dim): | ||||||||||||||||||||||||
- x_cf: the tensor in channel-first layout. | ||||||||||||||||||||||||
- orig_channel_dim: 1 if input was already channel-first; -1 if the channel was last. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Raises: | ||||||||||||||||||||||||
TypeError: if x is not a torch.Tensor. | ||||||||||||||||||||||||
ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Notes: | ||||||||||||||||||||||||
Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous, | ||||||||||||||||||||||||
prefers preserving the input layout or raises ValueError to avoid silent errors. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
if not isinstance(x, torch.Tensor): | ||||||||||||||||||||||||
raise TypeError(f"Expected torch.Tensor, got {type(x)}") | ||||||||||||||||||||||||
if x.ndim < 3: | ||||||||||||||||||||||||
raise ValueError(f"Expected >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Infer spatial dims if not provided (handles 1D/2D/3D uniformly). | ||||||||||||||||||||||||
if spatial_ndim is None: | ||||||||||||||||||||||||
spatial_ndim = x.ndim - 2 # not directly used for logic; informative only | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Heuristic: channels are usually small (e.g., <=32) in segmentation/classification. | ||||||||||||||||||||||||
threshold = 32 | ||||||||||||||||||||||||
s1 = int(x.shape[1]) # candidate channel at dim=1 (N, C, ...) | ||||||||||||||||||||||||
sl = int(x.shape[-1]) # candidate channel at last dim (..., C) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Unambiguous cases first. | ||||||||||||||||||||||||
if s1 <= threshold and sl > threshold: | ||||||||||||||||||||||||
# Looks like NCHW/D already. | ||||||||||||||||||||||||
return x, 1 | ||||||||||||||||||||||||
if sl <= threshold and s1 > threshold: | ||||||||||||||||||||||||
# Looks like NHWC/D: move last dim to channel dim. | ||||||||||||||||||||||||
return x.movedim(-1, 1), -1 | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering. | ||||||||||||||||||||||||
if s1 <= threshold and sl <= threshold: | ||||||||||||||||||||||||
return x, 1 | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; " | ||||||||||||||||||||||||
f"both dim1={s1} and dim-1={sl} look like spatial dims" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
Comment on lines
+84
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: both-large case wrongly raises; breaks valid NCHW with many classes (C > 32). For y_pred in channel-first with many classes (e.g., N, 64, H, W), s1 > 32 and sl (W) > 32; this path raises ValueError and will break Dice. Prefer preserving layout in both-large cases (consistent with the comment on Line 84). Apply: - # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
- if s1 <= threshold and sl <= threshold:
- return x, 1
-
- raise ValueError(
- f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
- f"both dim1={s1} and dim-1={sl} look like spatial dims"
- )
+ # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
+ if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
+ return x, 1 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def sliding_window_inference( | ||||||||||||||||||||||||
inputs: torch.Tensor | MetaTensor, | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
|
||
from monai.metrics.utils import do_metric_reduction | ||
from monai.utils import MetricReduction, deprecated_arg | ||
from monai.inferers.utils import ensure_channel_first | ||
|
||
|
||
from .metric import CumulativeIterationMetric | ||
|
||
|
@@ -123,6 +125,7 @@ def __init__( | |
num_classes=self.num_classes, | ||
) | ||
|
||
|
||
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||
""" | ||
Compute the dice value using ``DiceHelper``. | ||
|
@@ -306,6 +309,13 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. | ||
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). | ||
""" | ||
y_pred, _ = ensure_channel_first(y_pred) | ||
|
||
n_ch = self.num_classes or y_pred.shape[1] | ||
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch): | ||
y, _ = ensure_channel_first(y) | ||
|
||
Comment on lines
+312
to
+317
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainChannel-last normalization is correctly integrated; ensure upstream heuristic doesn’t blow up. This is the right place to normalize y_pred and y. However, with current ensure_channel_first, channel-first tensors with many classes (C > 32) will raise unless the utils fix is applied. Once that’s in, this path is solid. To be safe, verify both channel-first and channel-last with larger class counts: Expected: both print tensor shapes without exceptions. 🏁 Script executed: #!/bin/bash
python - <<'PY'
import torch
from monai.metrics.meandice import DiceHelper
def run(n=2, c=64, h=64, w=64):
y_pred_cf = torch.rand(n, c, h, w)
y_cf = torch.randint(0, c, (n, 1, h, w))
print("CF/C>32:", DiceHelper(num_classes=c)(y_pred_cf, y_cf).shape)
y_pred_cl = torch.rand(n, h, w, c)
y_cl = torch.randint(0, c, (n, h, w, 1))
print("CL/C>32:", DiceHelper(num_classes=c)(y_pred_cl, y_cl).shape)
run()
PY Length of output: 128 🏁 Script executed: #!/bin/bash
set -eo pipefail
# Locate ensure_channel_first definition
rg -n "def ensure_channel_first" -g "*.py" Length of output: 197 🏁 Script executed: #!/bin/bash
set -eo pipefail
# Show ensure_channel_first implementation
sed -n '1,200p' monai/inferers/utils.py Length of output: 10537 Support >32 channels in ensure_channel_first
🤖 Prompt for AI Agents
|
||
|
||
_apply_argmax, _threshold = self.apply_argmax, self.threshold | ||
if self.num_classes is None: | ||
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify import path used by DiceHelper.
PR text says the function was added to monai/metrics/utils.py, but it’s implemented here (monai/inferers/utils.py). Confirm call sites import the correct module path, or move the helper to the intended module.
Run:
🏁 Script executed:
Length of output: 3584
Ensure correct import path for ensure_channel_first
The helper lives in monai/inferers/utils.py, but in monai/metrics/meandice.py (line 18) it’s imported from monai.metrics.utils, which doesn’t define it.
• Either move ensure_channel_first into monai/metrics/utils.py and keep the existing import
• Or update the import in monai/metrics/meandice.py to:
from monai.inferers.utils import ensure_channel_first
🧰 Tools
🪛 Ruff (0.12.2)
41-41: Undefined name
Optional
(F821)
41-41: Undefined name
Tuple
(F821)
🤖 Prompt for AI Agents