-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Fix: add ensure_channel_first in utils and integrate into DiceHelper (refs #8366) #8532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Walkthrough.gitignore now ignores the directory Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (2)
.gitignore
(1 hunks)monai/metrics/meandice.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit Configuration File
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/metrics/meandice.py
🔇 Additional comments (3)
.gitignore (1)
168-168
: Ignore entry LGTMAdding issue38366/ to .gitignore is harmless and scoped.
monai/metrics/meandice.py (2)
18-19
: Import of ensure_channel_first is correctGood placement and usage intent.
128-128
: No-op whitespace changeNothing to do here.
…ect-MONAI#8366) Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🔭 Outside diff range comments (1)
monai/inferers/utils.py (1)
16-16
: Fix Ruff F821: import Optional and Tuple.Type hints use Optional and Tuple but they aren’t imported.
Apply:
-from typing import Any +from typing import Any, Optional, Tuple
🧹 Nitpick comments (2)
monai/inferers/utils.py (2)
39-39
: Consider exporting ensure_channel_first via all.If this is part of the public surface, add it to all for discoverability and star-import safety.
Add:
__all__ = ["sliding_window_inference", "ensure_channel_first"]
41-73
: Add unit tests for layout normalization.Cover NCHW/NHWC (2D), NCD/ND C (1D), ambiguity/error paths, and round-trip behavior.
Suggested cases:
- (N, C=1/2/3/4, H, W) → unchanged; orig=1.
- (N, H, W, C=1/2/3/4) → movedim to (N, C, H, W); orig=-1.
- (N, C, L) and (N, L, C) → both paths.
- ndim < 3 → ValueError.
- Ambiguous both-large dims → ValueError.
I can draft tests in monai/tests/test_ensure_channel_first.py if helpful.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/inferers/utils.py
(1 hunks)monai/metrics/meandice.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- monai/metrics/meandice.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit Configuration File
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/inferers/utils.py
🪛 Ruff (0.12.2)
monai/inferers/utils.py
41-41: Undefined name Optional
(F821)
41-41: Undefined name Tuple
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
@@ -38,7 +38,39 @@ | |||
|
|||
__all__ = ["sliding_window_inference"] | |||
|
|||
def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify import path used by DiceHelper.
PR text says the function was added to monai/metrics/utils.py, but it’s implemented here (monai/inferers/utils.py). Confirm call sites import the correct module path, or move the helper to the intended module.
Run:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Usages of ensure_channel_first:"
rg -n --hidden -S '\bensure_channel_first\b' || true
echo
echo "Imports of ensure_channel_first:"
rg -n --hidden -S 'from monai\.(metrics|inferers)\.utils import ensure_channel_first' || true
echo
echo "Definition locations:"
ast-grep --pattern $'def ensure_channel_first($_):\n $$$'
Length of output: 3584
Ensure correct import path for ensure_channel_first
The helper lives in monai/inferers/utils.py, but in monai/metrics/meandice.py (line 18) it’s imported from monai.metrics.utils, which doesn’t define it.
• Either move ensure_channel_first into monai/metrics/utils.py and keep the existing import
• Or update the import in monai/metrics/meandice.py to:
from monai.inferers.utils import ensure_channel_first
🧰 Tools
🪛 Ruff (0.12.2)
41-41: Undefined name Optional
(F821)
41-41: Undefined name Tuple
(F821)
🤖 Prompt for AI Agents
In monai/metrics/meandice.py at line 18, the import statement for
ensure_channel_first is incorrect because it imports from monai.metrics.utils
where the function is not defined. To fix this, update the import statement to
import ensure_channel_first from monai.inferers.utils instead, i.e., change it
to 'from monai.inferers.utils import ensure_channel_first'. Alternatively, if
preferred, move the ensure_channel_first function from monai/inferers/utils.py
to monai/metrics/utils.py and keep the existing import unchanged.
…st y channel-last detection (refs Project-MONAI#8366) Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local>
…st y channel-last detection (refs Project-MONAI#8366) Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local>
…nnel_first (refs Project-MONAI#8366) Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
monai/inferers/utils.py (1)
41-41
: Fix Ruff F821: use PEP 604/PEP 585 types (no need to import Optional/Tuple).Ruff flags undefined names Optional/Tuple. Prefer modern annotations to avoid adding imports.
Apply:
-def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]: +def ensure_channel_first(x: torch.Tensor, spatial_ndim: int | None = None) -> tuple[torch.Tensor, int]:
🧹 Nitpick comments (5)
monai/inferers/utils.py (3)
58-61
: Docstring tweak: align “ambiguous” behavior with implementation.After the fix above, ambiguous (both small or both large) always preserves layout. Update the Notes to reflect that we preserve rather than raise.
Proposed wording:
- “Uses a small-channel heuristic (<=32). In ambiguous cases (both candidate dims small or both large), preserves input layout to avoid silent mis-reordering.”
67-70
: Minor: clarify intentionally unused spatial_ndim.The inferred spatial_ndim isn’t used. Either drop the block or rename the local to underscore to signal intent.
Apply:
- if spatial_ndim is None: - spatial_ndim = x.ndim - 2 # not directly used for logic; informative only + if spatial_ndim is None: + _spatial_ndim = x.ndim - 2 # informational only; not used in logic
39-39
: Export ensure_channel_first in all for public use.You’re using this utility across modules; consider exposing it via all to mark it public.
Outside this hunk, update:
__all__ = ["sliding_window_inference", "ensure_channel_first"]monai/metrics/meandice.py (2)
128-145
: _compute_tensor addition LGTM; tiny style nit.Validation and delegation are correct. Prefer y_pred.ndim over ndimension().
Apply:
- dims = y_pred.ndimension() + dims = y_pred.ndim
304-311
: Docstring: explicitly state channel-last is accepted and auto-normalized.Reflect the new behavior for clarity.
Apply:
- Args: - y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...). - the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. - y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). + Args: + y_pred: predictions in channel-first or channel-last layout: + - (batch_size, num_classes or 1, spatial_dims...) [channel-first] + - (batch_size, spatial_dims..., num_classes or 1) [channel-last] + The number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes`` is None. + y: ground truth in channel-first or channel-last layout; if trailing dim is 1 or matches + the number of classes, it will be normalized to channel-first automatically.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/inferers/utils.py
(1 hunks)monai/metrics/meandice.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit Configuration File
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/metrics/meandice.py
monai/inferers/utils.py
🧬 Code Graph Analysis (1)
monai/metrics/meandice.py (1)
monai/inferers/utils.py (1)
ensure_channel_first
(41-91)
🪛 Ruff (0.12.2)
monai/inferers/utils.py
41-41: Undefined name Optional
(F821)
41-41: Undefined name Tuple
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: packaging
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
🔇 Additional comments (1)
monai/metrics/meandice.py (1)
18-19
: Correct import path; good catch.Importing ensure_channel_first from monai.inferers.utils matches the implementation location.
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering. | ||
if s1 <= threshold and sl <= threshold: | ||
return x, 1 | ||
|
||
raise ValueError( | ||
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; " | ||
f"both dim1={s1} and dim-1={sl} look like spatial dims" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: both-large case wrongly raises; breaks valid NCHW with many classes (C > 32).
For y_pred in channel-first with many classes (e.g., N, 64, H, W), s1 > 32 and sl (W) > 32; this path raises ValueError and will break Dice. Prefer preserving layout in both-large cases (consistent with the comment on Line 84).
Apply:
- # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
- if s1 <= threshold and sl <= threshold:
- return x, 1
-
- raise ValueError(
- f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
- f"both dim1={s1} and dim-1={sl} look like spatial dims"
- )
+ # Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering.
+ if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold):
+ return x, 1
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering. | |
if s1 <= threshold and sl <= threshold: | |
return x, 1 | |
raise ValueError( | |
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; " | |
f"both dim1={s1} and dim-1={sl} look like spatial dims" | |
) | |
# Ambiguous: both sides small (or both large). Prefer preserving to avoid silent mis-reordering. | |
if (s1 <= threshold and sl <= threshold) or (s1 > threshold and sl > threshold): | |
return x, 1 |
🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 84 to 91, the code currently raises a
ValueError when both dim1 and dim-1 are large (e.g., C>32), which incorrectly
rejects valid channel-first tensors; change the branching so that when both s1
and sl are small OR both are large the function preserves the existing layout
(return x, 1) instead of raising; only raise when sizes are truly ambiguous
(e.g., one small and one small-ish conflict) or adjust the thresholds/conditions
accordingly and update the error message to reflect remaining ambiguous cases.
y_pred, _ = ensure_channel_first(y_pred) | ||
|
||
n_ch = self.num_classes or y_pred.shape[1] | ||
if y.ndim == y_pred.ndim and y.shape[-1] in (1, n_ch): | ||
y, _ = ensure_channel_first(y) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Channel-last normalization is correctly integrated; ensure upstream heuristic doesn’t blow up.
This is the right place to normalize y_pred and y. However, with current ensure_channel_first, channel-first tensors with many classes (C > 32) will raise unless the utils fix is applied. Once that’s in, this path is solid.
To be safe, verify both channel-first and channel-last with larger class counts:
Expected: both print tensor shapes without exceptions.
🏁 Script executed:
#!/bin/bash
python - <<'PY'
import torch
from monai.metrics.meandice import DiceHelper
def run(n=2, c=64, h=64, w=64):
y_pred_cf = torch.rand(n, c, h, w)
y_cf = torch.randint(0, c, (n, 1, h, w))
print("CF/C>32:", DiceHelper(num_classes=c)(y_pred_cf, y_cf).shape)
y_pred_cl = torch.rand(n, h, w, c)
y_cl = torch.randint(0, c, (n, h, w, 1))
print("CL/C>32:", DiceHelper(num_classes=c)(y_pred_cl, y_cl).shape)
run()
PY
Length of output: 128
🏁 Script executed:
#!/bin/bash
set -eo pipefail
# Locate ensure_channel_first definition
rg -n "def ensure_channel_first" -g "*.py"
Length of output: 197
🏁 Script executed:
#!/bin/bash
set -eo pipefail
# Show ensure_channel_first implementation
sed -n '1,200p' monai/inferers/utils.py
Length of output: 10537
Support >32 channels in ensure_channel_first
ensure_channel_first
in monai/inferers/utils.py
hard-codes a 32-channel threshold and will throw on both NCHW and NHWC inputs when C>32. Before merging, please:
- Update
ensure_channel_first
to accept a configurable threshold (e.g., usenum_classes
or expose a kwarg) so it no longer raises for large class counts. - Or catch its
ValueError
inDiceHelper
and default to treating dim-1 as the channel axis. - Add unit tests for C>32 on both channel-first and channel-last inputs to verify no exceptions.
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 312 to 317, ensure_channel_first can
raise for channel counts >32; modify the code to catch ValueError from
ensure_channel_first and on exception treat dim-1 as the channel axis (i.e.,
transpose NHWC to NCHW manually or accept input as already channel-first),
ensuring n_ch is computed from num_classes or y_pred.shape[1] accordingly;
alternatively, update monai/inferers/utils.py:ensure_channel_first to accept a
configurable threshold or a kwarg (e.g., max_channel_threshold or
force_last_dim_channel) and use that in this call so it won’t raise for large C;
add unit tests that cover C>32 for both channel-first (NCHW) and channel-last
(NHWC) inputs to verify no exceptions and correct channel handling.
What does this PR do?
Adds ensure_channel_first utility in monai/metrics/utils.py and integrates it into DiceHelper.call in monai/metrics/meandice.py to normalize input tensors to channel-first layout before Dice computation. This prevents metric calculation errors when the input is provided in channel-last format.
Issue reference
Refs #8366
Motivation and context
Currently, DiceHelper assumes the channel dimension is at index 1. If a user passes channel-last tensors (e.g., [N, H, W, C]), it can lead to incorrect indexing and metric values. This PR ensures that y_pred and (when applicable) y are converted to channel-first format before further processing, preserving original metric behavior while adding channel-last support.
How I tested
Verified that channel-last inputs now produce identical results to channel-first inputs.
Ran pytest -q tests/metrics -k dice locally without failures.
Checked that no changes are introduced to the outputs when inputs are already channel-first.