diff --git a/docs/dev/codegen/00-pto_codegen.md b/docs/dev/codegen/00-pto_codegen.md index 4ae097aa..23da3c71 100644 --- a/docs/dev/codegen/00-pto_codegen.md +++ b/docs/dev/codegen/00-pto_codegen.md @@ -8,9 +8,10 @@ The PTO Codegen (`PTOCodegen`) generates MLIR code in PTO-ISA dialect from PyPTO - **Automatic MLIR Generation**: Converts PyPTO IR to PTO-ISA MLIR dialect - **Structured Code Generation**: Outputs constants, tensor views, allocations in order -- **Implicit Lowering**: Automatically generates `pto.subview` from `block.load`/`block.store` +- **Implicit Lowering**: Automatically generates `pto.partition_view` from `block.load`/`block.store` - **MemRef-based Allocation**: Maps IR MemRef objects to `pto.alloc_tile` operations -- **Type-aware Conversion**: Handles TensorType, TileType, ScalarType appropriately +- **Type-aware Conversion**: Derives tile_buf/tensor_view types from TileType metadata +- **PTOAS Type Annotations**: Emits typed `ins`/`outs` clauses for all operations ### Generation Order @@ -25,18 +26,33 @@ The codegen generates MLIR in the following fixed order: ### Class Structure -**Header**: `include/pypto/codegen/pto_codegen.h` +**Header**: `include/pypto/codegen/pto/pto_codegen.h` ```cpp namespace pypto::codegen { -class PTOCodegen { +class PTOCodegen : public CodegenBase { public: - PTOCodegen() = default; - ~PTOCodegen() = default; + PTOCodegen(); + explicit PTOCodegen(const backend::Backend* backend); - // Generate PTO-ISA MLIR from program std::string Generate(const ir::ProgramPtr& program); + + // CodegenBase interface + std::string GetCurrentResultTarget() const override; + void Emit(const std::string& line) override; + std::string GetExprAsCode(const ir::ExprPtr& expr) override; + std::string GetTypeString(const DataType& dtype) const override; + + // PTO-specific helpers for operator codegen + std::string NewTemp(); + std::string GetOrCreateTensorView(const ir::VarPtr& tensor); + std::string GetIndexConstant(int64_t val); + std::string GetOrEmitFloatConstant(double value, const std::string& mlir_type = "f32"); + std::string GetTensorViewTypeString(const ir::TensorType* tensor_type) const; + std::string GetTileBufTypeString(const ir::MemRef* memref) const; + std::string GetExprTypeAnnotation(const ir::ExprPtr& expr); + std::string GetCurrentResultTileBufTypeString() const; }; } // namespace codegen @@ -44,13 +60,13 @@ class PTOCodegen { ### Implementation Components -**File**: `src/codegen/pto_codegen.cpp` +**File**: `src/codegen/pto/pto_codegen.cpp` | Component | Purpose | | --------- | ------- | -| `PTOMLIRCodegen` | Main visitor class for IR traversal | -| `MemRefCollectorVisitor` | Collects all MemRef objects for allocation | -| Helper functions | `DataTypeToMLIR()`, `MemorySpaceToMLIR()` | +| `PTOCodegen` | Main visitor class (inherits `CodegenBase`) for IR traversal | +| `MemRefCollectorVisitor` | Collects MemRef objects and their associated TileType for allocation | +| Helper functions | `DataTypeToMLIRImpl()`, `MemorySpaceToMLIR()` | ## Python API @@ -95,8 +111,8 @@ print(pto_code) | PyPTO Operation | Generated PTO-ISA | | --------------- | ----------------- | -| `block.load(tensor, [row, col], [h, w])` | `pto.subview` + `pto.tload` | -| `block.store(tile, [row, col], [h, w], tensor)` | `pto.subview` + `pto.tstore` | +| `block.load(tensor, [row, col], [h, w])` | `pto.partition_view` + `pto.tload` | +| `block.store(tile, [row, col], [h, w], tensor)` | `pto.partition_view` + `pto.tstore` | | `block.mul(lhs, rhs)` | `pto.tmul` | | `block.add(a, b, c)` | `pto.taddc` (3-operand add) | | `block.adds(tile, scalar)` | `pto.tadds` (tile + scalar) | @@ -119,7 +135,7 @@ For each `TensorType` parameter, the codegen generates: %0 = pto.make_tensor_view %arg0, shape = [%c32, %c32] strides = [%c32, %c1] - : !pto.tensor_view<2xf32> + : !pto.tensor_view ``` **Key aspects**: @@ -127,21 +143,22 @@ For each `TensorType` parameter, the codegen generates: - Shape from `TensorType.shape_` - Strides computed as row-major: `[dim1, 1]` for 2D tensors - Constants (`%c32`, `%c1`) auto-generated +- Tensor view type uses `?` for each dimension (e.g., `?x?xf32` for 2D) ### Allocation Generation -Based on MemRef objects attached to TileType variables: +Based on MemRef objects attached to TileType variables. The codegen derives tile dimensions and dtype from the associated TileType: ```mlir -%0 = pto.alloc_tile : ``` **MemRef → alloc_tile mapping**: -- Memory space (`MemRef.memory_space_`) → `loc` attribute -- Tile dimensions inferred from usage context +- Memory space (`MemRef.memory_space_`) → `loc` attribute (using PTO address space names) +- Tile dtype and dimensions derived from associated TileType metadata - One allocation per unique MemRef ### Load Operation Transformation @@ -155,21 +172,21 @@ tile_a = pl.load(tensor_a, [0, 0], [32, 32]) **Generated MLIR** (two operations): ```mlir -# 1. Create tile view -%3 = pto.subview %tensor_view, offsets = [%c0, %c0], +# 1. Create partition view +%3 = pto.partition_view %tensor_view, offsets = [%c0, %c0], sizes = [%c32, %c32] - : !pto.tensor_view<2xf32> -> !pto.tile_view<32x32xf32> + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> # 2. Load into tile buffer -pto.tload ins(%3 : !pto.tile_view<32x32xf32>) - outs(%tile_buf : !pto.tile_buf) +pto.tload ins(%3 : !pto.partition_tensor_view<32x32xf32>) + outs(%tile_buf : !pto.tile_buf) ``` **Key transformations**: - Tensor parameter → tensor_view lookup - Offsets/sizes from `block.load` arguments -- Output tile_buf from variable's MemRef +- Output tile_buf from variable's MemRef with type derived from TileType ### Store Operation Transformation @@ -182,14 +199,14 @@ pl.store(tile_c, [0, 0], [32, 32], tensor_out) **Generated MLIR**: ```mlir -# 1. Create tile view for output -%5 = pto.subview %output_view, offsets = [%c0, %c0], +# 1. Create partition view for output +%5 = pto.partition_view %output_view, offsets = [%c0, %c0], sizes = [%c32, %c32] - : !pto.tensor_view<2xf32> -> !pto.tile_view<32x32xf32> + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> # 2. Store from tile buffer -pto.tstore ins(%tile_buf : !pto.tile_buf) - outs(%5 : !pto.tile_view<32x32xf32>) +pto.tstore ins(%tile_buf : !pto.tile_buf) + outs(%5 : !pto.partition_tensor_view<32x32xf32>) ``` ### Compute Operations @@ -214,6 +231,7 @@ pto.tmul ins(%tile_a_buf : !pto.tile_buf<...>, - Result variable's MemRef determines output tile_buf - Input operands resolved through variable name lookup +- All `ins`/`outs` clauses include type annotations ## Complete Example @@ -254,27 +272,27 @@ module { // Tensor views %3 = pto.make_tensor_view %arg0, shape = [%c32, %c32] - strides = [%c32, %c1] : !pto.tensor_view<2xf32> + strides = [%c32, %c1] : !pto.tensor_view %4 = pto.make_tensor_view %arg1, shape = [%c32, %c32] - strides = [%c32, %c1] : !pto.tensor_view<2xf32> + strides = [%c32, %c1] : !pto.tensor_view %5 = pto.make_tensor_view %arg2, shape = [%c32, %c32] - strides = [%c32, %c1] : !pto.tensor_view<2xf32> + strides = [%c32, %c1] : !pto.tensor_view // Allocations - %0 = pto.alloc_tile : - %1 = pto.alloc_tile : - %2 = pto.alloc_tile : + %0 = pto.alloc_tile : !pto.tile_buf + %1 = pto.alloc_tile : !pto.tile_buf + %2 = pto.alloc_tile : !pto.tile_buf // Load tile_a - %6 = pto.subview %3, offsets = [%c0, %c0], sizes = [%c32, %c32] - : !pto.tensor_view<2xf32> -> !pto.tile_view<32x32xf32> - pto.tload ins(%6 : !pto.tile_view<32x32xf32>) + %6 = pto.partition_view %3, offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + pto.tload ins(%6 : !pto.partition_tensor_view<32x32xf32>) outs(%0 : !pto.tile_buf<...>) // Load tile_b - %7 = pto.subview %4, offsets = [%c0, %c0], sizes = [%c32, %c32] - : !pto.tensor_view<2xf32> -> !pto.tile_view<32x32xf32> - pto.tload ins(%7 : !pto.tile_view<32x32xf32>) + %7 = pto.partition_view %4, offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + pto.tload ins(%7 : !pto.partition_tensor_view<32x32xf32>) outs(%1 : !pto.tile_buf<...>) // Multiply @@ -282,10 +300,10 @@ module { outs(%2 : !pto.tile_buf<...>) // Store tile_c - %8 = pto.subview %5, offsets = [%c0, %c0], sizes = [%c32, %c32] - : !pto.tensor_view<2xf32> -> !pto.tile_view<32x32xf32> + %8 = pto.partition_view %5, offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> pto.tstore ins(%2 : !pto.tile_buf<...>) - outs(%8 : !pto.tile_view<32x32xf32>) + outs(%8 : !pto.partition_tensor_view<32x32xf32>) return } @@ -303,6 +321,7 @@ The codegen maintains several mappings to track MLIR variable names: | `var_to_mlir_` | IR variable → MLIR SSA name | `"tile_a"` → `"%0"` | | `tensor_to_view_` | Parameter → tensor_view | `"a"` → `"%3"` | | `memref_to_mlir_` | MemRef pointer → tile_buf | `memref.get()` → `"%0"` | +| `memref_to_tile_type_` | MemRef pointer → TileType | Used for deriving tile_buf types | **SSA value naming**: @@ -324,7 +343,8 @@ The codegen: 2. Resolves `tile_b` → `%1` via `var_to_mlir_` 3. Gets `tile_c`'s MemRef from its TileType 4. Maps MemRef → `%2` via `memref_to_mlir_` -5. Generates: `pto.tmul ins(%0, %1) outs(%2)` +5. Gets tile_buf type from `memref_to_tile_type_` +6. Generates: `pto.tmul ins(%0 : !pto.tile_buf<...>, %1 : !pto.tile_buf<...>) outs(%2 : !pto.tile_buf<...>)` ## Type Conversions @@ -342,42 +362,44 @@ The codegen: ### Memory Space Mapping -| PyPTO MemorySpace | PTO-ISA loc | -| ----------------- | ----------- | -| `MemorySpace::DDR` | `ddr` | -| `MemorySpace::UB` | `ub` (unified buffer) | -| `MemorySpace::L1` | `l1` | -| `MemorySpace::L0A` | `l0a` | -| `MemorySpace::L0B` | `l0b` | -| `MemorySpace::L0C` | `l0c` | +| PyPTO MemorySpace | PTO Address Space | +| ----------------- | ----------------- | +| `MemorySpace::DDR` | `gm` (global memory) | +| `MemorySpace::UB` | `vec` (vector buffer) | +| `MemorySpace::L1` | `mat` (matrix buffer) | +| `MemorySpace::L0A` | `left` | +| `MemorySpace::L0B` | `right` | +| `MemorySpace::L0C` | `acc` (accumulator) | ### Tile Buffer Attributes -Generated `alloc_tile` operations include: +Generated `alloc_tile` operations derive dtype and dimensions from TileType metadata, and layout/fractal/pad from the associated TileView (when available): ```mlir !pto.tile_buf< - loc=ub, // Memory space - dtype=f32, // Element data type - rows=32, // Tile height - cols=32, // Tile width - v_row=32, // Virtual row size - v_col=32, // Virtual column size - blayout=row_major, // Block layout - slayout=none_box, // Sub-layout - fractal=512, // Fractal parameter - pad=0 // Padding + loc=vec, // PTO address space (from MemorySpace) + dtype=f32, // Element data type (from TileType) + rows=32, // Tile height (from TileType shape) + cols=32, // Tile width (from TileType shape) + v_row=32, // Virtual row size (= rows) + v_col=32, // Virtual column size (= cols) + blayout=row_major, // Block layout (from TileView, default: row_major) + slayout=none_box, // Scatter layout (from TileView, default: none_box) + fractal=512, // Fractal size (from TileView, default: 512) + pad=0 // Pad mode as int (from TileView, default: 0/null) > ``` -## Limitations and Future Work +**TileView-derived attributes**: -### Current Limitations +| Attribute | Source | Enum Values | Default | +| --------- | ------ | ----------- | ------- | +| `blayout` | `TileView::blayout` | `none_box`, `row_major`, `col_major` | `row_major` | +| `slayout` | `TileView::slayout` | `none_box`, `row_major`, `col_major` | `none_box` | +| `fractal` | `TileView::fractal` | uint64 | `512` | +| `pad` | `TileView::pad` | `null(0)`, `zero(1)`, `max(2)`, `min(3)` | `null(0)` | -1. **Fixed Tile Attributes**: `rows`, `cols`, `blayout` etc. use default values (32x32, row_major) -2. **2D Tensors Only**: Shape/stride generation assumes 2D tensors -3. **Single Memory Space**: All allocations use `ub` (unified buffer) by default -4. **Limited Operations**: Only basic block operations supported +When no TileView is associated with the MemRef, the codegen falls back to the default values listed above. ## See Also diff --git a/include/pypto/codegen/pto/pto_codegen.h b/include/pypto/codegen/pto/pto_codegen.h index 870b1cba..cd434347 100644 --- a/include/pypto/codegen/pto/pto_codegen.h +++ b/include/pypto/codegen/pto/pto_codegen.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -38,7 +39,7 @@ namespace codegen { * * Generates PTO-ISA MLIR format code from PyPTO IR Program. * Traverses the IR using the visitor pattern (aligned with CCECodegen). - * Automatically generates make_tensor_view, subview, and alloc_tile instructions. + * Automatically generates make_tensor_view, partition_view, and alloc_tile instructions. */ class PTOCodegen : public CodegenBase { public: @@ -102,6 +103,26 @@ class PTOCodegen : public CodegenBase { */ std::string GetOrEmitFloatConstant(double value, const std::string& mlir_type = "f32"); + /** + * @brief Get tensor_view type string for a TensorType (e.g., "!pto.tensor_view") + */ + std::string GetTensorViewTypeString(const ir::TensorType* tensor_type) const; + + /** + * @brief Get tile_buf type string for a MemRef (e.g., "!pto.tile_buf") + */ + std::string GetTileBufTypeString(const ir::MemRef* memref) const; + + /** + * @brief Get type annotation for an expression (for ins/outs clauses) + */ + std::string GetExprTypeAnnotation(const ir::ExprPtr& expr); + + /** + * @brief Get tile_buf type string for the current assignment result target + */ + std::string GetCurrentResultTileBufTypeString() const; + protected: // Override visitor methods for code generation - Statements void VisitStmt_(const ir::AssignStmtPtr& op) override; @@ -156,6 +177,7 @@ class PTOCodegen : public CodegenBase { std::map tensor_to_view_; std::map memref_to_mlir_; std::map var_to_memref_; + std::map> memref_to_tile_type_; std::set emitted_constants_; std::set emitted_float_constants_; std::map float_const_names_; @@ -165,6 +187,7 @@ class PTOCodegen : public CodegenBase { // Current function context ir::FunctionPtr current_function_; std::string current_result_buf_; + std::shared_ptr current_result_tile_type_; const backend::Backend* backend_; ///< Backend instance for querying op info }; diff --git a/include/pypto/ir/scalar_expr.h b/include/pypto/ir/scalar_expr.h index 8f38e44d..8abb471d 100644 --- a/include/pypto/ir/scalar_expr.h +++ b/include/pypto/ir/scalar_expr.h @@ -388,7 +388,9 @@ inline BinaryOperands PromoteBinaryOperands(const ExprPtr& left, const ExprPtr& DataType left_dtype = GetScalarDtype(left); DataType right_dtype = GetScalarDtype(right); DataType promoted_dtype = PromoteSameCategoryDtype(left_dtype, right_dtype, op_name); - return {MaybeCast(left, promoted_dtype, span), MaybeCast(right, promoted_dtype, span), promoted_dtype}; + ExprPtr promoted_left = MaybeCast(left, promoted_dtype, span); + ExprPtr promoted_right = MaybeCast(right, promoted_dtype, span); + return {std::move(promoted_left), std::move(promoted_right), promoted_dtype}; } inline BinaryOperands PromoteIntBinaryOperands(const ExprPtr& left, const ExprPtr& right, @@ -400,7 +402,9 @@ inline BinaryOperands PromoteIntBinaryOperands(const ExprPtr& left, const ExprPt " and " + right_dtype.ToString()); } DataType promoted_dtype = PromoteSameCategoryDtype(left_dtype, right_dtype, op_name); - return {MaybeCast(left, promoted_dtype, span), MaybeCast(right, promoted_dtype, span), promoted_dtype}; + ExprPtr promoted_left = MaybeCast(left, promoted_dtype, span); + ExprPtr promoted_right = MaybeCast(right, promoted_dtype, span); + return {std::move(promoted_left), std::move(promoted_right), promoted_dtype}; } // ========== Binary Operator Construction Functions ========== diff --git a/src/backend/910B_PTO/backend_910b_pto_ops.cpp b/src/backend/910B_PTO/backend_910b_pto_ops.cpp index 8396f740..10b2fa15 100644 --- a/src/backend/910B_PTO/backend_910b_pto_ops.cpp +++ b/src/backend/910B_PTO/backend_910b_pto_ops.cpp @@ -42,11 +42,13 @@ const std::vector cmp_modes = {"EQ", "NE", "LT", "LE", "GT", "GE"}; const std::vector round_modes = {"NONE", "RINT", "ROUND", "FLOOR", "CEIL", "TRUNC", "ODD", "CAST_RINT"}; -// Helper function for input & output generation +// Helper function for input & output generation (with type annotations) static std::string GenerateInsOutsClause(const CallPtr& op, codegen::PTOCodegen& codegen, const std::string& config_attr = "") { size_t args_num = op->args_.size(); std::ostringstream oss; + + // Build ins clause with operand names oss << "ins("; for (size_t input_idx = 0; input_idx < args_num; ++input_idx) { std::string operand = codegen.GetExprAsCode(op->args_[input_idx]); @@ -56,10 +58,32 @@ static std::string GenerateInsOutsClause(const CallPtr& op, codegen::PTOCodegen& oss << ", " << operand; } } + + // Add type annotations after colon + std::string type_annot; + for (size_t input_idx = 0; input_idx < args_num; ++input_idx) { + std::string annot = codegen.GetExprTypeAnnotation(op->args_[input_idx]); + if (!annot.empty()) { + if (!type_annot.empty()) type_annot += ", "; + type_annot += annot; + } + } + if (!type_annot.empty()) { + oss << " : " << type_annot; + } + if (!config_attr.empty()) { oss << config_attr; } - oss << ") outs(" << codegen.GetCurrentResultTarget() << ")"; + + // Build outs clause with type annotation + std::string result_target = codegen.GetCurrentResultTarget(); + std::string result_type = codegen.GetCurrentResultTileBufTypeString(); + oss << ") outs(" << result_target; + if (!result_type.empty()) { + oss << " : " << result_type; + } + oss << ")"; return oss.str(); } @@ -136,15 +160,22 @@ static std::string MakeTileCvtCodegenPTO(const std::string& pto_op_name, const C static std::string MakeFullCodegenPTO(const std::string& pto_op_name, const CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); - CHECK(op->args_.size() == 2) << "full op requires 3 arguments." - << op->args_.size(); // Actually 2 args, two of them are conbined! + CHECK(op->args_.size() == 2) << "full op requires 2 arguments, got " << op->args_.size(); std::string scalar = codegen.GetExprAsCode(op->args_[1]); + std::string scalar_type = codegen.GetExprTypeAnnotation(op->args_[1]); std::string dst = codegen.GetCurrentResultTarget(); - codegen.Emit(pto_op_name + " " + "ins(" + scalar + ") outs(" + dst + ")"); + std::string dst_type = codegen.GetCurrentResultTileBufTypeString(); + std::ostringstream oss; + oss << pto_op_name << " ins(" << scalar; + if (!scalar_type.empty()) oss << " : " << scalar_type; + oss << ") outs(" << dst; + if (!dst_type.empty()) oss << " : " << dst_type; + oss << ")"; + codegen.Emit(oss.str()); return ""; } -// block.load: emit pto.subview + pto.tload (same format as original IR layer codegen) +// block.load: emit pto.partition_view + pto.tload static std::string MakeBlockLoadCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); auto tensor = As(op->args_[0]); @@ -172,29 +203,29 @@ static std::string MakeBlockLoadCodegenPTO(const CallPtr& op, codegen::CodegenBa std::string tile_buf = codegen.GetCurrentResultTarget(); INTERNAL_CHECK(!tile_buf.empty()) << "block.load requires assignment target (tile_buf)"; - std::string tile_view = codegen.NewTemp(); - std::ostringstream subview_line; - subview_line << tile_view << " = pto.subview " << tensor_view; - subview_line << ", offsets = [" << codegen.GetIndexConstant(row_off) << ", "; - subview_line << codegen.GetIndexConstant(col_off) << "]"; - subview_line << ", sizes = [" << codegen.GetIndexConstant(height) << ", "; - subview_line << codegen.GetIndexConstant(width) << "]"; - subview_line << " : !pto.tensor_view<2x" << dtype_str << "> -> !pto.tile_view<"; - subview_line << height << "x" << width << "x" << dtype_str << ">"; - codegen.Emit(subview_line.str()); + std::string tensor_view_type = codegen.GetTensorViewTypeString(tensor_type.get()); + std::string tile_buf_type = codegen.GetCurrentResultTileBufTypeString(); + std::string partition_type = "!pto.partition_tensor_view<" + std::to_string(height) + "x" + + std::to_string(width) + "x" + dtype_str + ">"; + + std::string partition_view = codegen.NewTemp(); + std::ostringstream partition_line; + partition_line << partition_view << " = pto.partition_view " << tensor_view; + partition_line << ", offsets = [" << codegen.GetIndexConstant(row_off) << ", "; + partition_line << codegen.GetIndexConstant(col_off) << "]"; + partition_line << ", sizes = [" << codegen.GetIndexConstant(height) << ", "; + partition_line << codegen.GetIndexConstant(width) << "]"; + partition_line << " : " << tensor_view_type << " -> " << partition_type; + codegen.Emit(partition_line.str()); std::ostringstream tload_line; - tload_line << "pto.tload ins(" << tile_view; - tload_line << " : !pto.tile_view<" << height << "x" << width << "x" << dtype_str << ">) outs("; - tload_line << tile_buf << " : !pto.tile_buf)"; + tload_line << "pto.tload ins(" << partition_view << " : " << partition_type << ") outs("; + tload_line << tile_buf << " : " << tile_buf_type << ")"; codegen.Emit(tload_line.str()); return ""; // Multi-line emission } -// block.store: emit pto.subview + pto.tstore (same format as original IR layer codegen) +// block.store: emit pto.partition_view + pto.tstore static std::string MakeBlockStoreCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); auto tile = As(op->args_[0]); @@ -222,24 +253,35 @@ static std::string MakeBlockStoreCodegenPTO(const CallPtr& op, codegen::CodegenB std::string dtype_str = codegen.GetTypeString(tensor_type->dtype_); std::string tensor_view = codegen.GetOrCreateTensorView(output_tensor); std::string tile_buf = codegen.GetVarName(tile); - std::string tile_view = codegen.NewTemp(); - - std::ostringstream subview_line; - subview_line << tile_view << " = pto.subview " << tensor_view; - subview_line << ", offsets = [" << codegen.GetIndexConstant(row_off) << ", "; - subview_line << codegen.GetIndexConstant(col_off) << "]"; - subview_line << ", sizes = [" << codegen.GetIndexConstant(height) << ", "; - subview_line << codegen.GetIndexConstant(width) << "]"; - subview_line << " : !pto.tensor_view<2x" << dtype_str << "> -> !pto.tile_view<"; - subview_line << height << "x" << width << "x" << dtype_str << ">"; - codegen.Emit(subview_line.str()); + + std::string tensor_view_type = codegen.GetTensorViewTypeString(tensor_type.get()); + std::string partition_type = "!pto.partition_tensor_view<" + std::to_string(height) + "x" + + std::to_string(width) + "x" + dtype_str + ">"; + + // Get tile_buf type from the tile variable's TileType + std::string tile_buf_type; + if (auto tile_type = As(tile->GetType())) { + if (tile_type->memref_.has_value()) { + tile_buf_type = codegen.GetTileBufTypeString(tile_type->memref_.value().get()); + } + } + + std::string partition_view = codegen.NewTemp(); + std::ostringstream partition_line; + partition_line << partition_view << " = pto.partition_view " << tensor_view; + partition_line << ", offsets = [" << codegen.GetIndexConstant(row_off) << ", "; + partition_line << codegen.GetIndexConstant(col_off) << "]"; + partition_line << ", sizes = [" << codegen.GetIndexConstant(height) << ", "; + partition_line << codegen.GetIndexConstant(width) << "]"; + partition_line << " : " << tensor_view_type << " -> " << partition_type; + codegen.Emit(partition_line.str()); std::ostringstream tstore_line; tstore_line << "pto.tstore ins(" << tile_buf; - tstore_line << " : !pto.tile_buf) outs("; - tstore_line << tile_view << " : !pto.tile_view<" << height << "x" << width << "x" << dtype_str << ">)"; + if (!tile_buf_type.empty()) { + tstore_line << " : " << tile_buf_type; + } + tstore_line << ") outs(" << partition_view << " : " << partition_type << ")"; codegen.Emit(tstore_line.str()); return ""; // Multi-line emission } diff --git a/src/codegen/pto/pto_codegen.cpp b/src/codegen/pto/pto_codegen.cpp index a288c14b..4c556e8b 100644 --- a/src/codegen/pto/pto_codegen.cpp +++ b/src/codegen/pto/pto_codegen.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -73,20 +74,20 @@ static std::string DataTypeToMLIRImpl(::pypto::DataType dtype) { } } -// Helper function to convert MemorySpace to MLIR string +// Helper function to convert MemorySpace to PTO address space string static std::string MemorySpaceToMLIR(ir::MemorySpace space) { if (space == ir::MemorySpace::DDR) { - return "ddr"; + return "gm"; } else if (space == ir::MemorySpace::UB) { - return "ub"; + return "vec"; } else if (space == ir::MemorySpace::L1) { - return "l1"; + return "mat"; } else if (space == ir::MemorySpace::L0A) { - return "l0a"; + return "left"; } else if (space == ir::MemorySpace::L0B) { - return "l0b"; + return "right"; } else if (space == ir::MemorySpace::L0C) { - return "l0c"; + return "acc"; } else { throw pypto::ValueError("Invalid MemorySpace value"); } @@ -98,30 +99,36 @@ class MemRefCollectorVisitor : public ir::IRVisitor { MemRefCollectorVisitor() = default; [[nodiscard]] const std::vector& GetMemRefs() const { return memrefs_; } + [[nodiscard]] const std::map>& GetMemRefTileTypes() + const { + return memref_tile_types_; + } void VisitExpr_(const VarPtr& op) override { auto tile_type = As(op->GetType()); if (tile_type && tile_type->memref_.has_value()) { - AddMemRefIfUnique(tile_type->memref_.value()); + AddMemRefIfUnique(tile_type->memref_.value(), tile_type); } } void VisitExpr_(const ir::IterArgPtr& op) override { auto tile_type = As(op->GetType()); if (tile_type && tile_type->memref_.has_value()) { - AddMemRefIfUnique(tile_type->memref_.value()); + AddMemRefIfUnique(tile_type->memref_.value(), tile_type); } } private: std::vector memrefs_; std::set seen_ptrs_; + std::map> memref_tile_types_; - void AddMemRefIfUnique(const MemRefPtr& memref) { + void AddMemRefIfUnique(const MemRefPtr& memref, const std::shared_ptr& tile_type) { const ir::MemRef* raw_ptr = memref.get(); if (seen_ptrs_.find(raw_ptr) == seen_ptrs_.end()) { memrefs_.push_back(memref); seen_ptrs_.insert(raw_ptr); + memref_tile_types_[raw_ptr] = tile_type; } } }; @@ -176,6 +183,7 @@ void PTOCodegen::GenerateFunction(const FunctionPtr& func) { tensor_to_view_.clear(); memref_to_mlir_.clear(); var_to_memref_.clear(); + memref_to_tile_type_.clear(); emitted_constants_.clear(); emitted_float_constants_.clear(); float_const_names_.clear(); @@ -195,6 +203,7 @@ void PTOCodegen::GenerateFunction(const FunctionPtr& func) { std::string tile_buf = NewTemp(); memref_to_mlir_[memref.get()] = tile_buf; } + memref_to_tile_type_ = collector.GetMemRefTileTypes(); stream_ << " func.func @" << func->name_ << "("; @@ -314,8 +323,12 @@ void PTOCodegen::EmitMakeTensorViews(const FunctionPtr& func) { } stream_ << "]"; - stream_ << " : !pto.tensor_view<" << tensor_type->shape_.size() << "x"; - stream_ << GetTypeString(tensor_type->dtype_) << ">\n"; + stream_ << " : !pto.tensor_view<"; + for (size_t j = 0; j < tensor_type->shape_.size(); j++) { + if (j > 0) stream_ << "x"; + stream_ << "?"; + } + stream_ << "x" << GetTypeString(tensor_type->dtype_) << ">\n"; } } } @@ -324,11 +337,8 @@ void PTOCodegen::EmitAllocTiles(const ir::FunctionPtr& func, const std::vectormemory_space_); - - stream_ << GetIndent() << tile_buf << " = pto.alloc_tile : \n"; + stream_ << GetIndent() << tile_buf << " = pto.alloc_tile : " << GetTileBufTypeString(memref.get()) + << "\n"; } } @@ -361,14 +371,18 @@ void PTOCodegen::VisitStmt_(const AssignStmtPtr& op) { if (auto call = As(op->value_)) { if (backend_ != nullptr && backend_->GetOpInfo(call->op_->name_) != nullptr) { std::string result_buf; + std::shared_ptr result_tile_type; if (auto tile_type = As(op->var_->GetType())) { if (tile_type->memref_.has_value()) { result_buf = GetTileBufForMemRef(tile_type->memref_.value()); } + result_tile_type = tile_type; } current_result_buf_ = result_buf; + current_result_tile_type_ = result_tile_type; VisitExpr(op->value_); current_result_buf_.clear(); + current_result_tile_type_ = nullptr; return; } } @@ -472,5 +486,109 @@ std::string PTOCodegen::GetOrEmitFloatConstant(double value, const std::string& return float_const_names_[value]; } +std::string PTOCodegen::GetTensorViewTypeString(const ir::TensorType* tensor_type) const { + std::ostringstream oss; + oss << "!pto.tensor_view<"; + for (size_t i = 0; i < tensor_type->shape_.size(); i++) { + if (i > 0) oss << "x"; + oss << "?"; + } + oss << "x" << GetTypeString(tensor_type->dtype_) << ">"; + return oss.str(); +} + +std::string PTOCodegen::GetTileBufTypeString(const ir::MemRef* memref) const { + std::string loc = MemorySpaceToMLIR(memref->memory_space_); + + // Get dtype and dimensions from the associated TileType + std::string dtype_str = "f32"; + int64_t rows = 32; + int64_t cols = 32; + + // Extract blayout, slayout, fractal, pad from TileView if available, otherwise use defaults + ir::TileLayout blayout = ir::TileLayout::row_major; + ir::TileLayout slayout = ir::TileLayout::none_box; + uint64_t fractal = 512; + ir::TilePad pad = ir::TilePad::null; + + auto tile_it = memref_to_tile_type_.find(memref); + if (tile_it != memref_to_tile_type_.end()) { + const auto& tile_type = tile_it->second; + dtype_str = GetTypeString(tile_type->dtype_); + if (tile_type->shape_.size() >= 2) { + if (auto c0 = As(tile_type->shape_[0])) rows = c0->value_; + if (auto c1 = As(tile_type->shape_[1])) cols = c1->value_; + } else if (tile_type->shape_.size() == 1) { + if (auto c0 = As(tile_type->shape_[0])) { + rows = 1; + cols = c0->value_; + } + } + if (tile_type->tile_view_.has_value()) { + const auto& tv = *tile_type->tile_view_; + blayout = tv.blayout; + slayout = tv.slayout; + fractal = tv.fractal; + pad = tv.pad; + } + } + + auto layout_to_str = [](ir::TileLayout layout) -> const char* { + switch (layout) { + case ir::TileLayout::none_box: + return "none_box"; + case ir::TileLayout::row_major: + return "row_major"; + case ir::TileLayout::col_major: + return "col_major"; + } + return "row_major"; + }; + + std::ostringstream oss; + oss << "!pto.tile_buf"; + return oss.str(); +} + +std::string PTOCodegen::GetExprTypeAnnotation(const ir::ExprPtr& expr) { + if (auto var = As(expr)) { + // Check if this variable maps to a tile buffer via memref + auto memref_it = var_to_memref_.find(var->name_); + if (memref_it != var_to_memref_.end()) { + return GetTileBufTypeString(memref_it->second); + } + // Check if this is a scalar parameter + if (auto scalar_type = As(var->GetType())) { + return GetTypeString(scalar_type->dtype_); + } + // Check if variable has TileType with memref + if (auto tile_type = As(var->GetType())) { + if (tile_type->memref_.has_value()) { + return GetTileBufTypeString(tile_type->memref_.value().get()); + } + } + } + if (auto const_float = As(expr)) { + return "f32"; + } + if (auto const_int = As(expr)) { + return "index"; + } + return ""; +} + +std::string PTOCodegen::GetCurrentResultTileBufTypeString() const { + if (current_result_tile_type_ && current_result_tile_type_->memref_.has_value()) { + return GetTileBufTypeString(current_result_tile_type_->memref_.value().get()); + } + return ""; +} + } // namespace codegen } // namespace pypto diff --git a/tests/ut/codegen/test_pto_codegen.py b/tests/ut/codegen/test_pto_codegen.py index ff8b7fdb..91bec444 100644 --- a/tests/ut/codegen/test_pto_codegen.py +++ b/tests/ut/codegen/test_pto_codegen.py @@ -95,7 +95,7 @@ def tensor_param_func( assert "pto.make_tensor_view" in mlir_code assert "shape = [%c64, %c64]" in mlir_code or "shape = [%c32, %c32]" in mlir_code assert "strides = " in mlir_code - assert "!pto.tensor_view<2xf32>" in mlir_code + assert "!pto.tensor_view" in mlir_code def test_pto_codegen_alloc_tile(): @@ -120,13 +120,13 @@ def alloc_test(self, a: pl.Tensor[[32, 32], pl.FP32], b: pl.Tensor[[32, 32], pl. # Verify alloc_tile operations assert "pto.alloc_tile" in mlir_code - assert "loc=ub" in mlir_code # Unified buffer + assert "loc=vec" in mlir_code # Vector buffer (PTO address space) assert "dtype=f32" in mlir_code assert "rows=32, cols=32" in mlir_code def test_pto_codegen_block_load_lowering(): - """Test that block.load generates subview + tload.""" + """Test that block.load generates partition_view + tload.""" @pl.program class LoadProgram: @@ -141,11 +141,11 @@ def load_test(self, input: pl.Tensor[[64, 64], pl.FP32], output: pl.Tensor[[64, codegen = PTOCodegen() mlir_code = _get_mlir_code(codegen.generate(transformed_program)) - # Verify subview generation - assert "pto.subview" in mlir_code + # Verify partition_view generation + assert "pto.partition_view" in mlir_code assert "offsets = [%c0, %c0]" in mlir_code assert "sizes = [%c32, %c32]" in mlir_code - assert "!pto.tile_view<32x32xf32>" in mlir_code + assert "!pto.partition_tensor_view<32x32xf32>" in mlir_code # Verify tload generation assert "pto.tload" in mlir_code @@ -155,7 +155,7 @@ def load_test(self, input: pl.Tensor[[64, 64], pl.FP32], output: pl.Tensor[[64, def test_pto_codegen_block_store_lowering(): - """Test that block.store generates subview + tstore.""" + """Test that block.store generates partition_view + tstore.""" @pl.program class StoreProgram: diff --git a/tests/ut/codegen/test_pto_codegen_ops.py b/tests/ut/codegen/test_pto_codegen_ops.py index 18ff29e8..cd3fb3ec 100644 --- a/tests/ut/codegen/test_pto_codegen_ops.py +++ b/tests/ut/codegen/test_pto_codegen_ops.py @@ -142,7 +142,9 @@ def validate_kernel_codegen(kernel_name: str, mlir_code: str) -> None: # Validate memory operations are present assert "pto.tload" in mlir_code, f"Kernel {kernel_name} should contain pto.tload operation" assert "pto.tstore" in mlir_code, f"Kernel {kernel_name} should contain pto.tstore operation" - assert "pto.subview" in mlir_code, f"Kernel {kernel_name} should contain pto.subview operation" + assert "pto.partition_view" in mlir_code, ( + f"Kernel {kernel_name} should contain pto.partition_view operation" + ) # Category-specific validations if category == "binary_tile_tile":