-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
Problem: input_points shape mismatch when batch has mixed point prompts (BS>1)
In collate.py (collate_fn_api), input_points_embedding_dim is set to 257. During collation, when q.input_points exists, we append q.input_points.squeeze(0); otherwise we append an empty tensor with shape (0, input_points_embedding_dim).
if q.input_points is not None:
stages[stage_id].input_points.append(
q.input_points.squeeze(0) # Strip a trivial batch index
)
stages[stage_id].input_points_mask.append(
torch.zeros(q.input_points.shape[1])
)
else:
stages[stage_id].input_points.append(
torch.empty(0, input_points_embedding_dim)
)
stages[stage_id].input_points_mask.append(torch.empty(0))When batch size > 1 and the batch contains mixed samples (some have point prompts, some don't):
- Samples with point prompts: stages[stage_id].input_points element shape becomes [N, 3]
- Samples without point prompts: default empty tensor shape is [0, 257] (because input_points_embedding_dim=257)
Later, in function convert_my_tensors, these tensors cannot be stacked because the dimension is inconsistent (3 vs 257). Should input_points_embedding_dim be 3 instead of 257? Any guidance on what the intended representation is at this stage?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels