Skip to content

Commit a231740

Browse files
committed
Remove batch, use GaussianFilter
Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com>
1 parent 56e0662 commit a231740

File tree

4 files changed

+191
-294
lines changed

4 files changed

+191
-294
lines changed

monai/transforms/post/array.py

Lines changed: 58 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -757,15 +757,14 @@ class GenerateHeatmap(Transform):
757757
Notes:
758758
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
759759
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
760-
- Output layout uses channel-first convention with one channel per landmark:
761-
- Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D
762-
- Batched points (B, N, D): (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D
763-
- Each channel corresponds to one landmark.
760+
- Output layout uses channel-first convention with one channel per landmark.
761+
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3).
762+
- Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D.
763+
- Each channel index corresponds to one landmark.
764764
765765
Args:
766766
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
767767
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
768-
A single int value will be broadcast to all spatial dimensions.
769768
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
770769
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
771770
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).
@@ -787,84 +786,90 @@ def __init__(
787786
) -> None:
788787
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
789788
if any(s <= 0 for s in sigma):
790-
raise ValueError("sigma values must be positive.")
789+
raise ValueError("Argument `sigma` values must be positive.")
791790
self._sigma = tuple(float(s) for s in sigma)
792791
else:
793792
if float(sigma) <= 0:
794-
raise ValueError("sigma must be positive.")
793+
raise ValueError("Argument `sigma` must be positive.")
795794
self._sigma = (float(sigma),)
796795
if truncated <= 0:
797-
raise ValueError("truncated must be positive.")
796+
raise ValueError("Argument `truncated` must be positive.")
798797
self.truncated = float(truncated)
799798
self.normalize = normalize
800799
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
801800
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
802801
# Validate that dtype is floating-point for meaningful Gaussian values
803802
if not self.torch_dtype.is_floating_point:
804-
raise ValueError(f"dtype must be a floating-point type, got {self.torch_dtype}")
803+
raise ValueError(f"Argument `dtype` must be a floating-point type, got {self.torch_dtype}")
805804
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)
806805

807806
def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
808807
"""
809808
Args:
810-
points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D),
811-
ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
812-
spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
813-
the value provided at construction.
809+
points: landmark coordinates as ndarray/Tensor with shape (N, D),
810+
ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number
811+
of landmarks and D is the spatial dimensionality.
812+
spatial_shape: spatial size as a sequence. If None, uses the value provided at construction.
814813
815814
Returns:
816-
Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark.
815+
Heatmaps with shape (N, *spatial), one channel per landmark.
817816
818817
Raises:
819818
ValueError: if points shape/dimension or spatial_shape is invalid.
820819
"""
821820
original_points = points
822821
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
823822

824-
is_batched = points_t.ndim == 3
825-
if not is_batched:
826-
if points_t.ndim != 2:
827-
raise ValueError(
828-
"points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)."
829-
)
830-
points_t = points_t.unsqueeze(0) # Add a batch dimension
823+
if points_t.ndim != 2:
824+
raise ValueError(
825+
f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}."
826+
)
831827

832828
if points_t.shape[-1] not in (2, 3):
833829
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")
834830

835831
device = points_t.device
836-
batch_size, num_points, spatial_dims = points_t.shape
832+
num_points, spatial_dims = points_t.shape
837833

