@@ -757,15 +757,14 @@ class GenerateHeatmap(Transform):
757757 Notes:
758758 - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
759759 - Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
760- - Output layout uses channel-first convention with one channel per landmark:
761- - Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D
762- - Batched points (B, N, D) : (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D
763- - Each channel corresponds to one landmark.
760+ - Output layout uses channel-first convention with one channel per landmark.
761+ - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3).
762+ - Output heatmap shape : (N, Y, X) for 2D or (N, Z, Y, X) for 3D.
763+ - Each channel index corresponds to one landmark.
764764
765765 Args:
766766 sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
767767 spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
768- A single int value will be broadcast to all spatial dimensions.
769768 truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
770769 normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
771770 dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).
@@ -787,84 +786,90 @@ def __init__(
787786 ) -> None :
788787 if isinstance (sigma , Sequence ) and not isinstance (sigma , (str , bytes )):
789788 if any (s <= 0 for s in sigma ):
790- raise ValueError ("sigma values must be positive." )
789+ raise ValueError ("Argument ` sigma` values must be positive." )
791790 self ._sigma = tuple (float (s ) for s in sigma )
792791 else :
793792 if float (sigma ) <= 0 :
794- raise ValueError ("sigma must be positive." )
793+ raise ValueError ("Argument ` sigma` must be positive." )
795794 self ._sigma = (float (sigma ),)
796795 if truncated <= 0 :
797- raise ValueError ("truncated must be positive." )
796+ raise ValueError ("Argument ` truncated` must be positive." )
798797 self .truncated = float (truncated )
799798 self .normalize = normalize
800799 self .torch_dtype = get_equivalent_dtype (dtype , torch .Tensor )
801800 self .numpy_dtype = get_equivalent_dtype (dtype , np .ndarray )
802801 # Validate that dtype is floating-point for meaningful Gaussian values
803802 if not self .torch_dtype .is_floating_point :
804- raise ValueError (f"dtype must be a floating-point type, got { self .torch_dtype } " )
803+ raise ValueError (f"Argument ` dtype` must be a floating-point type, got { self .torch_dtype } " )
805804 self .spatial_shape = None if spatial_shape is None else tuple (int (s ) for s in spatial_shape )
806805
807806 def __call__ (self , points : NdarrayOrTensor , spatial_shape : Sequence [int ] | None = None ) -> NdarrayOrTensor :
808807 """
809808 Args:
810- points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D) ,
811- ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
812- spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
813- the value provided at construction.
809+ points: landmark coordinates as ndarray/Tensor with shape (N, D),
810+ ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number
811+ of landmarks and D is the spatial dimensionality.
812+ spatial_shape: spatial size as a sequence. If None, uses the value provided at construction.
814813
815814 Returns:
816- Heatmaps with shape (N, *spatial) or (B, N, *spatial) , one channel per landmark.
815+ Heatmaps with shape (N, *spatial), one channel per landmark.
817816
818817 Raises:
819818 ValueError: if points shape/dimension or spatial_shape is invalid.
820819 """
821820 original_points = points
822821 points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
823822
824- is_batched = points_t .ndim == 3
825- if not is_batched :
826- if points_t .ndim != 2 :
827- raise ValueError (
828- "points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)."
829- )
830- points_t = points_t .unsqueeze (0 ) # Add a batch dimension
823+ if points_t .ndim != 2 :
824+ raise ValueError (
825+ f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape { points_t .shape } ."
826+ )
831827
832828 if points_t .shape [- 1 ] not in (2 , 3 ):
833829 raise ValueError ("GenerateHeatmap only supports 2D or 3D landmarks." )
834830
835831 device = points_t .device
836- batch_size , num_points , spatial_dims = points_t .shape
832+ num_points , spatial_dims = points_t .shape
837833
838834 target_shape = self ._resolve_spatial_shape (spatial_shape , spatial_dims )
839835 sigma = self ._resolve_sigma (spatial_dims )
840- radius = tuple (int (np .ceil (self .truncated * s )) for s in sigma )
841-
842- heatmap = torch .zeros ((batch_size , num_points , * target_shape ), dtype = self .torch_dtype , device = device )
843- image_bounds = tuple (int (s ) for s in target_shape )
844- bounds_t = torch .as_tensor (image_bounds , device = device , dtype = points_t .dtype )
845- for b_idx in range (batch_size ):
846- for idx , center in enumerate (points_t [b_idx ]):
847- if not torch .isfinite (center ).all ():
848- continue
849- if not ((center >= 0 ).all () and (center < bounds_t ).all ()):
850- continue
851- # _make_window expects Python floats; convert only when needed
852- center_vals = center .tolist ()
853- window_slices , coord_shifts = self ._make_window (center_vals , radius , image_bounds , device )
854- if window_slices is None :
855- continue
856- region = heatmap [b_idx , idx ][window_slices ]
857- gaussian = self ._evaluate_gaussian (coord_shifts , sigma )
858- updated = torch .maximum (region , gaussian )
859- # write back
860- region .copy_ (updated )
861- if self .normalize :
862- peak = heatmap [b_idx , idx ].amax ()
863- denom = torch .where (peak > 0 , peak , torch .ones_like (peak ))
864- heatmap [b_idx , idx ].div_ (denom )
865-
866- if not is_batched :
867- heatmap = heatmap .squeeze (0 )
836+
837+ # Create sparse image with impulses at landmark locations
838+ heatmap = torch .zeros ((num_points , * target_shape ), dtype = self .torch_dtype , device = device )
839+ bounds_t = torch .as_tensor (target_shape , device = device , dtype = points_t .dtype )
840+
841+ for idx , center in enumerate (points_t ):
842+ if not torch .isfinite (center ).all ():
843+ continue
844+ if not ((center >= 0 ).all () and (center < bounds_t ).all ()):
845+ continue
846+ # Round to nearest integer for impulse placement
847+ center_int = center .round ().long ()
848+ # Place impulse (use maximum in case of overlapping landmarks)
849+ current_val = heatmap [idx ][tuple (center_int )]
850+ heatmap [idx ][tuple (center_int )] = max (current_val , torch .tensor (1.0 , dtype = self .torch_dtype , device = device ))
851+
852+ # Apply Gaussian blur using GaussianFilter
853+ # Reshape to (num_points, 1, *spatial) for per-channel filtering
854+ heatmap_input = heatmap .unsqueeze (1 ) # Add channel dimension
855+
856+ gaussian_filter = GaussianFilter (
857+ spatial_dims = spatial_dims ,
858+ sigma = sigma ,
859+ truncated = self .truncated ,
860+ approx = "erf" ,
861+ requires_grad = False
862+ ).to (device )
863+
864+ heatmap_blurred = gaussian_filter (heatmap_input )
865+ heatmap = heatmap_blurred .squeeze (1 ) # Remove channel dimension
866+
867+ # Normalize per channel if requested
868+ if self .normalize :
869+ for idx in range (num_points ):
870+ peak = heatmap [idx ].amax ()
871+ if peak > 0 :
872+ heatmap [idx ].div_ (peak )
868873
869874 target_dtype = self .torch_dtype if isinstance (original_points , (torch .Tensor , MetaTensor )) else self .numpy_dtype
870875 converted , _ , _ = convert_to_dst_type (heatmap , original_points , dtype = target_dtype )
@@ -873,14 +878,14 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
873878 def _resolve_spatial_shape (self , call_shape : Sequence [int ] | None , spatial_dims : int ) -> tuple [int , ...]:
874879 shape = call_shape if call_shape is not None else self .spatial_shape
875880 if shape is None :
876- raise ValueError ("spatial_shape must be provided either at construction time or call time." )
881+ raise ValueError ("Argument ` spatial_shape` must be provided either at construction time or call time." )
877882 shape_tuple = ensure_tuple (shape )
878883 if len (shape_tuple ) != spatial_dims :
879884 if len (shape_tuple ) == 1 :
880885 shape_tuple = shape_tuple * spatial_dims # type: ignore
881886 else :
882887 raise ValueError (
883- "spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
888+ "Argument ` spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)."
884889 )
885890 return tuple (int (s ) for s in shape_tuple )
886891
@@ -889,53 +894,7 @@ def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
889894 return self ._sigma
890895 if len (self ._sigma ) == 1 :
891896 return self ._sigma * spatial_dims
892- raise ValueError ("sigma sequence length must equal the number of spatial dimensions." )
893-
894- @staticmethod
895- def _is_inside (center : Sequence [float ], bounds : tuple [int , ...]) -> bool :
896- for c , size in zip (center , bounds ):
897- if not (0 <= c < size ):
898- return False
899- return True
900-
901- def _make_window (
902- self , center : Sequence [float ], radius : tuple [int , ...], bounds : tuple [int , ...], device : torch .device
903- ) -> tuple [tuple [slice , ...] | None , tuple [torch .Tensor , ...]]:
904- slices : list [slice ] = []
905- coord_shifts : list [torch .Tensor ] = []
906- for _dim , (c , r , size ) in enumerate (zip (center , radius , bounds )):
907- start = max (int (np .floor (c - r )), 0 )
908- stop = min (int (np .ceil (c + r )) + 1 , size )
909- if start >= stop :
910- return None , ()
911- slices .append (slice (start , stop ))
912- coord_shifts .append (torch .arange (start , stop , device = device , dtype = torch .float32 ) - float (c ))
913- return tuple (slices ), tuple (coord_shifts )
914-
915- def _evaluate_gaussian (self , coord_shifts : tuple [torch .Tensor , ...], sigma : tuple [float , ...]) -> torch .Tensor :
916- """
917- Evaluate Gaussian at given coordinate shifts with specified sigmas.
918-
919- Args:
920- coord_shifts: Per-dimension coordinate offsets from center.
921- sigma: Per-dimension standard deviations.
922-
923- Returns:
924- Gaussian values at the specified coordinates.
925- """
926- device = coord_shifts [0 ].device
927- shape = tuple (len (axis ) for axis in coord_shifts )
928- if 0 in shape :
929- return torch .zeros (shape , dtype = self .torch_dtype , device = device )
930- exponent = torch .zeros (shape , dtype = torch .float32 , device = device )
931- for dim , (shift , sig ) in enumerate (zip (coord_shifts , sigma )):
932- shift32 = shift .to (torch .float32 )
933- scaled = (shift32 / float (sig )) ** 2
934- reshape_shape = [1 ] * len (coord_shifts )
935- reshape_shape [dim ] = shift .numel ()
936- exponent += scaled .reshape (reshape_shape )
937- gauss = torch .exp (- 0.5 * exponent )
938- return gauss .to (dtype = self .torch_dtype )
897+ raise ValueError ("Argument `sigma` sequence length must equal the number of spatial dimensions." )
939898
940899
941900class ProbNMS (Transform ):
0 commit comments