From fd4f1a5c4b2c330987c0e906ee7eb0bba9da7b78 Mon Sep 17 00:00:00 2001 From: Tao-Tao-real <2510737554@qq.com> Date: Fri, 13 Feb 2026 19:29:09 +0800 Subject: [PATCH 1/4] Fix subset subview offset/layout and add upstream test --- lib/PTO/IR/PTO.cpp | 21 +- lib/PTO/Transforms/PTOToEmitC.cpp | 9 +- lib/PTO/Transforms/PTOViewToMemref.cpp | 39 +- test/samples/Subset/subset_upstream.py | 707 +++++++++++++++++++++++++ 4 files changed, 759 insertions(+), 17 deletions(-) create mode 100644 test/samples/Subset/subset_upstream.py diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index bdf40165..2d9b1b54 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6661,21 +6661,22 @@ LogicalResult SubsetOp::inferReturnTypes( if (pv == ShapedType::kDynamic) { vdim = ShapedType::kDynamic; } else { + // Only refine when offset is a compile-time constant. + // If offset is dynamic, keep static valid dims equal to size to + // avoid type instability across uses. int64_t off = 0; - // operands: [source, offsets...] if (operands.size() > 1 + i) { auto offOpt = getConstIndexValue(operands[1 + i]); - if (!offOpt) { - vdim = ShapedType::kDynamic; - validShape.push_back(vdim); - continue; + if (offOpt) { + off = *offOpt; + int64_t diff = pv - off; + if (diff < 0) diff = 0; + vdim = std::min(sizeDim, diff); + } else { + vdim = sizeDim; } - off = *offOpt; - int64_t diff = pv - off; - if (diff < 0) diff = 0; - vdim = std::min(sizeDim, diff); } else { - vdim = ShapedType::kDynamic; + vdim = sizeDim; } } } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 4f5731ad..df92aa66 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2222,8 +2222,13 @@ struct SubviewToEmitCPattern : public OpConversionPattern { } if (auto ot = dyn_cast(tileCandidate.getType())) { auto tyStr = ot.getValue(); - if (tyStr.find("Tile<") != std::string::npos || - tyStr.find("ConvTile<") != std::string::npos) { + const bool isPtrLike = tyStr.ends_with("*"); + bool isTileLike = tyStr.find("Tile<") != std::string::npos || + tyStr.find("ConvTile<") != std::string::npos; + if (!isTileLike && !isPtrLike && tyStr.find("Tile") != std::string::npos) + isTileLike = true; + + if (isTileLike && !isPtrLike) { std::string elemTok = elemTypeToString(srcType.getElementType()); std::string qualifier = "__gm__"; if (auto asAttr = diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index d5a6a35c..8385f287 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -733,7 +733,12 @@ struct PTOViewToMemrefPass // 1. Source must be memref already Value src = op->getOperand(0); - auto srcMrTy = dyn_cast(src.getType()); + // If the source is a bound tile, subview the underlying memref to avoid + // materializing a tile->pointer cast in later lowering. + Value subviewSrc = src; + if (auto bind = src.getDefiningOp()) + subviewSrc = bind.getSource(); + auto srcMrTy = dyn_cast(subviewSrc.getType()); if (!srcMrTy) { op.emitError("pto.subset source must be lowered to memref first"); signalPassFailure(); @@ -754,6 +759,8 @@ struct PTOViewToMemrefPass // 3. Offsets (mixed) SmallVector mixedOffsets; + SmallVector staticOffsets; + staticOffsets.reserve(op.getOffsets().size()); for (Value o : op.getOffsets()) { IntegerAttr constAttr; bool isStatic = false; @@ -764,10 +771,13 @@ struct PTOViewToMemrefPass constAttr = rewriter.getIndexAttr(cInt.value()); isStatic = true; } - if (isStatic) + if (isStatic) { mixedOffsets.push_back(constAttr); - else + staticOffsets.push_back(constAttr.getInt()); + } else { mixedOffsets.push_back(ensureIndex(rewriter, loc, o, op)); + staticOffsets.push_back(ShapedType::kDynamic); + } } // 3.1 Layout-aware checks for boxed tiles (SLayout != NoneBox) @@ -869,7 +879,26 @@ struct PTOViewToMemrefPass } (void)srcOffset; - auto resultLayout = StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides); + // If source offset/strides and subset offsets are all static, preserve + // a static offset in the result type to satisfy memref.subview verifier. + int64_t resultOffset = ShapedType::kDynamic; + bool allOffsetsStatic = (srcOffset != ShapedType::kDynamic); + if (allOffsetsStatic) { + int64_t totalOffset = srcOffset; + for (size_t i = 0; i < staticSizes.size(); ++i) { + if (i >= static_cast(srcStrides.size()) || + srcStrides[i] == ShapedType::kDynamic || + staticOffsets[i] == ShapedType::kDynamic) { + allOffsetsStatic = false; + break; + } + totalOffset += staticOffsets[i] * srcStrides[i]; + } + if (allOffsetsStatic) + resultOffset = totalOffset; + } + + auto resultLayout = StridedLayoutAttr::get(ctx, resultOffset, srcStrides); auto resultMemRefType = MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout, srcMrTy.getMemorySpace()); @@ -881,7 +910,7 @@ struct PTOViewToMemrefPass mixedStrides.push_back(rewriter.getIndexAttr(1)); auto sv = rewriter.create( - loc, resultMemRefType, src, mixedOffsets, mixedSizes, mixedStrides); + loc, resultMemRefType, subviewSrc, mixedOffsets, mixedSizes, mixedStrides); // 6. Re-bind tile metadata (config + valid dims) Value parentVRow; diff --git a/test/samples/Subset/subset_upstream.py b/test/samples/Subset/subset_upstream.py new file mode 100644 index 00000000..f26d0bfb --- /dev/null +++ b/test/samples/Subset/subset_upstream.py @@ -0,0 +1,707 @@ +# Auto-generated test from upstream MLIR sample. + +def build(): + return r""" +module { + func.func @main_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: !pto.ptr, %arg6: !pto.ptr, %arg7: !pto.ptr, %arg8: i32, %arg9: i32, %arg10: i32) { + %0 = pto.get_block_idx + %1 = arith.trunci %0 : i64 to i32 + %2 = pto.alloc_tile : !pto.tile_buf + %3 = pto.alloc_tile : !pto.tile_buf + %4 = pto.alloc_tile : !pto.tile_buf + %5 = pto.alloc_tile : !pto.tile_buf + %6 = pto.alloc_tile : !pto.tile_buf + %7 = pto.alloc_tile : !pto.tile_buf + %8 = pto.alloc_tile : !pto.tile_buf + %9 = pto.alloc_tile : !pto.tile_buf + %10 = pto.alloc_tile : !pto.tile_buf + %11 = pto.alloc_tile : !pto.tile_buf + %12 = pto.alloc_tile : !pto.tile_buf + %13 = pto.alloc_tile : !pto.tile_buf + %14 = pto.alloc_tile : !pto.tile_buf + %15 = pto.alloc_tile : !pto.tile_buf + %16 = pto.alloc_tile : !pto.tile_buf + %17 = pto.alloc_tile : !pto.tile_buf + %18 = pto.alloc_tile : !pto.tile_buf + %19 = pto.alloc_tile : !pto.tile_buf + %20 = pto.get_subblock_idx + %21 = arith.trunci %20 : i64 to i32 + pto.section.cube { + %22 = arith.divsi %1, %arg9 : i32 + %23 = arith.remsi %22, %arg8 : i32 + %24 = arith.muli %23, %arg9 : i32 + %c8192_i32 = arith.constant 8192 : i32 + %25 = arith.muli %24, %c8192_i32 : i32 + %26 = arith.remsi %1, %arg9 : i32 + %c8192_i32_0 = arith.constant 8192 : i32 + %27 = arith.muli %26, %c8192_i32_0 : i32 + %28 = arith.addi %25, %27 : i32 + %29 = arith.divsi %1, %arg9 : i32 + %30 = arith.divsi %29, %arg8 : i32 + %c4_i32 = arith.constant 4 : i32 + %31 = arith.remsi %30, %c4_i32 : i32 + %c2048_i32 = arith.constant 2048 : i32 + %32 = arith.muli %31, %c2048_i32 : i32 + %33 = arith.addi %28, %32 : i32 + %c0_i32 = arith.constant 0 : i32 + %34 = arith.muli %arg8, %arg9 : i32 + %c8192_i32_1 = arith.constant 8192 : i32 + %35 = arith.muli %34, %c8192_i32_1 : i32 + %36 = arith.divsi %1, %arg9 : i32 + %37 = arith.divsi %36, %arg8 : i32 + %c4_i32_2 = arith.constant 4 : i32 + %38 = arith.remsi %37, %c4_i32_2 : i32 + %c2048_i32_3 = arith.constant 2048 : i32 + %39 = arith.muli %38, %c2048_i32_3 : i32 + %40 = arith.subi %35, %39 : i32 + %41 = arith.divsi %1, %arg9 : i32 + %42 = arith.remsi %41, %arg8 : i32 + %43 = arith.muli %42, %arg9 : i32 + %c8192_i32_4 = arith.constant 8192 : i32 + %44 = arith.muli %43, %c8192_i32_4 : i32 + %45 = arith.subi %40, %44 : i32 + %46 = arith.remsi %1, %arg9 : i32 + %c8192_i32_5 = arith.constant 8192 : i32 + %47 = arith.muli %46, %c8192_i32_5 : i32 + %48 = arith.subi %45, %47 : i32 + %c2048_i32_6 = arith.constant 2048 : i32 + %c1 = arith.constant 1 : index + %49 = arith.index_cast %arg8 : i32 to index + %50 = arith.index_cast %arg9 : i32 to index + %c64_i32 = arith.constant 64 : i32 + %51 = arith.index_cast %c64_i32 : i32 to index + %c128_i32 = arith.constant 128 : i32 + %52 = arith.index_cast %c128_i32 : i32 to index + %c1_7 = arith.constant 1 : index + %53 = arith.muli %c1_7, %52 : index + %54 = arith.muli %53, %51 : index + %55 = arith.muli %54, %50 : index + %56 = arith.muli %55, %49 : index + %57 = pto.make_tensor_view %arg0, shape = [%c1, %49, %50, %51, %52] strides = [%56, %55, %54, %53, %c1_7] : !pto.tensor_view + %c0 = arith.constant 0 : index + %58 = arith.index_cast %33 : i32 to index + %c1_8 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %59 = pto.partition_view %57, offsets = [%c0, %c0, %c0, %c0, %58], sizes = [%c1_8, %c1_8, %c1_8, %c16, %c128] : !pto.tensor_view -> !pto.partition_tensor_view<16x128xf16> + %60 = pto.tload ins(%59 : !pto.partition_tensor_view<16x128xf16>) outs(%2 : !pto.tile_buf) -> tensor<16x128xf16> + pto.barrier + %c0_i32_9 = arith.constant 0 : i32 + %61 = arith.index_cast %c0_i32_9 : i32 to index + %c32_i32 = arith.constant 32 : i32 + %62 = arith.index_cast %c32_i32 : i32 to index + %c1_10 = arith.constant 1 : index + %63 = arith.addi %61, %62 : index + scf.for %arg11 = %61 to %63 step %c1_10 { + %64 = arith.index_cast %arg11 : index to i32 + pto.sync.wait , 0 + pto.sync.wait , 16 + pto.barrier + %c8192_i32_11 = arith.constant 8192 : i32 + %65 = arith.muli %1, %c8192_i32_11 : i32 + %c0_i32_12 = arith.constant 0 : i32 + %c17891328_i32 = arith.constant 17891328 : i32 + %c8192_i32_13 = arith.constant 8192 : i32 + %66 = arith.muli %1, %c8192_i32_13 : i32 + %67 = arith.subi %c17891328_i32, %66 : i32 + %c8192_i32_14 = arith.constant 8192 : i32 + %c1_15 = arith.constant 1 : index + %c2184_i32 = arith.constant 2184 : i32 + %68 = arith.index_cast %c2184_i32 : i32 to index + %c64_i32_16 = arith.constant 64 : i32 + %69 = arith.index_cast %c64_i32_16 : i32 to index + %c128_i32_17 = arith.constant 128 : i32 + %70 = arith.index_cast %c128_i32_17 : i32 to index + %c1_18 = arith.constant 1 : index + %71 = arith.muli %c1_18, %70 : index + %72 = arith.muli %71, %69 : index + %73 = arith.muli %72, %68 : index + %74 = arith.muli %73, %c1_15 : index + %75 = pto.make_tensor_view %arg4, shape = [%c1_15, %c1_15, %68, %69, %70] strides = [%74, %73, %72, %71, %c1_18] : !pto.tensor_view + %c0_19 = arith.constant 0 : index + %76 = arith.index_cast %65 : i32 to index + %c1_20 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128_21 = arith.constant 128 : index + %77 = pto.partition_view %75, offsets = [%c0_19, %c0_19, %c0_19, %c0_19, %76], sizes = [%c1_20, %c1_20, %c1_20, %c64, %c128_21] : !pto.tensor_view -> !pto.partition_tensor_view<64x128xf16> + %78 = pto.tload ins(%77 : !pto.partition_tensor_view<64x128xf16>) outs(%3 : !pto.tile_buf) -> tensor<64x128xf16> + pto.barrier + pto.set_flag[, , ] + pto.wait_flag[, , ] + %79 = pto.alloc_tile : !pto.tile_buf + %80 = pto.tmov ins(%2 : !pto.tile_buf) outs(%79 : !pto.tile_buf) -> tensor<16x128xf16> + %81 = pto.alloc_tile : !pto.tile_buf + pto.treshape ins(%3 : !pto.tile_buf) outs(%81 : !pto.tile_buf) + %82 = pto.alloc_tile : !pto.tile_buf + %83 = pto.tmov ins(%81 : !pto.tile_buf) outs(%82 : !pto.tile_buf) -> tensor<128x64xf16> + pto.set_flag[, , ] + pto.wait_flag[, , ] + %84 = pto.tmatmul ins(%79, %82 : !pto.tile_buf, !pto.tile_buf) outs(%4 : !pto.tile_buf) -> tensor<16x64xf32> + pto.barrier + %c0_i32_22 = arith.constant 0 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %85 = arith.muli %1, %c1024_i32 : i32 + %c1024_i32_23 = arith.constant 1024 : i32 + %c2236416_i32 = arith.constant 2236416 : i32 + %c1024_i32_24 = arith.constant 1024 : i32 + %86 = arith.muli %1, %c1024_i32_24 : i32 + %87 = arith.subi %c2236416_i32, %86 : i32 + %c1_25 = arith.constant 1 : index + %c2184_i32_26 = arith.constant 2184 : i32 + %88 = arith.index_cast %c2184_i32_26 : i32 to index + %c16_i32 = arith.constant 16 : i32 + %89 = arith.index_cast %c16_i32 : i32 to index + %c64_i32_27 = arith.constant 64 : i32 + %90 = arith.index_cast %c64_i32_27 : i32 to index + %c1_28 = arith.constant 1 : index + %91 = arith.muli %c1_28, %90 : index + %92 = arith.muli %91, %89 : index + %93 = arith.muli %92, %88 : index + %94 = arith.muli %93, %c1_25 : index + %95 = pto.make_tensor_view %arg5, shape = [%c1_25, %c1_25, %88, %89, %90] strides = [%94, %93, %92, %91, %c1_28] : !pto.tensor_view + %c0_29 = arith.constant 0 : index + %96 = arith.index_cast %85 : i32 to index + %c1_30 = arith.constant 1 : index + %c16_31 = arith.constant 16 : index + %c64_32 = arith.constant 64 : index + %97 = pto.partition_view %95, offsets = [%c0_29, %c0_29, %c0_29, %c0_29, %96], sizes = [%c1_30, %c1_30, %c1_30, %c16_31, %c64_32] : !pto.tensor_view -> !pto.partition_tensor_view<16x64xf32> + %98 = pto.tstore ins(%4 : !pto.tile_buf) outs(%97 : !pto.partition_tensor_view<16x64xf32>) -> tensor<16x64xf32> + pto.barrier + pto.sync.set , 1 + pto.sync.set , 17 + pto.sync.wait , 2 + pto.sync.wait , 18 + pto.barrier + %c1024_i32_33 = arith.constant 1024 : i32 + %99 = arith.muli %1, %c1024_i32_33 : i32 + %c0_i32_34 = arith.constant 0 : i32 + %c2236416_i32_35 = arith.constant 2236416 : i32 + %c1024_i32_36 = arith.constant 1024 : i32 + %100 = arith.muli %1, %c1024_i32_36 : i32 + %101 = arith.subi %c2236416_i32_35, %100 : i32 + %c1024_i32_37 = arith.constant 1024 : i32 + %c1_38 = arith.constant 1 : index + %c2184_i32_39 = arith.constant 2184 : i32 + %102 = arith.index_cast %c2184_i32_39 : i32 to index + %c16_i32_40 = arith.constant 16 : i32 + %103 = arith.index_cast %c16_i32_40 : i32 to index + %c64_i32_41 = arith.constant 64 : i32 + %104 = arith.index_cast %c64_i32_41 : i32 to index + %c1_42 = arith.constant 1 : index + %105 = arith.muli %c1_42, %104 : index + %106 = arith.muli %105, %103 : index + %107 = arith.muli %106, %102 : index + %108 = arith.muli %107, %c1_38 : index + %109 = pto.make_tensor_view %arg6, shape = [%c1_38, %c1_38, %102, %103, %104] strides = [%108, %107, %106, %105, %c1_42] : !pto.tensor_view + %c0_43 = arith.constant 0 : index + %110 = arith.index_cast %99 : i32 to index + %c1_44 = arith.constant 1 : index + %c16_45 = arith.constant 16 : index + %c64_46 = arith.constant 64 : index + %111 = pto.partition_view %109, offsets = [%c0_43, %c0_43, %c0_43, %c0_43, %110], sizes = [%c1_44, %c1_44, %c1_44, %c16_45, %c64_46] : !pto.tensor_view -> !pto.partition_tensor_view<16x64xf16> + %112 = pto.tload ins(%111 : !pto.partition_tensor_view<16x64xf16>) outs(%5 : !pto.tile_buf) -> tensor<16x64xf16> + pto.barrier + pto.set_flag[, , ] + pto.wait_flag[, , ] + %113 = pto.alloc_tile : !pto.tile_buf + %114 = pto.tmov ins(%5 : !pto.tile_buf) outs(%113 : !pto.tile_buf) -> tensor<16x64xf16> + %115 = pto.alloc_tile : !pto.tile_buf + %116 = pto.tmov ins(%3 : !pto.tile_buf) outs(%115 : !pto.tile_buf) -> tensor<64x128xf16> + pto.set_flag[, , ] + pto.wait_flag[, , ] + %117 = pto.tmatmul ins(%113, %115 : !pto.tile_buf, !pto.tile_buf) outs(%6 : !pto.tile_buf) -> tensor<16x128xf32> + pto.barrier + %c0_i32_47 = arith.constant 0 : i32 + %c2048_i32_48 = arith.constant 2048 : i32 + %118 = arith.muli %1, %c2048_i32_48 : i32 + %c2048_i32_49 = arith.constant 2048 : i32 + %c4472832_i32 = arith.constant 4472832 : i32 + %c2048_i32_50 = arith.constant 2048 : i32 + %119 = arith.muli %1, %c2048_i32_50 : i32 + %120 = arith.subi %c4472832_i32, %119 : i32 + %c1_51 = arith.constant 1 : index + %c2184_i32_52 = arith.constant 2184 : i32 + %121 = arith.index_cast %c2184_i32_52 : i32 to index + %c16_i32_53 = arith.constant 16 : i32 + %122 = arith.index_cast %c16_i32_53 : i32 to index + %c128_i32_54 = arith.constant 128 : i32 + %123 = arith.index_cast %c128_i32_54 : i32 to index + %c1_55 = arith.constant 1 : index + %124 = arith.muli %c1_55, %123 : index + %125 = arith.muli %124, %122 : index + %126 = arith.muli %125, %121 : index + %127 = arith.muli %126, %c1_51 : index + %128 = pto.make_tensor_view %arg7, shape = [%c1_51, %c1_51, %121, %122, %123] strides = [%127, %126, %125, %124, %c1_55] : !pto.tensor_view + %c0_56 = arith.constant 0 : index + %129 = arith.index_cast %118 : i32 to index + %c1_57 = arith.constant 1 : index + %c16_58 = arith.constant 16 : index + %c128_59 = arith.constant 128 : index + %130 = pto.partition_view %128, offsets = [%c0_56, %c0_56, %c0_56, %c0_56, %129], sizes = [%c1_57, %c1_57, %c1_57, %c16_58, %c128_59] : !pto.tensor_view -> !pto.partition_tensor_view<16x128xf32> + %131 = pto.tstore ins(%6 : !pto.tile_buf) outs(%130 : !pto.partition_tensor_view<16x128xf32>) -> tensor<16x128xf32> + pto.barrier + pto.sync.set , 3 + pto.sync.set , 19 + pto.sync.wait , 4 + pto.sync.wait , 20 + } + pto.sync.wait , 8 + pto.sync.wait , 24 + } + pto.section.vector { + %cst = arith.constant 0.000000e+00 : f32 + pto.texpands ins(%cst : f32) outs(%7 : !pto.tile_buf) + %cst_0 = arith.constant 0.000000e+00 : f32 + pto.texpands ins(%cst_0 : f32) outs(%8 : !pto.tile_buf) + %cst_1 = arith.constant -1.07374182E+9 : f32 + pto.texpands ins(%cst_1 : f32) outs(%9 : !pto.tile_buf) + pto.barrier + %c0_i32 = arith.constant 0 : i32 + %22 = arith.index_cast %c0_i32 : i32 to index + %c32_i32 = arith.constant 32 : i32 + %23 = arith.index_cast %c32_i32 : i32 to index + %c1 = arith.constant 1 : index + %24 = arith.addi %22, %23 : index + scf.for %arg11 = %22 to %24 step %c1 { + %69 = arith.index_cast %arg11 : index to i32 + %70 = arith.divsi %1, %arg9 : i32 + %71 = arith.remsi %70, %arg8 : i32 + %72 = arith.muli %71, %arg9 : i32 + %c8192_i32_18 = arith.constant 8192 : i32 + %73 = arith.muli %72, %c8192_i32_18 : i32 + %74 = arith.remsi %1, %arg9 : i32 + %c8192_i32_19 = arith.constant 8192 : i32 + %75 = arith.muli %74, %c8192_i32_19 : i32 + %76 = arith.addi %73, %75 : i32 + %77 = arith.divsi %1, %arg9 : i32 + %78 = arith.divsi %77, %arg8 : i32 + %c4_i32_20 = arith.constant 4 : i32 + %79 = arith.remsi %78, %c4_i32_20 : i32 + %c2048_i32_21 = arith.constant 2048 : i32 + %80 = arith.muli %79, %c2048_i32_21 : i32 + %81 = arith.addi %76, %80 : i32 + %c64_i32_22 = arith.constant 64 : i32 + %82 = arith.muli %69, %c64_i32_22 : i32 + %83 = arith.addi %81, %82 : i32 + %c0_i32_23 = arith.constant 0 : i32 + %84 = arith.muli %arg8, %arg9 : i32 + %c8192_i32_24 = arith.constant 8192 : i32 + %85 = arith.muli %84, %c8192_i32_24 : i32 + %c64_i32_25 = arith.constant 64 : i32 + %86 = arith.muli %69, %c64_i32_25 : i32 + %87 = arith.subi %85, %86 : i32 + %88 = arith.divsi %1, %arg9 : i32 + %89 = arith.divsi %88, %arg8 : i32 + %c4_i32_26 = arith.constant 4 : i32 + %90 = arith.remsi %89, %c4_i32_26 : i32 + %c2048_i32_27 = arith.constant 2048 : i32 + %91 = arith.muli %90, %c2048_i32_27 : i32 + %92 = arith.subi %87, %91 : i32 + %93 = arith.divsi %1, %arg9 : i32 + %94 = arith.remsi %93, %arg8 : i32 + %95 = arith.muli %94, %arg9 : i32 + %c8192_i32_28 = arith.constant 8192 : i32 + %96 = arith.muli %95, %c8192_i32_28 : i32 + %97 = arith.subi %92, %96 : i32 + %98 = arith.remsi %1, %arg9 : i32 + %c8192_i32_29 = arith.constant 8192 : i32 + %99 = arith.muli %98, %c8192_i32_29 : i32 + %100 = arith.subi %97, %99 : i32 + %c64_i32_30 = arith.constant 64 : i32 + %c1_31 = arith.constant 1 : index + %101 = arith.index_cast %arg8 : i32 to index + %102 = arith.index_cast %arg9 : i32 to index + %c4_i32_32 = arith.constant 4 : i32 + %103 = arith.index_cast %c4_i32_32 : i32 to index + %c2048_i32_33 = arith.constant 2048 : i32 + %104 = arith.index_cast %c2048_i32_33 : i32 to index + %c1_34 = arith.constant 1 : index + %105 = arith.muli %c1_34, %104 : index + %106 = arith.muli %105, %103 : index + %107 = arith.muli %106, %102 : index + %108 = arith.muli %107, %101 : index + %109 = pto.make_tensor_view %arg2, shape = [%c1_31, %101, %102, %103, %104] strides = [%108, %107, %106, %105, %c1_34] : !pto.tensor_view + %c0_35 = arith.constant 0 : index + %110 = arith.index_cast %83 : i32 to index + %c1_36 = arith.constant 1 : index + %c1_37 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %111 = pto.partition_view %109, offsets = [%c0_35, %c0_35, %c0_35, %c0_35, %110], sizes = [%c1_36, %c1_36, %c1_36, %c1_37, %c64] : !pto.tensor_view -> !pto.partition_tensor_view<1x64xi32> + %112 = pto.tload ins(%111 : !pto.partition_tensor_view<1x64xi32>) outs(%10 : !pto.tile_buf) -> tensor<1x64xi32> + pto.barrier + %c0_i32_38 = arith.constant 0 : i32 + %113 = arith.index_cast %c0_i32_38 : i32 to index + %c32_i32_39 = arith.constant 32 : i32 + %114 = arith.index_cast %c32_i32_39 : i32 to index + %c1_40 = arith.constant 1 : index + %115 = arith.addi %113, %114 : index + scf.for %arg12 = %113 to %115 step %c1_40 { + %179 = arith.index_cast %arg12 : index to i32 + %180 = arith.divsi %1, %arg9 : i32 + %181 = arith.remsi %180, %arg8 : i32 + %182 = arith.muli %181, %arg10 : i32 + %c512_i32_96 = arith.constant 512 : i32 + %183 = arith.muli %182, %c512_i32_96 : i32 + %c32_i32_97 = arith.constant 32 : i32 + %184 = arith.muli %21, %c32_i32_97 : i32 + %185 = arith.addi %184, %179 : i32 + %186 = arith.index_cast %185 : i32 to index + %187 = pto.tgetval ins(%10, %186 : !pto.tile_buf, index) outs : i32 + %c512_i32_98 = arith.constant 512 : i32 + %188 = arith.muli %187, %c512_i32_98 : i32 + %189 = arith.addi %183, %188 : i32 + %190 = arith.divsi %1, %arg9 : i32 + %191 = arith.divsi %190, %arg8 : i32 + %c4_i32_99 = arith.constant 4 : i32 + %192 = arith.remsi %191, %c4_i32_99 : i32 + %c128_i32_100 = arith.constant 128 : i32 + %193 = arith.muli %192, %c128_i32_100 : i32 + %194 = arith.addi %189, %193 : i32 + %c0_i32_101 = arith.constant 0 : i32 + %195 = arith.muli %arg8, %arg10 : i32 + %c512_i32_102 = arith.constant 512 : i32 + %196 = arith.muli %195, %c512_i32_102 : i32 + %197 = arith.divsi %1, %arg9 : i32 + %198 = arith.divsi %197, %arg8 : i32 + %c4_i32_103 = arith.constant 4 : i32 + %199 = arith.remsi %198, %c4_i32_103 : i32 + %c128_i32_104 = arith.constant 128 : i32 + %200 = arith.muli %199, %c128_i32_104 : i32 + %201 = arith.subi %196, %200 : i32 + %202 = arith.divsi %1, %arg9 : i32 + %203 = arith.remsi %202, %arg8 : i32 + %204 = arith.muli %203, %arg10 : i32 + %c512_i32_105 = arith.constant 512 : i32 + %205 = arith.muli %204, %c512_i32_105 : i32 + %206 = arith.subi %201, %205 : i32 + %c32_i32_106 = arith.constant 32 : i32 + %207 = arith.muli %21, %c32_i32_106 : i32 + %208 = arith.addi %207, %179 : i32 + %209 = arith.index_cast %208 : i32 to index + %210 = pto.tgetval ins(%10, %209 : !pto.tile_buf, index) outs : i32 + %c512_i32_107 = arith.constant 512 : i32 + %211 = arith.muli %210, %c512_i32_107 : i32 + %212 = arith.subi %206, %211 : i32 + %c128_i32_108 = arith.constant 128 : i32 + %c1_109 = arith.constant 1 : index + %213 = arith.index_cast %arg8 : i32 to index + %214 = arith.index_cast %arg10 : i32 to index + %c4_i32_110 = arith.constant 4 : i32 + %215 = arith.index_cast %c4_i32_110 : i32 to index + %c128_i32_111 = arith.constant 128 : i32 + %216 = arith.index_cast %c128_i32_111 : i32 to index + %c1_112 = arith.constant 1 : index + %217 = arith.muli %c1_112, %216 : index + %218 = arith.muli %217, %215 : index + %219 = arith.muli %218, %214 : index + %220 = arith.muli %219, %213 : index + %221 = pto.make_tensor_view %arg1, shape = [%c1_109, %213, %214, %215, %216] strides = [%220, %219, %218, %217, %c1_112] : !pto.tensor_view + %c0_113 = arith.constant 0 : index + %222 = arith.index_cast %194 : i32 to index + %c1_114 = arith.constant 1 : index + %c1_115 = arith.constant 1 : index + %c128_116 = arith.constant 128 : index + %223 = pto.partition_view %221, offsets = [%c0_113, %c0_113, %c0_113, %c0_113, %222], sizes = [%c1_114, %c1_114, %c1_114, %c1_115, %c128_116] : !pto.tensor_view -> !pto.partition_tensor_view<1x128xf16> + %224 = pto.tload ins(%223 : !pto.partition_tensor_view<1x128xf16>) outs(%11 : !pto.tile_buf) -> tensor<1x128xf16> + pto.barrier + %c0_i32_117 = arith.constant 0 : i32 + %c8192_i32_118 = arith.constant 8192 : i32 + %225 = arith.muli %1, %c8192_i32_118 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %226 = arith.muli %21, %c4096_i32 : i32 + %227 = arith.addi %225, %226 : i32 + %c128_i32_119 = arith.constant 128 : i32 + %228 = arith.muli %179, %c128_i32_119 : i32 + %229 = arith.addi %227, %228 : i32 + %c128_i32_120 = arith.constant 128 : i32 + %c17891328_i32 = arith.constant 17891328 : i32 + %c128_i32_121 = arith.constant 128 : i32 + %230 = arith.muli %179, %c128_i32_121 : i32 + %231 = arith.subi %c17891328_i32, %230 : i32 + %c4096_i32_122 = arith.constant 4096 : i32 + %232 = arith.muli %21, %c4096_i32_122 : i32 + %233 = arith.subi %231, %232 : i32 + %c8192_i32_123 = arith.constant 8192 : i32 + %234 = arith.muli %1, %c8192_i32_123 : i32 + %235 = arith.subi %233, %234 : i32 + %c1_124 = arith.constant 1 : index + %c2184_i32_125 = arith.constant 2184 : i32 + %236 = arith.index_cast %c2184_i32_125 : i32 to index + %c64_i32_126 = arith.constant 64 : i32 + %237 = arith.index_cast %c64_i32_126 : i32 to index + %c128_i32_127 = arith.constant 128 : i32 + %238 = arith.index_cast %c128_i32_127 : i32 to index + %c1_128 = arith.constant 1 : index + %239 = arith.muli %c1_128, %238 : index + %240 = arith.muli %239, %237 : index + %241 = arith.muli %240, %236 : index + %242 = arith.muli %241, %c1_124 : index + %243 = pto.make_tensor_view %arg4, shape = [%c1_124, %c1_124, %236, %237, %238] strides = [%242, %241, %240, %239, %c1_128] : !pto.tensor_view + %c0_129 = arith.constant 0 : index + %244 = arith.index_cast %229 : i32 to index + %c1_130 = arith.constant 1 : index + %c1_131 = arith.constant 1 : index + %c128_132 = arith.constant 128 : index + %245 = pto.partition_view %243, offsets = [%c0_129, %c0_129, %c0_129, %c0_129, %244], sizes = [%c1_130, %c1_130, %c1_130, %c1_131, %c128_132] : !pto.tensor_view -> !pto.partition_tensor_view<1x128xf16> + %246 = pto.tstore ins(%11 : !pto.tile_buf) outs(%245 : !pto.partition_tensor_view<1x128xf16>) -> tensor<1x128xf16> + pto.barrier + } + pto.sync.set , 0 + %cst_41 = arith.constant 0.000000e+00 : f32 + pto.texpands ins(%cst_41 : f32) outs(%12 : !pto.tile_buf) + pto.barrier + %c0_i32_42 = arith.constant 0 : i32 + %c0_i32_43 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c8_i32_44 = arith.constant 8 : i32 + %116 = pto.tmov ins(%9 : !pto.tile_buf) outs(%13 : !pto.tile_buf) -> tensor<1x8xf32> + pto.barrier + pto.sync.wait , 1 + %c1024_i32_45 = arith.constant 1024 : i32 + %117 = arith.muli %1, %c1024_i32_45 : i32 + %c512_i32 = arith.constant 512 : i32 + %118 = arith.muli %21, %c512_i32 : i32 + %119 = arith.addi %117, %118 : i32 + %c0_i32_46 = arith.constant 0 : i32 + %c2236416_i32 = arith.constant 2236416 : i32 + %c512_i32_47 = arith.constant 512 : i32 + %120 = arith.muli %21, %c512_i32_47 : i32 + %121 = arith.subi %c2236416_i32, %120 : i32 + %c1024_i32_48 = arith.constant 1024 : i32 + %122 = arith.muli %1, %c1024_i32_48 : i32 + %123 = arith.subi %121, %122 : i32 + %c512_i32_49 = arith.constant 512 : i32 + %c1_50 = arith.constant 1 : index + %c2184_i32 = arith.constant 2184 : i32 + %124 = arith.index_cast %c2184_i32 : i32 to index + %c16_i32 = arith.constant 16 : i32 + %125 = arith.index_cast %c16_i32 : i32 to index + %c64_i32_51 = arith.constant 64 : i32 + %126 = arith.index_cast %c64_i32_51 : i32 to index + %c1_52 = arith.constant 1 : index + %127 = arith.muli %c1_52, %126 : index + %128 = arith.muli %127, %125 : index + %129 = arith.muli %128, %124 : index + %130 = arith.muli %129, %c1_50 : index + %131 = pto.make_tensor_view %arg5, shape = [%c1_50, %c1_50, %124, %125, %126] strides = [%130, %129, %128, %127, %c1_52] : !pto.tensor_view + %c0_53 = arith.constant 0 : index + %132 = arith.index_cast %119 : i32 to index + %c1_54 = arith.constant 1 : index + %c8_55 = arith.constant 8 : index + %c64_56 = arith.constant 64 : index + %133 = pto.partition_view %131, offsets = [%c0_53, %c0_53, %c0_53, %c0_53, %132], sizes = [%c1_54, %c1_54, %c1_54, %c8_55, %c64_56] : !pto.tensor_view -> !pto.partition_tensor_view<8x64xf32> + %134 = pto.tload ins(%133 : !pto.partition_tensor_view<8x64xf32>) outs(%14 : !pto.tile_buf) -> tensor<8x64xf32> + pto.barrier + pto.tadd ins(%12, %14 : !pto.tile_buf, !pto.tile_buf) outs(%12 : !pto.tile_buf) + %cst_57 = arith.constant 0.0883883461 : f32 + pto.tmuls ins(%12, %cst_57 : !pto.tile_buf, f32) outs(%12 : !pto.tile_buf) + %135 = pto.alloc_tile : !pto.tile_buf + %136 = pto.alloc_tile : !pto.tile_buf + pto.treshape ins(%9 : !pto.tile_buf) outs(%136 : !pto.tile_buf) + pto.trowmax ins(%12, %135 : !pto.tile_buf, !pto.tile_buf) outs(%136 : !pto.tile_buf) + pto.treshape ins(%136 : !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmax ins(%9, %13 : !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tsub ins(%13, %9 : !pto.tile_buf, !pto.tile_buf) outs(%13 : !pto.tile_buf) + pto.texp ins(%13 : !pto.tile_buf) outs(%13 : !pto.tile_buf) + %c0_i32_58 = arith.constant 0 : i32 + %137 = arith.index_cast %c0_i32_58 : i32 to index + %c8_i32_59 = arith.constant 8 : i32 + %138 = arith.index_cast %c8_i32_59 : i32 to index + %c1_60 = arith.constant 1 : index + %139 = arith.addi %137, %138 : index + scf.for %arg12 = %137 to %139 step %c1_60 { + %179 = arith.index_cast %arg12 : index to i32 + %c0_96 = arith.constant 0 : index + %c64_i32_97 = arith.constant 64 : i32 + %180 = arith.muli %179, %c64_i32_97 : i32 + %181 = arith.index_cast %180 : i32 to index + %182 = pto.subset %12[%c0_96, %181] sizes [8, 64] : !pto.tile_buf + %c0_98 = arith.constant 0 : index + %c64_i32_99 = arith.constant 64 : i32 + %183 = arith.muli %179, %c64_i32_99 : i32 + %184 = arith.index_cast %183 : i32 to index + %185 = pto.subset %12[%c0_98, %184] sizes [8, 64] : !pto.tile_buf + %186 = arith.index_cast %179 : i32 to index + %187 = pto.tgetval ins(%9, %186 : !pto.tile_buf, index) outs : f32 + pto.tsubs ins(%185, %187 : !pto.tile_buf, f32) outs(%182 : !pto.tile_buf) + } + pto.texp ins(%12 : !pto.tile_buf) outs(%12 : !pto.tile_buf) + %140 = pto.alloc_tile : !pto.tile_buf + %141 = pto.alloc_tile : !pto.tile_buf + pto.treshape ins(%16 : !pto.tile_buf) outs(%141 : !pto.tile_buf) + pto.trowsum ins(%12, %140 : !pto.tile_buf, !pto.tile_buf) outs(%141 : !pto.tile_buf) + pto.treshape ins(%141 : !pto.tile_buf) outs(%16 : !pto.tile_buf) + pto.tmul ins(%8, %13 : !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tadd ins(%8, %16 : !pto.tile_buf, !pto.tile_buf) outs(%8 : !pto.tile_buf) + %142 = pto.alloc_tile : !pto.tile_buf + pto.treshape ins(%13 : !pto.tile_buf) outs(%142 : !pto.tile_buf) + pto.trowexpandmul ins(%7, %142 : !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.treshape ins(%142 : !pto.tile_buf) outs(%13 : !pto.tile_buf) + %c0_i32_61 = arith.constant 0 : i32 + %c0_i32_62 = arith.constant 0 : i32 + %c512_i32_63 = arith.constant 512 : i32 + %c512_i32_64 = arith.constant 512 : i32 + pto.tcvt ins(%12 : !pto.tile_buf) outs(%17 : !pto.tile_buf) + pto.barrier + %c0_i32_65 = arith.constant 0 : i32 + %c1024_i32_66 = arith.constant 1024 : i32 + %143 = arith.muli %1, %c1024_i32_66 : i32 + %c512_i32_67 = arith.constant 512 : i32 + %144 = arith.muli %21, %c512_i32_67 : i32 + %145 = arith.addi %143, %144 : i32 + %c512_i32_68 = arith.constant 512 : i32 + %c2236416_i32_69 = arith.constant 2236416 : i32 + %c512_i32_70 = arith.constant 512 : i32 + %146 = arith.muli %21, %c512_i32_70 : i32 + %147 = arith.subi %c2236416_i32_69, %146 : i32 + %c1024_i32_71 = arith.constant 1024 : i32 + %148 = arith.muli %1, %c1024_i32_71 : i32 + %149 = arith.subi %147, %148 : i32 + %c1_72 = arith.constant 1 : index + %c2184_i32_73 = arith.constant 2184 : i32 + %150 = arith.index_cast %c2184_i32_73 : i32 to index + %c16_i32_74 = arith.constant 16 : i32 + %151 = arith.index_cast %c16_i32_74 : i32 to index + %c64_i32_75 = arith.constant 64 : i32 + %152 = arith.index_cast %c64_i32_75 : i32 to index + %c1_76 = arith.constant 1 : index + %153 = arith.muli %c1_76, %152 : index + %154 = arith.muli %153, %151 : index + %155 = arith.muli %154, %150 : index + %156 = arith.muli %155, %c1_72 : index + %157 = pto.make_tensor_view %arg6, shape = [%c1_72, %c1_72, %150, %151, %152] strides = [%156, %155, %154, %153, %c1_76] : !pto.tensor_view + %c0_77 = arith.constant 0 : index + %158 = arith.index_cast %145 : i32 to index + %c1_78 = arith.constant 1 : index + %c8_79 = arith.constant 8 : index + %c64_80 = arith.constant 64 : index + %159 = pto.partition_view %157, offsets = [%c0_77, %c0_77, %c0_77, %c0_77, %158], sizes = [%c1_78, %c1_78, %c1_78, %c8_79, %c64_80] : !pto.tensor_view -> !pto.partition_tensor_view<8x64xf16> + %160 = pto.tstore ins(%17 : !pto.tile_buf) outs(%159 : !pto.partition_tensor_view<8x64xf16>) -> tensor<8x64xf16> + pto.barrier + pto.sync.set , 2 + pto.sync.wait , 3 + pto.barrier + %c2048_i32_81 = arith.constant 2048 : i32 + %161 = arith.muli %1, %c2048_i32_81 : i32 + %c1024_i32_82 = arith.constant 1024 : i32 + %162 = arith.muli %21, %c1024_i32_82 : i32 + %163 = arith.addi %161, %162 : i32 + %c0_i32_83 = arith.constant 0 : i32 + %c4472832_i32 = arith.constant 4472832 : i32 + %c1024_i32_84 = arith.constant 1024 : i32 + %164 = arith.muli %21, %c1024_i32_84 : i32 + %165 = arith.subi %c4472832_i32, %164 : i32 + %c2048_i32_85 = arith.constant 2048 : i32 + %166 = arith.muli %1, %c2048_i32_85 : i32 + %167 = arith.subi %165, %166 : i32 + %c1024_i32_86 = arith.constant 1024 : i32 + %c1_87 = arith.constant 1 : index + %c2184_i32_88 = arith.constant 2184 : i32 + %168 = arith.index_cast %c2184_i32_88 : i32 to index + %c16_i32_89 = arith.constant 16 : i32 + %169 = arith.index_cast %c16_i32_89 : i32 to index + %c128_i32_90 = arith.constant 128 : i32 + %170 = arith.index_cast %c128_i32_90 : i32 to index + %c1_91 = arith.constant 1 : index + %171 = arith.muli %c1_91, %170 : index + %172 = arith.muli %171, %169 : index + %173 = arith.muli %172, %168 : index + %174 = arith.muli %173, %c1_87 : index + %175 = pto.make_tensor_view %arg7, shape = [%c1_87, %c1_87, %168, %169, %170] strides = [%174, %173, %172, %171, %c1_91] : !pto.tensor_view + %c0_92 = arith.constant 0 : index + %176 = arith.index_cast %163 : i32 to index + %c1_93 = arith.constant 1 : index + %c8_94 = arith.constant 8 : index + %c128_95 = arith.constant 128 : index + %177 = pto.partition_view %175, offsets = [%c0_92, %c0_92, %c0_92, %c0_92, %176], sizes = [%c1_93, %c1_93, %c1_93, %c8_94, %c128_95] : !pto.tensor_view -> !pto.partition_tensor_view<8x128xf32> + %178 = pto.tload ins(%177 : !pto.partition_tensor_view<8x128xf32>) outs(%18 : !pto.tile_buf) -> tensor<8x128xf32> + pto.barrier + pto.tadd ins(%7, %18 : !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.barrier + pto.sync.set , 4 + pto.barrier + } + %25 = pto.alloc_tile : !pto.tile_buf + pto.treshape ins(%8 : !pto.tile_buf) outs(%25 : !pto.tile_buf) + pto.trowexpanddiv ins(%7, %25 : !pto.tile_buf, !pto.tile_buf) outs(%7 : !pto.tile_buf) + pto.treshape ins(%25 : !pto.tile_buf) outs(%8 : !pto.tile_buf) + %c0_i32_2 = arith.constant 0 : i32 + %c0_i32_3 = arith.constant 0 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1024_i32_4 = arith.constant 1024 : i32 + pto.tcvt ins(%7 : !pto.tile_buf) outs(%19 : !pto.tile_buf) + pto.barrier + %c0_i32_5 = arith.constant 0 : i32 + %26 = arith.divsi %1, %arg9 : i32 + %27 = arith.remsi %26, %arg8 : i32 + %28 = arith.muli %27, %arg9 : i32 + %c8192_i32 = arith.constant 8192 : i32 + %29 = arith.muli %28, %c8192_i32 : i32 + %30 = arith.remsi %1, %arg9 : i32 + %c8192_i32_6 = arith.constant 8192 : i32 + %31 = arith.muli %30, %c8192_i32_6 : i32 + %32 = arith.addi %29, %31 : i32 + %33 = arith.divsi %1, %arg9 : i32 + %34 = arith.divsi %33, %arg8 : i32 + %c4_i32 = arith.constant 4 : i32 + %35 = arith.remsi %34, %c4_i32 : i32 + %c2048_i32 = arith.constant 2048 : i32 + %36 = arith.muli %35, %c2048_i32 : i32 + %37 = arith.addi %32, %36 : i32 + %c1024_i32_7 = arith.constant 1024 : i32 + %38 = arith.muli %21, %c1024_i32_7 : i32 + %39 = arith.addi %37, %38 : i32 + %c1024_i32_8 = arith.constant 1024 : i32 + %40 = arith.muli %arg8, %arg9 : i32 + %c8192_i32_9 = arith.constant 8192 : i32 + %41 = arith.muli %40, %c8192_i32_9 : i32 + %c1024_i32_10 = arith.constant 1024 : i32 + %42 = arith.muli %21, %c1024_i32_10 : i32 + %43 = arith.subi %41, %42 : i32 + %44 = arith.divsi %1, %arg9 : i32 + %45 = arith.divsi %44, %arg8 : i32 + %c4_i32_11 = arith.constant 4 : i32 + %46 = arith.remsi %45, %c4_i32_11 : i32 + %c2048_i32_12 = arith.constant 2048 : i32 + %47 = arith.muli %46, %c2048_i32_12 : i32 + %48 = arith.subi %43, %47 : i32 + %49 = arith.divsi %1, %arg9 : i32 + %50 = arith.remsi %49, %arg8 : i32 + %51 = arith.muli %50, %arg9 : i32 + %c8192_i32_13 = arith.constant 8192 : i32 + %52 = arith.muli %51, %c8192_i32_13 : i32 + %53 = arith.subi %48, %52 : i32 + %54 = arith.remsi %1, %arg9 : i32 + %c8192_i32_14 = arith.constant 8192 : i32 + %55 = arith.muli %54, %c8192_i32_14 : i32 + %56 = arith.subi %53, %55 : i32 + %c1_15 = arith.constant 1 : index + %57 = arith.index_cast %arg8 : i32 to index + %58 = arith.index_cast %arg9 : i32 to index + %c64_i32 = arith.constant 64 : i32 + %59 = arith.index_cast %c64_i32 : i32 to index + %c128_i32 = arith.constant 128 : i32 + %60 = arith.index_cast %c128_i32 : i32 to index + %c1_16 = arith.constant 1 : index + %61 = arith.muli %c1_16, %60 : index + %62 = arith.muli %61, %59 : index + %63 = arith.muli %62, %58 : index + %64 = arith.muli %63, %57 : index + %65 = pto.make_tensor_view %arg3, shape = [%c1_15, %57, %58, %59, %60] strides = [%64, %63, %62, %61, %c1_16] : !pto.tensor_view + %c0 = arith.constant 0 : index + %66 = arith.index_cast %39 : i32 to index + %c1_17 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %67 = pto.partition_view %65, offsets = [%c0, %c0, %c0, %c0, %66], sizes = [%c1_17, %c1_17, %c1_17, %c8, %c128] : !pto.tensor_view -> !pto.partition_tensor_view<8x128xf16> + %68 = pto.tstore ins(%19 : !pto.tile_buf) outs(%67 : !pto.partition_tensor_view<8x128xf16>) -> tensor<8x128xf16> + pto.barrier + pto.sync.set , 8 + } + return + } +} +""" + +if __name__ == "__main__": + print(build()) From 1c84994b47a5307f65505bfa08863162288f8d3b Mon Sep 17 00:00:00 2001 From: Tao-Tao-real <2510737554@qq.com> Date: Sat, 14 Feb 2026 09:38:06 +0800 Subject: [PATCH 2/4] Keep subset valid dims when parent valid equals size --- lib/PTO/IR/PTO.cpp | 11 ++++++++--- lib/PTO/Transforms/PTOViewToMemref.cpp | 3 +++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index b786689d..936a6750 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6735,9 +6735,14 @@ LogicalResult SubsetOp::inferReturnTypes( auto offOpt = getConstIndexValue(operands[1 + i]); if (offOpt) { off = *offOpt; - int64_t diff = pv - off; - if (diff < 0) diff = 0; - vdim = std::min(sizeDim, diff); + // If parent valid equals subset size, keep size regardless of offset. + if (pv == sizeDim) { + vdim = sizeDim; + } else { + int64_t diff = pv - off; + if (diff < 0) diff = 0; + vdim = std::min(sizeDim, diff); + } } else { vdim = sizeDim; } diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 01619925..e595b57c 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -295,6 +295,9 @@ static Value computeSubsetValidDim(IRRewriter &rewriter, Location loc, int64_t pvConst = 0, offConst = 0; if (getConstIndexValue(parentValid, pvConst) && getConstIndexValue(offset, offConst)) { + if (pvConst == size) { + return sizeVal; + } int64_t diff = pvConst - offConst; if (diff < 0) diff = 0; int64_t clipped = std::min(size, diff); From 2b64e7ae8c7c761a7d9aba0541c234f645553375 Mon Sep 17 00:00:00 2001 From: Tao-Tao-real <2510737554@qq.com> Date: Sat, 14 Feb 2026 10:21:15 +0800 Subject: [PATCH 3/4] Keep subset valid dims static when dynamic --- lib/PTO/IR/PTO.cpp | 8 +++----- lib/PTO/Transforms/PTOViewToMemref.cpp | 13 ++----------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 936a6750..df7f67ea 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -6724,18 +6724,16 @@ LogicalResult SubsetOp::inferReturnTypes( if (parentValid.size() == resultShape.size()) { int64_t pv = parentValid[i]; + // In current subset usage, valid dims are treated as static. + // Only refine when both parent valid and offset are compile-time constants. if (pv == ShapedType::kDynamic) { - vdim = ShapedType::kDynamic; + vdim = sizeDim; } else { - // Only refine when offset is a compile-time constant. - // If offset is dynamic, keep static valid dims equal to size to - // avoid type instability across uses. int64_t off = 0; if (operands.size() > 1 + i) { auto offOpt = getConstIndexValue(operands[1 + i]); if (offOpt) { off = *offOpt; - // If parent valid equals subset size, keep size regardless of offset. if (pv == sizeDim) { vdim = sizeDim; } else { diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index e595b57c..6755b8c6 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -303,17 +303,8 @@ static Value computeSubsetValidDim(IRRewriter &rewriter, Location loc, int64_t clipped = std::min(size, diff); return rewriter.create(loc, clipped); } - - Value pv = ensureIndex(rewriter, loc, parentValid, anchorOp); - Value off = ensureIndex(rewriter, loc, offset, anchorOp); - Value diff = rewriter.create(loc, pv, off); - Value zero = rewriter.create(loc, 0); - Value gt = - rewriter.create(loc, arith::CmpIPredicate::sgt, diff, zero); - Value nonNeg = rewriter.create(loc, gt, diff, zero); - Value lt = rewriter.create(loc, arith::CmpIPredicate::slt, - nonNeg, sizeVal); - return rewriter.create(loc, lt, nonNeg, sizeVal); + // Keep static valid dims when runtime values are not constant. + return sizeVal; } static void dumpPretty(Operation *op, llvm::raw_ostream &os) { From b2d115146c80e54ac3622f76e14806a9adea50b2 Mon Sep 17 00:00:00 2001 From: Tao-Tao-real <2510737554@qq.com> Date: Tue, 24 Feb 2026 11:42:42 +0800 Subject: [PATCH 4/4] Update MatMul sample outputs --- test/samples/MatMul/tmatmulk.cpp | 287 +++++++++++++++++++++---------- test/samples/MatMul/tmatmulk.pto | 34 ++-- 2 files changed, 212 insertions(+), 109 deletions(-) diff --git a/test/samples/MatMul/tmatmulk.cpp b/test/samples/MatMul/tmatmulk.cpp index af643f0b..ffa4709b 100644 --- a/test/samples/MatMul/tmatmulk.cpp +++ b/test/samples/MatMul/tmatmulk.cpp @@ -1,120 +1,223 @@ #include "pto/pto-inst.hpp" using namespace pto; + + template + static inline To ptoas_bitcast(From from) { + static_assert(sizeof(To) == sizeof(From), "ptoas_bitcast: size mismatch"); + To to; + __builtin_memcpy(&to, &from, sizeof(To)); + return to; + } + __global__ AICORE void RunTMATMULSplitK(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, __gm__ float* v4, bool v5) { - unsigned v6 = 1; - unsigned v7 = 0; - int32_t v8 = 8; - int32_t v9 = 256; - int32_t v10 = 32; - int32_t v11 = 1; - int32_t v12 = 0; - int64_t v13 = 0; - int64_t v14 = 4096; - int64_t v15 = 8192; - using T = float; - Tile v16; - TASSIGN(v16, v13); + int64_t v6; + int64_t v7; + int64_t v8; + int32_t v9; + size_t v10; + int32_t v11; + size_t v12; + int32_t v13; + int32_t v14; + int32_t v15; + size_t v16; Tile v17; - TASSIGN(v17, v14); - Tile v18; - TASSIGN(v18, v15); - Tile v19; - TASSIGN(v19, v13); - Tile v20; - TASSIGN(v20, v13); - Tile v21; - TASSIGN(v21, v13); - Tile v22; - TASSIGN(v22, v13); - for (int32_t v23 = v12; v23 < v8; v23 += v11) { - int32_t v24 = v23 * v10; - unsigned v25 = (unsigned) v9; - unsigned v26 = v7 * v25; - unsigned v27 = v7 + v26; - unsigned v28 = (unsigned) v24; - unsigned v29 = (unsigned) v11; - unsigned v30 = v28 * v29; - unsigned v31 = v27 + v30; - __gm__ float* v32 = v2 + v31; - using GTShape_94586210699536 = pto::Shape<1, 1, 1, 32, 32>;; - using GTStride_94586210699536 = pto::Stride<1024, 1024, 1024, 32, 1>;; - GTShape_94586210699536 v33 = GTShape_94586210699536(); - GTStride_94586210699536 v34 = GTStride_94586210699536(); - using GT_94586210699536 = GlobalTensor;; - GT_94586210699536 v35 = GT_94586210699536(v32, v33, v34); - unsigned v36 = (unsigned) v24; - unsigned v37 = (unsigned) v10; - unsigned v38 = v36 * v37; - unsigned v39 = v7 + v38; - unsigned v40 = (unsigned) v11; - unsigned v41 = v7 * v40; - unsigned v42 = v39 + v41; - __gm__ float* v43 = v3 + v42; - using GTShape_94586210700880 = pto::Shape<1, 1, 1, 32, 32>;; - using GTStride_94586210700880 = pto::Stride<1024, 1024, 1024, 32, 1>;; - GTShape_94586210700880 v44 = GTShape_94586210700880(); - GTStride_94586210700880 v45 = GTStride_94586210700880(); - using GT_94586210700880 = GlobalTensor;; - GT_94586210700880 v46 = GT_94586210700880(v43, v44, v45); - unsigned v47 = (unsigned) v10; - unsigned v48 = v7 * v47; - unsigned v49 = v7 + v48; - unsigned v50 = (unsigned) v11; - unsigned v51 = v7 * v50; - unsigned v52 = v49 + v51; - __gm__ float* v53 = v4 + v52; - using GTShape_94586210701088 = pto::Shape<1, 1, 1, 1, 32>;; - using GTStride_94586210701088 = pto::Stride<32, 32, 32, 32, 1>;; - GTShape_94586210701088 v54 = GTShape_94586210701088(); - GTStride_94586210701088 v55 = GTStride_94586210701088(); - using GT_94586210701088 = GlobalTensor;; - GT_94586210701088 v56 = GT_94586210701088(v53, v54, v55); - TLOAD(v16, v35); - TLOAD(v17, v46); + Tile v18; + Tile v19; + Tile v20; + Tile v21; + Tile v22; + Tile v23; + int32_t v24; + uint32_t v25; + uint32_t v26; + uint32_t v27; + int32_t v28; + unsigned v29; + unsigned v30; + unsigned v31; + unsigned v32; + unsigned v33; + unsigned v34; + unsigned v35; + unsigned v36; + unsigned v37; + unsigned v38; + unsigned v39; + __gm__ float* v40; + pto::Shape<1, 1, 1, 32, 32> v41; + pto::Stride<8192, 8192, 8192, 256, 1> v42; + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v43(nullptr); + unsigned v44; + unsigned v45; + unsigned v46; + unsigned v47; + unsigned v48; + unsigned v49; + unsigned v50; + unsigned v51; + unsigned v52; + unsigned v53; + unsigned v54; + __gm__ float* v55; + pto::Shape<1, 1, 1, 32, 32> v56; + pto::Stride<1024, 1024, 1024, 32, 1> v57; + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v58(nullptr); + unsigned v59; + unsigned v60; + unsigned v61; + unsigned v62; + unsigned v63; + unsigned v64; + unsigned v65; + unsigned v66; + unsigned v67; + unsigned v68; + unsigned v69; + __gm__ float* v70; + pto::Shape<1, 1, 1, 1, 32> v71; + pto::Stride<32, 32, 32, 32, 1> v72; + GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND> v73(nullptr); + bool v74; + unsigned v75; + unsigned v76; + unsigned v77; + unsigned v78; + unsigned v79; + unsigned v80; + unsigned v81; + unsigned v82; + unsigned v83; + unsigned v84; + unsigned v85; + __gm__ float* v86; + pto::Shape<1, 1, 1, 32, 32> v87; + pto::Stride<1024, 1024, 1024, 32, 1> v88; + GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND> v89(nullptr); + using T = float; + v6 = 8192; + v7 = 4096; + v8 = 0; + v9 = 0; + v10 = (size_t) v9; + v11 = 1; + v12 = (size_t) v11; + v13 = 32; + v14 = 256; + v15 = 8; + v16 = (size_t) v15; + ; + TASSIGN(v17, v8); + ; + TASSIGN(v18, v7); + ; + TASSIGN(v19, v6); + ; + TASSIGN(v20, v8); + ; + TASSIGN(v21, v8); + ; + TASSIGN(v22, v8); + ; + TASSIGN(v23, v8); + for (size_t v90 = v10; v90 < v16; v90 += v12) { + v24 = (int32_t) v90; + v25 = (uint32_t) v24; + v26 = (uint32_t) v13; + v27 = v25 * v26; + v28 = (int32_t) v27; + v29 = 0; + v30 = 0; + v31 = 1; + v32 = (unsigned) v14; + v33 = v30 * v32; + v34 = v29 + v33; + v35 = (unsigned) v28; + v36 = 1; + v37 = (unsigned) v11; + v38 = v35 * v37; + v39 = v34 + v38; + v40 = v2 + v39; + v41 = pto::Shape<1, 1, 1, 32, 32>(); + v42 = pto::Stride<8192, 8192, 8192, 256, 1>(); + v43 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v40, v41, v42); + v44 = 0; + v45 = (unsigned) v28; + v46 = 1; + v47 = (unsigned) v13; + v48 = v45 * v47; + v49 = v44 + v48; + v50 = 0; + v51 = 1; + v52 = (unsigned) v11; + v53 = v50 * v52; + v54 = v49 + v53; + v55 = v3 + v54; + v56 = pto::Shape<1, 1, 1, 32, 32>(); + v57 = pto::Stride<1024, 1024, 1024, 32, 1>(); + v58 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v55, v56, v57); + v59 = 0; + v60 = 0; + v61 = 1; + v62 = (unsigned) v13; + v63 = v60 * v62; + v64 = v59 + v63; + v65 = 0; + v66 = 1; + v67 = (unsigned) v11; + v68 = v65 * v67; + v69 = v64 + v68; + v70 = v4 + v69; + v71 = pto::Shape<1, 1, 1, 1, 32>(); + v72 = pto::Stride<32, 32, 32, 32, 1>(); + v73 = GlobalTensor, pto::Stride<32, 32, 32, 32, 1>, pto::Layout::ND>(v70, v71, v72); + TLOAD(v17, v43); + TLOAD(v18, v58); if (v5) { - TLOAD(v18, v56); + TLOAD(v19, v73); } else { }; set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); - TMOV(v19, v16); TMOV(v20, v17); + TMOV(v21, v18); if (v5) { - TMOV(v22, v18); + TMOV(v23, v19); } else { }; set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); - bool v57 = v23 == v12; - if (v57) { + v74 = v24 == v9; + if (v74) { if (v5) { - TMATMUL_BIAS(v21, v19, v20, v22); + TMATMUL_BIAS(v22, v20, v21, v23); } else { - TMATMUL(v21, v19, v20); + TMATMUL(v22, v20, v21); }; } else { - TMATMUL_ACC(v21, v21, v19, v20); + TMATMUL_ACC(v22, v22, v20, v21); }; set_flag(PIPE_M, PIPE_MTE2, EVENT_ID0); wait_flag(PIPE_M, PIPE_MTE2, EVENT_ID0); } set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); - unsigned v58 = (unsigned) v10; - unsigned v59 = v7 * v58; - unsigned v60 = v7 + v59; - unsigned v61 = (unsigned) v11; - unsigned v62 = v7 * v61; - unsigned v63 = v60 + v62; - __gm__ float* v64 = v1 + v63; - using GTShape_94586210701264 = pto::Shape<1, 1, 1, 32, 32>; - using GTStride_94586210701264 = pto::Stride<1024, 1024, 1024, 32, 1>; - GTShape_94586210701264 v65 = GTShape_94586210701264(); - GTStride_94586210701264 v66 = GTStride_94586210701264(); - using GT_94586210701264 = GlobalTensor; - GT_94586210701264 v67 = GT_94586210701264(v64, v65, v66); - TSTORE(v67, v21); + v75 = 0; + v76 = 0; + v77 = 1; + v78 = (unsigned) v13; + v79 = v76 * v78; + v80 = v75 + v79; + v81 = 0; + v82 = 1; + v83 = (unsigned) v11; + v84 = v81 * v83; + v85 = v80 + v84; + v86 = v1 + v85; + v87 = pto::Shape<1, 1, 1, 32, 32>(); + v88 = pto::Stride<1024, 1024, 1024, 32, 1>(); + v89 = GlobalTensor, pto::Stride<1024, 1024, 1024, 32, 1>, pto::Layout::ND>(v86, v87, v88); + TSTORE(v89, v22); return; } - diff --git a/test/samples/MatMul/tmatmulk.pto b/test/samples/MatMul/tmatmulk.pto index fa8ca9c0..b54b4ba1 100644 --- a/test/samples/MatMul/tmatmulk.pto +++ b/test/samples/MatMul/tmatmulk.pto @@ -10,10 +10,10 @@ module attributes {"pto.device-spec" = "Ascend910B1"} { %c8 = arith.constant 8 : index %c32_3 = arith.constant 32 : index %c32_4 = arith.constant 32 : index - %0 = pto.make_tensor_view %arg1, shape = [%c32, %c256] strides = [%c256, %c1] : !pto.tensor_view<2xf32> - %1 = pto.make_tensor_view %arg2, shape = [%c256, %c32_1] strides = [%c32_1, %c1] : !pto.tensor_view<2xf32> - %2 = pto.make_tensor_view %arg0, shape = [%c32, %c32_1] strides = [%c32_1, %c1] : !pto.tensor_view<2xf32> - %3 = pto.make_tensor_view %arg3, shape = [%c1_0, %c32_1] strides = [%c32_1, %c1] : !pto.tensor_view<2xf32> + %0 = pto.make_tensor_view %arg1, shape = [%c32, %c256] strides = [%c256, %c1] : !pto.tensor_view + %1 = pto.make_tensor_view %arg2, shape = [%c256, %c32_1] strides = [%c32_1, %c1] : !pto.tensor_view + %2 = pto.make_tensor_view %arg0, shape = [%c32, %c32_1] strides = [%c32_1, %c1] : !pto.tensor_view + %3 = pto.make_tensor_view %arg3, shape = [%c1_0, %c32_1] strides = [%c32_1, %c1] : !pto.tensor_view %4 = pto.alloc_tile : !pto.tile_buf %5 = pto.alloc_tile : !pto.tile_buf %6 = pto.alloc_tile : !pto.tile_buf @@ -23,41 +23,41 @@ module attributes {"pto.device-spec" = "Ascend910B1"} { %10 = pto.alloc_tile : !pto.tile_buf scf.for %arg5 = %c0 to %c8 step %c1 { %12 = arith.muli %arg5, %c32_2 : index - %13 = pto.partition_view %0, offsets = [%c0, %12], sizes = [%c32_3, %c32_2] : !pto.tensor_view<2xf32> -> !pto.partition_tensor_view<32x32xf32> - %14 = pto.partition_view %1, offsets = [%12, %c0], sizes = [%c32_2, %c32_4] : !pto.tensor_view<2xf32> -> !pto.partition_tensor_view<32x32xf32> - %15 = pto.partition_view %3, offsets = [%c0, %c0], sizes = [%c1_0, %c32_4] : !pto.tensor_view<2xf32> -> !pto.partition_tensor_view<1x32xf32> + %13 = pto.partition_view %0, offsets = [%c0, %12], sizes = [%c32_3, %c32_2] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %14 = pto.partition_view %1, offsets = [%12, %c0], sizes = [%c32_2, %c32_4] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %15 = pto.partition_view %3, offsets = [%c0, %c0], sizes = [%c1_0, %c32_4] : !pto.tensor_view -> !pto.partition_tensor_view<1x32xf32> pto.tload ins(%13 : !pto.partition_tensor_view<32x32xf32>) outs(%4 : !pto.tile_buf) pto.tload ins(%14 : !pto.partition_tensor_view<32x32xf32>) outs(%5 : !pto.tile_buf) scf.if %arg4 { pto.tload ins(%15 : !pto.partition_tensor_view<1x32xf32>) outs(%6 : !pto.tile_buf) } else { } - pto.set_flag[, , ] - pto.wait_flag[, , ] + pto.record_event[, , ] + pto.wait_event[, , ] pto.tmov ins(%4 : !pto.tile_buf) outs(%7 : !pto.tile_buf) pto.tmov ins(%5 : !pto.tile_buf) outs(%8 : !pto.tile_buf) scf.if %arg4 { pto.tmov ins(%6 : !pto.tile_buf) outs(%10 : !pto.tile_buf) } else { } - pto.set_flag[, , ] - pto.wait_flag[, , ] + pto.record_event[, , ] + pto.wait_event[, , ] %16 = arith.cmpi eq, %arg5, %c0 : index scf.if %16 { scf.if %arg4 { - pto.tmatmul.bias ins(%7, %8, %10 : , , !pto.tile_buf) outs(%9 : !pto.tile_buf) + pto.tmatmul.bias ins(%7, %8, %10 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } else { pto.tmatmul ins(%7, %8 : !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } } else { pto.tmatmul.acc ins(%9, %7, %8 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) } - pto.set_flag[, , ] - pto.wait_flag[, , ] + pto.record_event[, , ] + pto.wait_event[, , ] } - pto.set_flag[, , ] - pto.wait_flag[, , ] - %11 = pto.partition_view %2, offsets = [%c0, %c0], sizes = [%c32_3, %c32_4] : !pto.tensor_view<2xf32> -> !pto.partition_tensor_view<32x32xf32> + pto.record_event[, , ] + pto.wait_event[, , ] + %11 = pto.partition_view %2, offsets = [%c0, %c0], sizes = [%c32_3, %c32_4] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> pto.tstore ins(%9 : !pto.tile_buf) outs(%11 : !pto.partition_tensor_view<32x32xf32>) return }