838834
target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
839835
sigma = self._resolve_sigma(spatial_dims)
840-
radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma)
841-
842-
heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
843-
image_bounds = tuple(int(s) for s in target_shape)
844-
bounds_t = torch.as_tensor(image_bounds, device=device, dtype=points_t.dtype)
845-
for b_idx in range(batch_size):
846-
for idx, center in enumerate(points_t[b_idx]):
847-
if not torch.isfinite(center).all():
848-
continue
849-
if not ((center >= 0).all() and (center < bounds_t).all()):
850-
continue
851-
# _make_window expects Python floats; convert only when needed
852-
center_vals = center.tolist()
853-
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
854-
if window_slices is None:
855-
continue
856-
region = heatmap[b_idx, idx][window_slices]
857-
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
858-
updated = torch.maximum(region, gaussian)
859-
# write back
860-
region.copy_(updated)
861-
if self.normalize:
862-
peak = heatmap[b_idx, idx].amax()
863-
denom = torch.where(peak > 0, peak, torch.ones_like(peak))
864-
heatmap[b_idx, idx].div_(denom)
865-
866-
if not is_batched:
867-
heatmap = heatmap.squeeze(0)
836+
837+
# Create sparse image with impulses at landmark locations
838+
heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device)
839+
bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype)
840+
841+
for idx, center in enumerate(points_t):
842+
if not torch.isfinite(center).all():
843+
continue
844+
if not ((center >= 0).all() and (center < bounds_t).all()):
845+
continue
846+
# Round to nearest integer for impulse placement
847+
center_int = center.round().long()
848+
# Place impulse (use maximum in case of overlapping landmarks)
849+
current_val = heatmap[idx][tuple(center_int)]
850+
heatmap[idx][tuple(center_int)] = max(current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device))
851+
852+
# Apply Gaussian blur using GaussianFilter
853+
# Reshape to (num_points, 1, *spatial) for per-channel filtering
854+
heatmap_input = heatmap.unsqueeze(1) # Add channel dimension
855+
856+
gaussian_filter = GaussianFilter(
857+
spatial_dims=spatial_dims,
858+
sigma=sigma,
859+
truncated=self.truncated,
860+
approx="erf",
861+
requires_grad=False
862+
).to(device)
863+
864+
heatmap_blurred = gaussian_filter(heatmap_input)
865+
heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension
866+
867+
# Normalize per channel if requested
868+
if self.normalize:
869+
for idx in range(num_points):
870+
peak = heatmap[idx].amax()
871+
if peak > 0:
872+
heatmap[idx].div_(peak)
868873

869874
target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
870875
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
@@ -873,14 +878,14 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
873878
def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
874879
shape = call_shape if call_shape is not None else self.spatial_shape
875880
if shape is None:
876-
raise ValueError("spatial_shape must be provided either at construction time or call time.")
881+
raise ValueError("Argument `spatial_shape` must be provided either at construction time or call time.")
877882
shape_tuple = ensure_tuple(shape)
878883
if len(shape_tuple) != spatial_dims:
879884
if len(shape_tuple) == 1:
880885
shape_tuple = shape_tuple * spatial_dims # type: ignore
881886
else:
882887
raise ValueError(
883-
"spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
888+
"Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)."
884889
)
885890
return tuple(int(s) for s in shape_tuple)
886891

@@ -889,53 +894,7 @@ def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
889894
return self._sigma
890895
if len(self._sigma) == 1:
891896
return self._sigma * spatial_dims
892-
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")
893-
894-
@staticmethod
895-
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
896-
for c, size in zip(center, bounds):
897-
if not (0 <= c < size):
898-
return False
899-
return True
900-
901-
def _make_window(
902-
self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device
903-
) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]:
904-
slices: list[slice] = []
905-
coord_shifts: list[torch.Tensor] = []
906-
for _dim, (c, r, size) in enumerate(zip(center, radius, bounds)):
907-
start = max(int(np.floor(c - r)), 0)
908-
stop = min(int(np.ceil(c + r)) + 1, size)
909-
if start >= stop:
910-
return None, ()
911-
slices.append(slice(start, stop))
912-
coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))
913-
return tuple(slices), tuple(coord_shifts)
914-
915-
def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
916-
"""
917-
Evaluate Gaussian at given coordinate shifts with specified sigmas.
918-
919-
Args:
920-
coord_shifts: Per-dimension coordinate offsets from center.
921-
sigma: Per-dimension standard deviations.
922-
923-
Returns:
924-
Gaussian values at the specified coordinates.
925-
"""
926-
device = coord_shifts[0].device
927-
shape = tuple(len(axis) for axis in coord_shifts)
928-
if 0 in shape:
929-
return torch.zeros(shape, dtype=self.torch_dtype, device=device)
930-
exponent = torch.zeros(shape, dtype=torch.float32, device=device)
931-
for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
932-
shift32 = shift.to(torch.float32)
933-
scaled = (shift32 / float(sig)) ** 2
934-
reshape_shape = [1] * len(coord_shifts)
935-
reshape_shape[dim] = shift.numel()
936-
exponent += scaled.reshape(reshape_shape)
937-
gauss = torch.exp(-0.5 * exponent)
938-
return gauss.to(dtype=self.torch_dtype)
897+
raise ValueError("Argument `sigma` sequence length must equal the number of spatial dimensions.")
939898

