-
Notifications
You must be signed in to change notification settings - Fork 743
Bug: RuntimeError: max() in ObjectCentricSSI._compute_scale_and_shift when mask is empty #168
Description
Sorry I am a student working on research regarding this tool, and don't have time atm to fix this into a tested pull request, but figured it would be helpful to point out since the problem seems to be arising with the min() function as well.
When processing an image where the valid mask area becomes completely empty (e.g., an off-center object that gets cropped out, or an empty mask being passed in), the pipeline crashes with a RuntimeError during the pointmap normalization step.
In sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py, inside the ObjectCentricSSI._compute_scale_and_shift method, the code filters pointmap_flat using the mask. If no valid points are found, mask_points becomes an empty tensor.
Calling .max() on an empty tensor in PyTorch without a specified dimension throws an error, causing the script to crash before it can reach the intended ValueError or fallback logging.
File "sam-3d-objects/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py", line 597, in normalize
_scale, _shift = self._compute_scale_and_shift(pointmap, mask)
File "sam-3d-objects/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py", line 550, in _compute_scale_and_shift
if mask_points.isfinite().max() == 0:
RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
Steps to Reproduce:
- Pass an image and a completely black mask (all zeros) through the preprocessing pipeline using the ObjectCentricSSI normalizer.
- The mask_points tensor evaluates to size [3, 0].
- The check if
mask_points.isfinite().max() == 0:throws the RuntimeError.
Suggested Fix:
Adding a quick numel() == 0 check before .max() safely catches the empty mask and allows the intended fallback logic to execute. This allowed my instance of the model to run as expected.
File: sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py
Line ~550:
if mask_points.numel() == 0 or mask_points.isfinite().max() == 0: