Skip to content

Conversation

weiyu0824
Copy link
Contributor

@weiyu0824 weiyu0824 commented Oct 2, 2025

I'm adding a rank verifier for the AtenMatmul operation. The goal is to catch rank mismatches earlier in the pipeline.

Problem

Currently, invalid result rank error only show up during the lowering stage (e.g., when running --convert-torch-to-linalg) and the error (torch.cast error shown below) doesn't clearly point to the root cause.

Solution

Adding a verifier to report a clear, specific error directly on the AtenMatmul op itself, before lowering even starts.

  • Case: 1D x 1D
    I add a check for 1D × 1D Matmul, as Pytorch Matmul doc define that 1D x 1D must output a scalar (rank 0). The verifier ensures the output rank is correct.

Before and After

If we feed an invalid (1D x 1D = 1D) MLIR like this:

func.func @invalid_matmul_1d_1d(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1],f32> {
  %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[1],f32>
  return %0 : !torch.vtensor<[1],f32>
}
  • Before: running torch-mlir-opt input.mlir --convert-torch-to-linalg, we'll get "Error: 'tensor.cast' op operand type 'tensor' and result type 'tensor<1xf32>' are cast incompatible" during lowering.
  • After: Running torch-mlir-opt input.mlir, we'll get "Verifier-reported error: "1D x 1D matmul should produce a scalar (rank 0)" during IR verfication

Note: I also added the above-metioned mlir to invalid.mlir, and ninja check-torch-mlir would complete without error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant