From 6d3612afb20d657e74be5822376e6196e3571ac8 Mon Sep 17 00:00:00 2001 From: Wei-Yu Date: Wed, 1 Oct 2025 22:19:46 -0700 Subject: [PATCH 1/2] Add verifier for AtenMatmul to check that 1D x 1D matmul outputs a scaler --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 23 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/invalid.mlir | 10 ++++++++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4ad03f54313f..641bb7315425 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6670,6 +6670,7 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasVerifier = 1; } def Torch_AtenMvOp : Torch_Op<"aten.mv", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bf2a605c950b..90e714cfa1ce 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5293,6 +5293,29 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenMatmulOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenMatmulOp::verify() { + + auto lhsType = cast(getSelf().getType()); + auto rhsType = cast(getOther().getType()); + auto resultType = cast(getResult().getType()); + + if (lhsType.hasSizes() && rhsType.hasSizes() && resultType.hasSizes()) { + // Get the rank + auto lhsRank = lhsType.getSizes().size(); + auto rhsRank = rhsType.getSizes().size(); + auto resultRank = resultType.getSizes().size(); + + if (lhsRank == 1 && rhsRank == 1 && resultRank != 0) { + return emitOpError("1D x 1D matmul should produce a scalar (rank 0)"); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // AtenNormScalarOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 2d1d2d2390c4..e0ce52c24259 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -575,7 +575,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::mm : (Tensor, Tensor) -> (Tensor)") emit("aten::_int_mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") - emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") + emit("aten::matmul : (Tensor, Tensor) -> (Tensor)", has_verifier=True) emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit("aten::dot : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::outer : (Tensor, Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index c863e93fa5fa..4e1eb550e19c 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -403,3 +403,13 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) - torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } + +// ----- + +func.func @torch.matmul$1d_1d_result_not_scalar(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[4],f32>) + -> !torch.vtensor<[1],f32> { + // expected-error @+1 {{1D x 1D matmul should produce a scalar (rank 0)}} + %0 = torch.aten.matmul %arg0, %arg1 + : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} From 46069661ba20c873f10d459c49aa0452677d4816 Mon Sep 17 00:00:00 2001 From: Wei-Yu Date: Thu, 2 Oct 2025 00:09:08 -0700 Subject: [PATCH 2/2] Remove trailing space --- lib/Dialect/Torch/IR/TorchOps.cpp | 4 ++-- test/Dialect/Torch/invalid.mlir | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 90e714cfa1ce..8df4bf22708e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5307,10 +5307,10 @@ LogicalResult AtenMatmulOp::verify() { // Get the rank auto lhsRank = lhsType.getSizes().size(); auto rhsRank = rhsType.getSizes().size(); - auto resultRank = resultType.getSizes().size(); + auto resultRank = resultType.getSizes().size(); if (lhsRank == 1 && rhsRank == 1 && resultRank != 0) { - return emitOpError("1D x 1D matmul should produce a scalar (rank 0)"); + return emitOpError("1D x 1D matmul should produce a scalar (rank 0)"); } } return success(); diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 4e1eb550e19c..3bbe91f62e57 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -406,10 +406,10 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) - // ----- -func.func @torch.matmul$1d_1d_result_not_scalar(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[4],f32>) +func.func @torch.matmul$1d_1d_result_not_scalar(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1],f32> { // expected-error @+1 {{1D x 1D matmul should produce a scalar (rank 0)}} - %0 = torch.aten.matmul %arg0, %arg1 + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32> }