Skip to content

input_points shape mismatch when batch has mixed point prompts (BS>1) #458

@Casperqian

Description

@Casperqian

Problem: input_points shape mismatch when batch has mixed point prompts (BS>1)

In collate.py (collate_fn_api), input_points_embedding_dim is set to 257. During collation, when q.input_points exists, we append q.input_points.squeeze(0); otherwise we append an empty tensor with shape (0, input_points_embedding_dim).

if q.input_points is not None:
    stages[stage_id].input_points.append(
        q.input_points.squeeze(0)  # Strip a trivial batch index
    )
    stages[stage_id].input_points_mask.append(
        torch.zeros(q.input_points.shape[1])
    )
else:
    stages[stage_id].input_points.append(
        torch.empty(0, input_points_embedding_dim)
    )
    stages[stage_id].input_points_mask.append(torch.empty(0))

When batch size > 1 and the batch contains mixed samples (some have point prompts, some don't):

  • Samples with point prompts: stages[stage_id].input_points element shape becomes [N, 3]
  • Samples without point prompts: default empty tensor shape is [0, 257] (because input_points_embedding_dim=257)

Later, in function convert_my_tensors, these tensors cannot be stacked because the dimension is inconsistent (3 vs 257). Should input_points_embedding_dim be 3 instead of 257? Any guidance on what the intended representation is at this stage?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions