Add verifier for AtenMatmul to check that 1D x 1D outputs a scaler #4329
+35
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I'm adding a rank verifier for the AtenMatmul operation. The goal is to catch rank mismatches earlier in the pipeline.
Problem
Currently, invalid result rank error only show up during the lowering stage (e.g., when running
--convert-torch-to-linalg
) and the error (torch.cast
error shown below) doesn't clearly point to the root cause.Solution
Adding a verifier to report a clear, specific error directly on the AtenMatmul op itself, before lowering even starts.
I add a check for 1D × 1D Matmul, as Pytorch Matmul doc define that 1D x 1D must output a scalar (rank 0). The verifier ensures the output rank is correct.
Before and After
If we feed an invalid (1D x 1D = 1D) MLIR like this:
torch-mlir-opt input.mlir --convert-torch-to-linalg
, we'll get "Error: 'tensor.cast' op operand type 'tensor' and result type 'tensor<1xf32>' are cast incompatible" during lowering.torch-mlir-opt input.mlir
, we'll get "Verifier-reported error: "1D x 1D matmul should produce a scalar (rank 0)" during IR verficationNote: I also added the above-metioned mlir to invalid.mlir, and
ninja check-torch-mlir
would complete without error.