From 2fcd68ac1ff8b8fdb9af62cdbe9fcbcd6466f485 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Tue, 4 Nov 2025 16:58:59 +0000 Subject: [PATCH] [TOSA] Fix empty-dim reductions Teach the TorchToTosa reducer that an explicit empty dim list means "all dims" and cast the result back to the requested dtype. Add MLIR and e2e regression cases and update XFAILs. Change-Id: Ibd1be38d219ad5c1986eb4a641efbb9ff0cb6a55 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 5 ++ .../TorchToTosa/TosaLegalizeCommon.cpp | 12 ++++- projects/pt1/e2e_testing/xfail_sets.py | 4 +- .../test_suite/reduction.py | 46 +++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 23 ++++++++++ 5 files changed, 87 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..0f19dad4cb1f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1089,6 +1089,11 @@ class ConvertAtenMultipleDimsReductionOp for (int64_t i = 0; i < inputRank; i++) reduceDims.push_back(i); } + // PyTorch treats an explicit empty list the same as "reduce all dims". + if (reduceDims.empty()) { + for (int64_t i = 0; i < inputRank; i++) + reduceDims.push_back(i); + } int64_t N = reduceDims.size(); for (unsigned i = 0; i < N; i++) { diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 02d1390ed148..444a2bdd2508 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -782,13 +782,23 @@ std::optional convertReduceOpCommon( // Optionally squeeze out the reduced axes. if (!keep_dims) { + auto squeezedType = + RankedTensorType::get(output_shape, reduce_element_type); auto reshape_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, val, + rewriter, op->getLoc(), squeezedType, val, tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape)); val = reshape_op.getResult(); } } + // Ensure the result element type matches the expected output type. + if (val.getType() != output_type) { + auto casted = tosa::tosaCastTensorToType(rewriter, val, output_type); + if (!casted) + return std::nullopt; + val = casted.value(); + } + return val; } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 81071c6ab058..efbfaf259ac2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3434,6 +3434,8 @@ "ElementwiseClampMinModule_bfloat16", "ElementwiseClampModule_bfloat16", "ElementwiseReluModule_bfloat16", + # torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '' + "ReduceSumEmptyDimListInt8ToInt32Module_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3846,7 +3848,6 @@ "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", - "MeanDimEmptyDimModule_basic", "MlGroupNormManualModule_basic", "MlGroupNormModule_basic", "MlLayerNormManualModule_basic", @@ -3901,7 +3902,6 @@ "ReduceL3NormKeepDimComplexModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceSumDimIntListEmptyDimModule_basic", "RollModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 0eb0545e7f11..2e4ba9c4ccfc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -58,6 +58,52 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumEmptyDimListInt8ToInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[], dtype=torch.int32) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8ToInt32Module()) +def ReduceSumEmptyDimListInt8ToInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + +class ReduceSumEmptyDimListInt8Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8Module()) +def ReduceSumEmptyDimListInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + class ReduceSumElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d100fe9dcfde..543dc09a65b2 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -311,6 +311,29 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> ! // ----- +// CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[],f32> +// CHECK: } +func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %empty = torch.prim.ListConstruct : () -> !torch.list + %0 = torch.aten.sum.dim_IntList %arg0, %empty, %false, %none : !torch.vtensor<[2,3,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_linalg_vector_norm$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>