940899

941900
class ProbNMS(Transform):

monai/transforms/post/dictionary.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,9 @@ class GenerateHeatmapd(MapTransform):
519519
Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.
520520
521521
Args:
522-
keys: keys of the corresponding items in the dictionary.
522+
keys: keys of the corresponding items in the dictionary, where each key references a tensor
523+
of landmark point coordinates with shape (N, D), where N is the number of landmarks
524+
and D is the spatial dimensionality (2 or 3).
523525
sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number
524526
of spatial dimensions.
525527
heatmap_keys: keys to store output heatmaps. Default: "{key}_heatmap" for each key.
@@ -539,25 +541,54 @@ class GenerateHeatmapd(MapTransform):
539541
Raises:
540542
ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.
541543
ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).
542-
ValueError: If input points have invalid shape (must be 2D or 3D).
544+
ValueError: If input points have invalid shape (must be 2D array with shape (N, D)).
545+
546+
Example:
547+
.. code-block:: python
548+
549+
import numpy as np
550+
from monai.transforms import GenerateHeatmapd
551+
552+
# Create sample data with landmark points and a reference image
553+
data = {
554+
"landmarks": np.array([[10.0, 15.0], [20.0, 25.0]]), # 2 points in 2D
555+
"image": np.zeros((32, 32)) # reference image
556+
}
557+
558+
# Transform with reference image
559+
transform = GenerateHeatmapd(
560+
keys="landmarks",
561+
sigma=2.0,
562+
ref_image_keys="image"
563+
)
564+
result = transform(data)
565+
# result["landmarks_heatmap"] has shape (2, 32, 32) - one channel per landmark
566+
567+
# Or with explicit spatial_shape
568+
transform = GenerateHeatmapd(
569+
keys="landmarks",
570+
sigma=2.0,
571+
spatial_shape=(64, 64)
572+
)
573+
result = transform(data)
574+
# result["landmarks_heatmap"] has shape (2, 64, 64)
543575
544576
Notes:
545577
- Default heatmap_keys are generated as "{key}_heatmap" for each input key
546578
- Shape inference precedence: static spatial_shape > ref_image
547-
- Output shapes:
548-
- Non-batched points (N, D): (N, H, W[, D])
549-
- Batched points (B, N, D): (B, N, H, W[, D])
579+
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions
580+
- Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D
550581
- When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference
551582
"""
552583

553584
backend = GenerateHeatmap.backend
554585

555586
# Error messages
556-
_ERR_HEATMAP_KEYS_LEN = "heatmap_keys length must match keys length."
557-
_ERR_REF_KEYS_LEN = "ref_image_keys length must match keys length when provided."
558-
_ERR_SHAPE_LEN = "spatial_shape length must match keys length when providing per-key shapes."
587+
_ERR_HEATMAP_KEYS_LEN = "Argument `heatmap_keys` length must match keys length."
588+
_ERR_REF_KEYS_LEN = "Argument `ref_image_keys` length must match keys length when provided."
589+
_ERR_SHAPE_LEN = "Argument `spatial_shape` length must match keys length when providing per-key shapes."
559590
_ERR_NO_SHAPE = "Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys."
560-
_ERR_INVALID_POINTS = "landmark arrays must be 2D or 3D with shape (N, D) or (B, N, D)."
591+
_ERR_INVALID_POINTS = "Landmark arrays must be 2D with shape (N, D)."
561592
_ERR_REF_NO_SHAPE = "Reference data must define a shape attribute."
562593

563594
def __init__(

0 commit comments

Comments
 (0)