From 1b27b4c82ce050eb367f019770c2a075f58ca342 Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Sun, 15 Jun 2025 09:28:09 +0000 Subject: [PATCH 01/62] [Frontend/Fusion] Prologue fusion implementation working --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 106 +++++++- .../mlir/mlir_codegen_backend.py | 1 - PyTorchSimFrontend/mlir/mlir_common.py | 3 + PyTorchSimFrontend/mlir/mlir_conv_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 33 ++- .../mlir/mlir_maxpool_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_scheduling.py | 86 ++++++- PyTorchSimFrontend/mlir/mlir_template.py | 238 +++++++++++++++--- tests/Fusion/test_prologue_fusion.py | 81 ++++++ 9 files changed, 503 insertions(+), 49 deletions(-) create mode 100644 tests/Fusion/test_prologue_fusion.py diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index d6917cad..85631adb 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -83,8 +83,77 @@ } """ +BMM_PROLOGUE_TEMPLATE = r""" +// BMM Prologue kernel +// BATCH = {{ B }} +// M = {{ M }} +// N = {{ N }} +// K = {{ K }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// SUB_TILE_M = {{ SUB_TILE_M }} +// SUB_TILE_N = {{ SUB_TILE_N }} +#map0 = affine_map<(d0, d1, d2) -> ({{ X_map }})> +#map1 = affine_map<(d0, d1, d2) -> ({{ W_map }})> +#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> +memref.global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> +memref.global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> +memref.global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { + %c_mvin = arith.constant 2 : index + %c_mvin2 = arith.constant 1 : index{% if Bias %} + %c_mvin3 = arith.constant 14 : index{% endif %} + %c_mvout = arith.constant 3 : index + %vstride = arith.constant 1 : index + %axis = arith.constant 2 : index + %X_buffer = memref.get_global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer = memref.get_global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer = memref.get_global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %tag = memref.alloc() : memref<1xi32> + %tag0 = memref.alloc() : memref<1xi32> + %tag1 = memref.alloc() : memref<1xi32> + %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + %c0 = arith.constant 0 : index +{{ kernel.def_local_vars() }} + affine.for %b=0 to {{ B }} { + affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { + affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + + %index2 = affine.apply #map2(%b, %t_m, %t_n) + {% if Bias -%} + memref.dma_start %Bias[ + {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} + ], %Y_buffer2D[0, 0], %c_mvin3, %tag0[%c0], % + {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} + , %vstride : memref< + {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} + xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1 , {{ TILE_M }}] } + {%- else -%} + affine.vector_store %v0, %Y_buffer2D[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { + %index0 = affine.apply #map0(%b, %t_m, %t_k) + %index1 = affine.apply #map1(%b, %t_k, %t_n) + {{kernel.prepare_input(indent_size=10)}} + linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) + } { accumulation_loop=true } + memref.dma_start %Y_buffer[%c0, %c0, %c0], %Y[%index2], %c_mvout, %tag[%c0], %axis, %vstride : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ B * M * N }}xf32>, memref<1xi32> { padding=0, sram_stride=[1, 1, {{ TILE_M }}] } + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + BMM_REDUCTION_TEMPLATE = r""" -// BMM kernel +// BMM Reduction kernel // BATCH = {{ B }} // M = {{ M }} // N = {{ N }} @@ -166,6 +235,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node @@ -192,13 +262,17 @@ def render(self, TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K if TILE_K < kernel.vector_lane else kernel.vector_lane + TILE_K = TILE_K // 2 if prologue_nodes else TILE_K + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane if n_extra_node==1 and epilogue_nodes[0].is_reduction(): template = BMM_REDUCTION_TEMPLATE nr_rdim = 1 + elif prologue_nodes: + template = BMM_PROLOGUE_TEMPLATE + nr_rdim = 0 else: template = BMM_TEMPLATE nr_rdim = 0 @@ -229,7 +303,29 @@ def render(self, input_reorder = self.input_reorder ) - kernel.store_info = dict( + kernel.prologue_info = dict ( + input_sram_var = "X_buffer2D", + input_dram_var = "X", + input_index_var = "index0", + input_tag_var = "tag1", + input_numel = B * M * K, + input_tile_size = (TILE_M, TILE_K), + input_sram_stride = [1, TILE_M], + input_subtile_size = (SUB_TILE_M, SUB_TILE_K), + weight_sram_var = "W_buffer2D", + weight_dram_var = "W", + weight_index_var = "index1", + weight_tag_var = "tag2", + weight_numel = B * K * N, + weight_tile_size = (TILE_K, TILE_N), + weight_sram_stride = [1, TILE_K], + weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), + tile_size = (TILE_M, TILE_K), + vlane_split_axis = 1, + vlane_stride = 1, + is_bmm = True, + ) + kernel.epilogue_info = dict( output_node = self.output_node.name, dependent_buf = [], sram_var = "Y_buffer", diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 9a3c4148..1272a46e 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -845,7 +845,6 @@ def __init__(self, kernel_group, reason=None): self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() self.applys = IndentedBuffer() - self.body = IndentedBuffer() self.dma_loads = IndentedBuffer() self.dma_stores = IndentedBuffer() self.indexed_buffer = IndentedBuffer() diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 8ab94049..c3dc0c51 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -306,6 +306,9 @@ def __init__(self): def set_tile_info(self, tile_desc : MLIRMultiDimTile): self.tile_desc = tile_desc + def set_prologue_tile_info(self, tile_desc : MLIRMultiDimTile): + self.prologue_tile_desc = tile_desc + class BaseMLIRHardwareInfo(): def __init__(self): # Default HW setting diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 7a3b4b19..0b6d13ef 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -736,7 +736,7 @@ def render(self, input_reorder=self.input_reorder ) - kernel.store_info = dict( + kernel.epilogue_info = dict( output_node = self.output_node.name, dependent_buf = [], sram_var = "output_buffer", diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index a6b3423b..ec1dd9a8 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -64,10 +64,15 @@ affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { %index0 = affine.apply #map0(%t_m, %t_k) %index1 = affine.apply #map1(%t_k, %t_n) + {% if prologue_nodes -%} + // prologue nodes + {{kernel.prepare_input(indent_size=8)}} + {%- else -%} memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} memref.dma_start %W[%index1], %W_buffer[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} + {%- endif %} linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) } { accumulation_loop=true } @@ -160,6 +165,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node @@ -236,10 +242,33 @@ def render(self, W_map = W_map, Y_numel = M * N, epilogue_nodes = epilogue_nodes, + prologue_nodes = prologue_nodes, input_reorder = self.input_reorder ) - - kernel.store_info = dict( + kernel.prologue_info = dict ( + input_sram_var = "X_buffer", + input_dram_var = "X", + input_index_var = "index0", + input_tag_var = "tag1", + input_numel = M * K, + input_tile_size = (TILE_M, TILE_K), + input_sram_stride = [1, TILE_M], + vector_sram_stride = [TILE_M, 1], + input_subtile_size = (SUB_TILE_M, SUB_TILE_K), + weight_sram_var = "W_buffer", + weight_dram_var = "W", + weight_index_var = "index1", + weight_tag_var = "tag2", + weight_numel = K * N, + weight_tile_size = (TILE_K, TILE_N), + weight_sram_stride = [1, TILE_K], + weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), + tile_size = (TILE_M, TILE_K), + vlane_split_axis = 1, + vlane_stride = 1, + is_bmm = False, + ) + kernel.epilogue_info = dict( output_node = self.output_node.name, dependent_buf = [], sram_var = "Y_buffer", diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 6a5aafa0..ff617eb4 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -75,7 +75,7 @@ def render(self, out_tile=out_tile, DATA_STYPE="f32", ) - kernel.store_info = dict( + kernel.epilogue_info = dict( output_node = self.output_node.name, dependent_buf = [], sram_var = "Y_buffer", diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index ec8de5a1..a1f39543 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -1,5 +1,7 @@ import os import math +from functools import reduce +import operator from sympy import symbols, sympify from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel @@ -41,6 +43,22 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # We can't fuse dim=-1 possible = int(sympify(stride).coeff(target_symbol)) != 1 return size_match and possible + + # For prologue fusion case + if not node1.is_template() and len(node1.get_nodes())==1 and node2.is_template(): + # Return false if node2 is Convolution template + if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ + node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': + return False + if node2.get_nodes()[0].node.origin_node is not None and hasattr(node2.get_nodes()[0].node.origin_node.target, "_name") and node2.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': + return False + if node1.is_reduction(): + return False + if len(node1.read_writes.writes) != 1: + return False + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + return True + return self.scheduler.can_fuse_origin(node1, node2) def _set_flush_status(self, status: bool): @@ -167,13 +185,51 @@ def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) return kernel_name - def codegen_template_code(self, kernel, render, template_node, epilogue_nodes): + def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): with kernel: - for node in [template_node, *epilogue_nodes]: + for node in [template_node, *prologue_nodes, *epilogue_nodes]: node.mark_run() partial_code = render() - tile_desc = kernel.set_tile_size(kernel.store_info) + tile_desc = kernel.set_tile_size(kernel.epilogue_info) kernel.kernel_group.set_tile_info(tile_desc) + if prologue_nodes: + _, (group, reduction_group) = max( + prologue_nodes, key=lambda x: int(x.is_reduction()) + ).group + tile_desc = kernel.set_tile_size(kernel.prologue_info) + kernel.kernel_group.set_prologue_tile_info(tile_desc) + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + # Flush created varaibles, since template fusion doen't share variable + kernel.cse.cache.clear() + kernel.prologue_buffer_group.set_buffers() + kernel.call_ranges = None + kernel.load = kernel.load_prologue + kernel.store = kernel.store_prologue + for node in prologue_nodes: + # Reuse created spad + read_list = sorted(list(node.read_writes.reads)) + if reduce(operator.mul, read_list[-1].size, 1) == template_node.node.get_numel(): + prologue_input_arg = read_list[-1].name + else: + prologue_input_arg = read_list[0].name + prologue_output_arg = list(node.read_writes.writes)[0].name + template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] + if template_node.get_nodes()[0].node.origin_node.target._name == 'aten::bmm': + target_buf = f"{template_buf}_buffer2D" + else: + target_buf = f"{template_buf}_buffer" + + # To skip the dma code gen + kernel.buffer_names[prologue_input_arg] = target_buf + kernel.buffer_names[prologue_output_arg] = target_buf + + # Edge delete + kernel.kernel_group.args.input_buffers = { + (arg if buf != template_buf else prologue_input_arg): buf + for arg, buf in kernel.kernel_group.args.input_buffers.items() + } + node.codegen((vars, reduction_vars)) + if epilogue_nodes: _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) @@ -181,9 +237,12 @@ def codegen_template_code(self, kernel, render, template_node, epilogue_nodes): vars, reduction_vars = kernel.set_ranges(group, reduction_group) # Flush created varaibles, since template fusion doen't share variable kernel.cse.cache.clear() + kernel.epilogue_buffer_group.set_buffers() + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue for node in epilogue_nodes: if template_node.node.name in [dep[0] for dep in list(node.read_writes.reads)]: - kernel.store_info['dependent_buf'].append(node.node.name) + kernel.epilogue_info['dependent_buf'].append(node.node.name) node.codegen((vars, reduction_vars)) with V.set_kernel_handler(kernel): src_code = ( @@ -194,18 +253,29 @@ def codegen_template_code(self, kernel, render, template_node, epilogue_nodes): return src_code def codegen_template(self, template_node, epilogue_nodes): + # Handle prologue pattern + prologue_nodes = [] + if not template_node.is_template(): + epilogue_nodes = [template_node] + epilogue_nodes + for i, node in enumerate(epilogue_nodes): + if node.is_template(): + template_node = node + prologue_nodes = epilogue_nodes[:i] + epilogue_nodes = epilogue_nodes[i+1:] + break + _, (numel, rnumel) = template_node.group template_buffer = template_node.node - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = self.codegen_template_code(kernel, render, template_node, epilogue_nodes) + src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined kernel_name = wrapper.src_to_kernel[src_code] - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name - src_code = self.codegen_template_code(kernel, render, template_node, epilogue_nodes) + kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name + src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(kernel): spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index a0537201..6cd06a23 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -25,6 +25,30 @@ from . import mlir_common +class IndentedBufferGroup: + def __init__(self, kernel: 'MLIRTemplateKernel'): + self.kernel = kernel + self.body = IndentedBuffer() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.applys = IndentedBuffer() + self.dma_loads = IndentedBuffer() + self.dma_stores = IndentedBuffer() + self.spad_buffer = IndentedBuffer() + + def set_buffers(self): + self.kernel.loads = self.loads + self.kernel.compute = self.compute + self.kernel.stores = self.stores + self.kernel.dma_loads = self.dma_loads + self.kernel.dma_stores = self.dma_stores + self.kernel.spad_buffer = self.spad_buffer + + @contextlib.contextmanager + def as_local(self): + yield self + class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): def __init__(self, kernel_name, @@ -54,6 +78,9 @@ def __init__(self, self.map_cse = CSE("#", self.suffix, name_prefix="template_map") self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_const") self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_alloc") + self.prologue_buffer_group = IndentedBufferGroup(self) + self.epilogue_buffer_group = IndentedBufferGroup(self) + self.global_vars = IndentedBuffer() self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False self.reduction_idx = None @@ -321,27 +348,48 @@ def call_kernel(self, kernel_name): kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) - def codegen_body(self): + def codegen_prologue_body(self): + with self.prologue_buffer_group.as_local() as buf: + buf.body.splice(buf.spad_buffer) + buf.body.splice(buf.applys) + buf.body.splice(buf.dma_loads) + + if (buf.loads.getvalue() != '' or buf.compute.getvalue() != '' or buf.stores.getvalue() != ''): + buf.body.writelines(self.compute_body_loop.lines()) + compute_body = mlir_common.ParallelLoopBuffer() + with contextlib.ExitStack() as stack: + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(buf.loads) + compute_body.splice(buf.compute) + compute_body.splice(buf.stores) + buf.body.splice(compute_body) + + # Clear buffers + self.loads.clear() + self.compute.clear() + self.stores.clear() + + def codegen_epilogue_body(self): def template_store(): zero_cse = self.get_const_cse(0) - sram_var = self.store_info["sram_var"] - dram_var = self.store_info["dram_var"] - index_var = self.store_info["index_var"] - tag_var = self.store_info["tag_var"] - mlir_dtype = self.store_info["mlir_dtype"] - dram_shape = self.store_info["dram_shape"] + sram_var = self.epilogue_info["sram_var"] + dram_var = self.epilogue_info["dram_var"] + index_var = self.epilogue_info["index_var"] + tag_var = self.epilogue_info["tag_var"] + mlir_dtype = self.epilogue_info["mlir_dtype"] + dram_shape = self.epilogue_info["dram_shape"] vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.get_vlane_stride() - tile_stride = self.store_info["tile_stride"] + tile_stride = self.epilogue_info["tile_stride"] tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) sram_index_var = ",".join([f"%{zero_cse}"] * self.kernel_group.tile_desc.get_nr_dim()) code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, tag_var, dram_shape, tile_shape, tile_stride) self.cse.generate(self.dma_stores, code, assignment = False) - self.body.splice(self.spad_buffer) - self.body.splice(self.applys) - self.body.splice(self.dma_loads) - self.body.writelines(self.compute_body_loop.lines()) + self.epilogue_buffer_group.body.splice(self.spad_buffer) + self.epilogue_buffer_group.body.splice(self.applys) + self.epilogue_buffer_group.body.splice(self.dma_loads) + self.epilogue_buffer_group.body.writelines(self.compute_body_loop.lines()) compute_body = mlir_common.ParallelLoopBuffer() with contextlib.ExitStack() as stack: stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) @@ -349,10 +397,11 @@ def template_store(): compute_body.splice(self.compute) if len(self.stores._lines) == 0: template_store() - compute_body.splice(self.stores) - self.body.splice(compute_body) - self.body.splice(self.dma_stores) - self.body.splice(self.reduction_epilogue_suffix) + compute_body.splice(self.epilogue_buffer_group.stores) + if (compute_body.getvalue()): + self.epilogue_buffer_group.body.splice(compute_body) + self.epilogue_buffer_group.body.splice(self.dma_stores) + self.epilogue_buffer_group.body.splice(self.reduction_epilogue_suffix) # Clear buffers self.loads.clear() @@ -394,7 +443,7 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.store_info['sram_var'] + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) @@ -439,7 +488,7 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.store_info['sram_var'] #TODO: Buffer name fixed + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -467,6 +516,50 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} + def prepare_input(self, indent_size: int = 0): + def emit_dma_start(buffer_name, index_var, tag_var, size, tile_size, subtile_size=None, async_flag=True, label="X"): + base = f"memref.dma_start %{label}[%{index_var}], %{buffer_name}[%c0, %c0], %c_mvin" + if label == "W": + base = base.replace("mvin", "mvin2") + + suffix = f"%{tag_var}[%c0], %axis, %vstride" + memref_shape = f"memref<{size}xf32>" + tile_shape = "x".join([str(x) for x in tile_size]) + tile_memref = f"memref<{tile_shape}xf32, 1>" + tag_memref = f"memref<1xi32>" + attrs = f"sram_stride=[1, {tile_size[0]}]" + async_flag = "true" if async_flag else "false" + if subtile_size: + subtile_shape = ", ".join([str(x) for x in subtile_size]) + attrs = f"subtile_size=[{subtile_shape}], async={async_flag}, {attrs}" + else: + subtile_shape = ", ".join([str(x) for x in tile_size]) + attrs = f"subtile_size=[{subtile_shape}], async={async_flag}, {attrs}" + attr_memref = f"{{ {attrs} }}" + return f"{base}, {suffix}: {memref_shape}, {tile_memref}, {tag_memref} {attr_memref}" + + def hook(): + code = IndentedBuffer() + self.codegen_prologue_body() + prologue_code = self.prologue_buffer_group.body + if prologue_code.getvalue(): + code.writeline(emit_dma_start(self.prologue_info["input_sram_var"], self.prologue_info["input_index_var"], self.prologue_info["input_tag_var"], + self.prologue_info["input_numel"], self.prologue_info["input_tile_size"], subtile_size=self.prologue_info["input_subtile_size"], label="X")) + code.writeline(emit_dma_start(self.prologue_info["weight_sram_var"], self.prologue_info["weight_index_var"], self.prologue_info["weight_tag_var"], + self.prologue_info["weight_numel"], self.prologue_info["weight_tile_size"], subtile_size=self.prologue_info["weight_subtile_size"], label="W")) + code.splice(prologue_code) + else: + code.writeline(emit_dma_start(self.prologue_info["input_sram_var"], self.prologue_info["input_index_var"], self.prologue_info["input_tag_var"], + self.prologue_info["input_numel"], self.prologue_info["input_tile_size"], self.prologue_info["input_subtile_size"], async_flag=True, label="X")) + code.writeline(emit_dma_start(self.prologue_info["weight_sram_var"], self.prologue_info["weight_index_var"], self.prologue_info["weight_tag_var"], + self.prologue_info["weight_numel"], self.prologue_info["weight_tile_size"], self.prologue_info["weight_subtile_size"], async_flag=True, label="W")) + code = textwrap.indent(code.getvalue(), " "*indent_size).strip() + return code + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + def output_name(self): # Cannot know the output name from the template, so we need to hook it def hook(): @@ -481,8 +574,8 @@ def hook(): def store_output(self, indent_size: int = 0): def hook(): - self.codegen_body() - return textwrap.indent(self.body.getvalue(), " "*indent_size).strip() + self.codegen_epilogue_body() + return textwrap.indent(self.epilogue_buffer_group.body.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -569,11 +662,92 @@ def get_spad_size_per_lane(self, tile_m, tile_n): size = tile_m * ((tile_n + self.vector_lane - 1) // self.vector_lane) return max(size, 2) # vector load/store + def load_prologue(self, name: str, index: sympy.Expr): + load_dim = [] + if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: + load_dim = V.graph.graph_inputs[name].layout.size + if self.kernel_group.prologue_tile_desc.get_numel() == self.buffer_types[name][1]: + index_var = self.prologue_info['input_index_var'] if len(load_dim) != 1 else 'tile_n' + else: + # Broadcast pattern + zero_index = self.const_cse.generate(self.const_buffer, "arith.constant 0 : index") + if self.prologue_info['is_bmm']: # FIXME: hardcoded + idx = f"%b, %t_m, %{zero_index}" + map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1, d2) -> (d0 * 512 + d1)>") + vlane_split_axis = 2 + else: + idx = f"%t_m, %{zero_index}" + map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1) -> (d0)>") + vlane_split_axis = self.kernel_group.prologue_tile_desc.vlane_split_axis if len(load_dim) != 1 else 0 # FIXME: Fixed split axis for 1d load dim + index_var = self.apply_cse.generate(self.dma_loads, f"affine.apply #{map_var}({idx})") + index = self.rename_indexing(index) + dram_var = self.kernel_group.args.input(name) + dtype = V.graph.get_dtype(name) + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + vlane_stride = self.kernel_group.prologue_tile_desc.vlane_stride if len(load_dim) != 1 else 1 # FIXME: Fixed stride for 1d load dim + tile_numel_per_lane = self.kernel_group.prologue_tile_desc.get_numel_per_lane() + tile_shape = self.kernel_group.prologue_tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = self.prologue_info['input_sram_stride'] + + # Compute vector unit size + vshape = self.kernel_group.prologue_tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.prologue_tile_desc.get_compute_vec_size() + + if name not in self.buffer_names: + # Allocate sram buffer + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index, self.alloc_buffer) + self.buffer_names[name] = sram_var + code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + f"{name}_tag", dram_shape, tile_shape, tile_stride) + self.cse.generate(self.dma_loads, code, assignment = False) + + # Load vector from sram + sram_var = self.buffer_names[name] + zero_var = self.get_const_cse(0) + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.prologue_tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) + + if compute_vec_size > 1: + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + else: + operation = "affine.load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" + + out = self.cse.generate(self.loads, line) + self.register_var_info(out, [compute_vec_size, mlir_dtype]) + return out + + def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): + dtype = V.graph.get_dtype(name) + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_shape = self.kernel_group.prologue_tile_desc.get_mlir_shape(mlir_dtype) + + # Compute vector unit size + vshape = self.kernel_group.prologue_tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.prologue_tile_desc.get_compute_vec_size() + + sram_var = self.buffer_names[name] + zero_var = self.get_const_cse(0) + + _, operand_type = self.var_info[value] + if mlir_dtype != operand_type: + value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.prologue_tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) + # Generate vector load instruction + if compute_vec_size > 1: + operation = "affine.vector_store" + line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + else: + operation = "affine.store" + line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" + self.stores.writeline(line) + def load_epilogue(self, name: str, index: sympy.Expr): load_dim = [] if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: load_dim = V.graph.graph_inputs[name].layout.size - index_var = self.store_info['index_var'] if len(load_dim) != 1 else 'tile_n' + index_var = self.epilogue_info['index_var'] if len(load_dim) != 1 else 'tile_n' index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dtype = V.graph.get_dtype(name) @@ -582,7 +756,7 @@ def load_epilogue(self, name: str, index: sympy.Expr): vlane_stride = self.kernel_group.tile_desc.vlane_stride if len(load_dim) != 1 else 1 # FIXME: Fixed stride for 1d load dim tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.store_info['tile_stride'] + tile_stride = self.epilogue_info['tile_stride'] # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) @@ -636,7 +810,7 @@ def load_epilogue(self, name: str, index: sympy.Expr): return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - index_var = self.store_info['index_var'] + index_var = self.epilogue_info['index_var'] dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] @@ -646,7 +820,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.store_info['tile_stride'] + tile_stride = self.epilogue_info['tile_stride'] # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) @@ -816,13 +990,13 @@ def store_reduction_epilogue(self, name, index, value): def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, buffer=None): return super().get_scratchpad_buffer(dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, True, buffer=buffer) - def set_tile_size(self, template_store_info): - tile_desc = mlir_common.MLIRMultiDimTile(template_store_info['tile_size'], + def set_tile_size(self, template_epilogue_info): + tile_desc = mlir_common.MLIRMultiDimTile(template_epilogue_info['tile_size'], self.vector_lane, - vlane_split_axis=template_store_info['vlane_split_axis'], - vlane_stride=template_store_info['vlane_stride']) + vlane_split_axis=template_epilogue_info['vlane_split_axis'], + vlane_stride=template_epilogue_info['vlane_stride']) - if 'nr_rdim' in template_store_info and template_store_info['nr_rdim']==1: + if 'nr_rdim' in template_epilogue_info and template_epilogue_info['nr_rdim']==1: tile_desc.nr_rdim = 1 numel_per_lane = tile_desc.get_numel_per_lane() reduction_axis_size = tile_desc.get_tile_size()[-2] @@ -832,7 +1006,7 @@ def set_tile_size(self, template_store_info): self.reduction_fusion = True self.reduction_axis_size = tile_desc.get_tile_size()[-2] self.reduction_nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size - self.reduction_idx = template_store_info["reduction_idx"] + self.reduction_idx = template_epilogue_info["reduction_idx"] self.compute_body_loop.size = reduction_axis_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop else: @@ -890,6 +1064,7 @@ def generate(self, **kwargs) -> ChoiceCaller: def make_kernel_render( template_node: TemplateBuffer, + prologue_nodes: Optional[List[IRNode]] = None, epilogue_nodes: Optional[List[IRNode]] = None, kernel_name: str = kernel_hash_name, kernel_group: Optional[mlir_common.MLIRWrapperKenrelGroup] = None @@ -910,7 +1085,8 @@ def make_kernel_render( kwargs = { 'kernel': kernel, 'template_buffer_node': template_node, - 'epilogue_nodes': epilogue_nodes + 'epilogue_nodes': epilogue_nodes, + 'prologue_nodes': prologue_nodes, } render = functools.partial( kernel.render, diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py new file mode 100644 index 00000000..12098b24 --- /dev/null +++ b/tests/Fusion/test_prologue_fusion.py @@ -0,0 +1,81 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_elem_broadcast_fusion(device): + def matmul_fused(a, b, c): + return torch.matmul(c * a, b) + torch.manual_seed(0) + input = torch.randn(128, 128) + weight = torch.randn(128, 128) + c = torch.randn(128, 1, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + c1 = c.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c2 = c.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c1) + y = matmul_fused(x2, w2, c2) + test_result("Matmul Scalar Fusion Forward", res, y) + +def test_elem_fusion(device): + def matmul_fused(a, b, c): + return torch.matmul(c * a, b) + torch.manual_seed(0) + input = torch.randn(128, 128) + weight = torch.randn(128, 128) + c = torch.randn(128, 128, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + c1 = c.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c2 = c.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c1) + y = matmul_fused(x2, w2, c2) + test_result("Matmul Element-wise Fusion Forward", res, y) + +def test_elem_bmm_fusion(device, batch_size=1, m=512, n=512, k=64): + def bmm(a, b, c, d): + return torch.bmm((a - b)/c , d) + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, m, 1).to(device=device) + c = torch.randn(batch_size, m, 1) * 1000 + c = c.to(device=device) + d = torch.randn(batch_size, k, n).to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, c, d) + out = bmm(a.cpu(), b.cpu(), c.cpu(), d.cpu()) + print(torch.max(torch.abs(res.cpu() - out))) + test_result("BMM Element-wise Fusion Forward", res, out) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + test_elem_broadcast_fusion(device) + test_elem_fusion(device) + test_elem_bmm_fusion(device, batch_size=12, m=512, n=64, k=512) \ No newline at end of file From 018078ec5ae1a8427a330886e856c4f925356d0d Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 17 Jun 2025 21:13:12 +0000 Subject: [PATCH 02/62] [Frontend/Fusion] Optimize BMM+Reduction fusion --- PyTorchSimFrontend/mlir/mlir_template.py | 218 ++++++++++------------- 1 file changed, 99 insertions(+), 119 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 6cd06a23..a9da6e9d 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -83,6 +83,7 @@ def __init__(self, self.global_vars = IndentedBuffer() self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False + self.reduction_body_loop = None self.reduction_idx = None # Overwrite ops @@ -386,6 +387,12 @@ def template_store(): code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, tag_var, dram_shape, tile_shape, tile_stride) self.cse.generate(self.dma_stores, code, assignment = False) + # Do dma store first to overlap epilogue nodes + if self.reduction_fusion: + if len(self.stores._lines) == 0: + template_store() + self.epilogue_buffer_group.body.splice(self.dma_stores) + self.dma_stores.clear() self.epilogue_buffer_group.body.splice(self.spad_buffer) self.epilogue_buffer_group.body.splice(self.applys) self.epilogue_buffer_group.body.splice(self.dma_loads) @@ -393,10 +400,18 @@ def template_store(): compute_body = mlir_common.ParallelLoopBuffer() with contextlib.ExitStack() as stack: stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) - compute_body.splice(self.loads) - compute_body.splice(self.compute) - if len(self.stores._lines) == 0: - template_store() + if self.reduction_fusion: + #if len(self.stores._lines) == 0: + # template_store() + compute_body.writelines(self.reduction_body_loop.lines()) + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(self.loads) + compute_body.splice(self.compute) + else: + compute_body.splice(self.loads) + compute_body.splice(self.compute) + if len(self.stores._lines) == 0: + template_store() compute_body.splice(self.epilogue_buffer_group.stores) if (compute_body.getvalue()): self.epilogue_buffer_group.body.splice(compute_body) @@ -783,30 +798,22 @@ def load_epilogue(self, name: str, index: sympy.Expr): operation = "affine.load" line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" out = self.cse.generate(self.loads, line) + self.register_var_info(out, [compute_vec_size, mlir_dtype]) else: # For reduction case reduce_size = self.reduction_nr_outer_loop vsize = compute_vec_size//reduce_size vshape = f"vector<{vsize}x{mlir_dtype}>" - flatten_tshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - init = self.cse.generate(self.loads, f"arith.constant 0.0 : {mlir_dtype}") - init_vec = self.cse.generate(self.loads, f"vector.broadcast %{init} : {mlir_dtype} to {flatten_tshape}") if compute_vec_size > 1: - out_list = [] - for i in range(reduce_size): - offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0) -> (d0 + {i*(self.reduction_axis_size)})>(%{self.compute_idx})") - compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - out = self.cse.generate(self.loads, line) - out_list.append(out) - for idx, partial_out in enumerate(out_list): - init_vec = self.cse.generate(self.loads, f"vector.insert_strided_slice %{partial_out}, %{init_vec} {{offsets=[{vsize*idx}],strides=[1]}} : {vshape} into {flatten_tshape}") - out = init_vec + offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.reduction_axis_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + out = self.cse.generate(self.loads, line) else: line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" out = self.cse.generate(self.loads, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) + self.register_var_info(out, [self.compute_body_loop.step, mlir_dtype]) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): @@ -859,91 +866,39 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): if argmax_or_argmin or is_welford_reduction(reduction_type): raise NotImplementedError() #TODO: argmin, argmax - # Prepare reduction loop - reduction_key = src_dtype, reduction_type, value - acc = self.reduction_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - iterator = self.iterator_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init = self.init_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) - init_vec = self.init_vec_cse.generate( - self.loads, f"reduction {reduction_key}", write=False - ) + # Reduction fusion codegen part type_name = mlir_common.DTYPE_TO_MLIR[dtype] - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - vec_len = self.kernel_group.tile_desc.get_compute_vec_size() - reduced_shape = self.kernel_group.tile_desc.get_mlir_vshape(type_name) - - # Set accumulation var - if vec_len == 1: # 1-D vector to scalar - # Edge case for scalar - init_vec = init - else: - # Adjust shape and inital value - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") - acc_var = init_vec - - # Reduction body prepare - body_acc = self.reduction_cse.generate( - self.compute, f"reduction {reduction_key}body_acc", write=False - ) - body_iter_arg = self.iterator_cse.generate( - self.compute, f"reduction {reduction_key}body_iter_arg", write=False - ) - self.register_var_info(body_iter_arg, [vec_len, type_name]) + vec_size = self.compute_body_loop.step + vshape = f"vector<{vec_size}x{type_name}>" + + tile_shape = f"memref<{self.reduction_body_loop.size * self.vector_lane}x{vec_size}x{type_name}, 1>" + name = f"{reduction_type}_buffer" + index = "dummy_index" # Not used + tile_numel_per_lane = self.compute_body_loop.step * self.reduction_body_loop.size + sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, None, index, self.const_buffer) + zero_var = self.get_const_cse(0) - self.reduction_vars[acc] = (reduction_type, iterator, acc_var, reduced_shape) - self.affine_yield[body_acc] = reduced_shape - self.reduction_cse.reduction_cache[reduction_key] = acc - self.iterator_cse.reduction_cache[reduction_key] = iterator - self.init_cse.reduction_cache[reduction_key] = init_vec + # Load partial result + operation = "affine.vector_load" + compute_index_var = ",".join([f"%{self.reduction_loop_idx}"] + [f"%{zero_var}"]) + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + out = self.cse.generate(self.loads, line) + self.register_var_info(out, [self.compute_body_loop.step, type_name]) # Reduction body codegen - result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) - self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iterator, reduced_shape) - self.compute_body_loop.affine_yield[result] = reduced_shape - - # Final reduction - reduction_size = self.reduction_nr_outer_loop - if vec_len > reduction_size: - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") - if reduction_size == 1: - final_reduced_shape = f"{type_name}" - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, acc, init, axis=0, shape=reduced_shape, reduced_shape=final_reduced_shape)) - else: - final_reduced_shape = f"vector<{reduction_size}x{type_name}>" - init_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{init} : {type_name} to {final_reduced_shape}") - new_vshape= f"vector<{reduction_size}x{vec_len//reduction_size}x{type_name}>" - partial_vshape= f"vector<{vec_len//reduction_size}x{type_name}>" - value = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{acc} : {reduced_shape} to {new_vshape}") - # FIXME. I want to use N-Rank multi-reduciton, but we can't use it. It lowerd to scalar operations now... - for i in range(reduction_size): - partial_value = self.cse.generate(self.reductions_suffix, f"vector.extract %{value}[{i}] : {partial_vshape} from {new_vshape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(reduction_type, partial_value, init, axis=0, shape=partial_vshape, reduced_shape=type_name)) - init_vec = self.cse.generate(self.reductions_suffix, f"vector.insert %{out}, %{init_vec}[{i}] : {type_name} into {final_reduced_shape}") - out = init_vec - acc = out - - # reigster reduction output - var_info = [reduction_size, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(acc, var_info) - - # Specail handling for fusion - self.reduction_epilogue_suffix.writeline(f"affine.yield %{body_acc} : {self.affine_yield[body_acc]}") - return acc + result = reduction_partial_combine_vec(reduction_type, value, out) - def store_reduction_epilogue(self, name, index, value): - index = self.reduction_idx - tmp_cse = self.cse - self.cse = self.reduction_cse + # Store partial result + operation = "affine.vector_store" + line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + self.compute.writeline(line) # Need to be placed after partial reduction + self.reduction_info = {sram_var : reduction_type} + return sram_var + def store_reduction_epilogue(self, name, index, value): dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + type_name = mlir_common.DTYPE_TO_MLIR[dtype] index = self.rename_indexing(index) # Tile is always reuduced in inner loop @@ -953,40 +908,63 @@ def store_reduction_epilogue(self, name, index, value): vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis - 1 vlane_stride = self.kernel_group.tile_desc.vlane_stride - tile_numel_per_lane = vlane_stride * nr_outer_loop + tile_numel_per_lane = vlane_stride * nr_outer_loop * 2 dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - tile_shape = f"memref<{self.kernel_group.tile_desc.get_tile_size()[1]}x{mlir_dtype}, 1>" + tile_shape = f"memref<{self.kernel_group.tile_desc.get_tile_size()[1]*2}x{type_name}, 1>" tile_stride = [1] - compute_vec_size = self.var_info[value][0] - if compute_vec_size == 1: - vshape = f"{mlir_dtype}" - else: - vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index, index, buffer=self.const_buffer) + for i in range(self.reduction_body_loop.size): + vec_size = self.compute_body_loop.step + vshape = f"vector<{vec_size}x{type_name}>" + + partial_tile_shape = f"memref<{self.reduction_body_loop.size * self.vector_lane}x{vec_size}x{type_name}, 1>" + # Load partial result + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value], dtype)} : {type_name}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") + zero_var = self.const_cse.generate(self.const_buffer, f"arith.constant {0} : index") + index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") + compute_index_var = ",".join([f"%{index_var}"] + [f"%{zero_var}"]) - if self.welford_reduce_out is not None: - raise NotImplementedError() + operation = "affine.vector_load" + line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {vshape}" + out = self.cse.generate(self.reductions_suffix, line) + operation = "affine.vector_store" + line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {vshape}" + self.reductions_suffix.writeline(line) + + # 2 step reduction + new_vec_size = 2 + new_vshape = f"vector<{vec_size//new_vec_size}x{new_vec_size}x{type_name}>" + new_reduced_shape = f"vector<{new_vec_size}x{type_name}>" + out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {vshape} to {new_vshape}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {new_reduced_shape}") + out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) + out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") + + self.compute, self.reductions_suffix = self.reductions_suffix, self.compute + self.register_var_info(out, [new_vec_size, type_name]) + self.register_var_info(out2, [new_vec_size, type_name]) + out = reduction_partial_combine_vec(self.reduction_info[value], out, out2) + self.compute, self.reductions_suffix = self.reductions_suffix, self.compute + + # Final reduction + #final_reduced_shape = type_name + #init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value], dtype)} : {type_name}") + #out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value], out, init, axis=0, shape=vshape, reduced_shape=final_reduced_shape)) - # Select src type - if compute_vec_size == 1: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}" - else: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{sram_index_var}] : {tile_shape}, {vshape}" - self.reductions_suffix.writeline(DeferredLine(name, line)) + operation = "affine.vector_store" + line = f"{operation} %{out}, %{sram_var}[%{index_var}] : {tile_shape}, {new_reduced_shape}" + self.reductions_suffix.writeline(DeferredLine(name, line)) # MVOUT Encoding # Generate DMA instruction - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + index_var = "red_idx" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, type_name, dram_var, index_var, sram_var, sram_index_var, + f"{name}_tag", dram_shape, tile_shape, tile_stride) self.reductions_suffix.writeline(DeferredLine(name, code)) - # Restore origin cse - self.cse = tmp_cse - def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, buffer=None): return super().get_scratchpad_buffer(dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, True, buffer=buffer) @@ -1001,14 +979,16 @@ def set_tile_size(self, template_epilogue_info): numel_per_lane = tile_desc.get_numel_per_lane() reduction_axis_size = tile_desc.get_tile_size()[-2] nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size - tile_desc.vec_size = nr_outer_loop * 2 # Why? Emprically selected, other option failed to functionality... + tile_desc.vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True self.reduction_axis_size = tile_desc.get_tile_size()[-2] self.reduction_nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size self.reduction_idx = template_epilogue_info["reduction_idx"] + self.reduction_loop_idx = "reduce_loop_idx" self.compute_body_loop.size = reduction_axis_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop + self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: tile_desc.vec_size=64 self.compute_body_loop.size = tile_desc.get_numel_per_lane() From 9ce931025a362563e28a06d1b1499d5e16068ba9 Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Wed, 18 Jun 2025 04:40:43 +0000 Subject: [PATCH 03/62] [Frontend] optimize attention kernel --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 4 ++-- PyTorchSimFrontend/mlir/mlir_scheduling.py | 6 +++--- experiments/BERT.py | 5 +++-- experiments/attention.py | 6 +++--- tests/Fusion/test_attention_fusion.py | 3 +-- tests/test_matmul.py | 21 ++++++++++++++++++++ tests/test_transformer.py | 10 ++++------ 7 files changed, 37 insertions(+), 18 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 85631adb..25858222 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -262,7 +262,7 @@ def render(self, TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - TILE_K = TILE_K // 2 if prologue_nodes else TILE_K + TILE_K = TILE_K // 4 if prologue_nodes else TILE_K SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane @@ -320,7 +320,7 @@ def render(self, weight_tile_size = (TILE_K, TILE_N), weight_sram_stride = [1, TILE_K], weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), - tile_size = (TILE_M, TILE_K), + tile_size = (TILE_K, TILE_N), vlane_split_axis = 1, vlane_stride = 1, is_bmm = True, diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index a1f39543..d41a9128 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -47,9 +47,9 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # For prologue fusion case if not node1.is_template() and len(node1.get_nodes())==1 and node2.is_template(): # Return false if node2 is Convolution template - if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ - node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': - return False + # if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ + # node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': + # return False if node2.get_nodes()[0].node.origin_node is not None and hasattr(node2.get_nodes()[0].node.origin_node.target, "_name") and node2.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': return False if node1.is_reduction(): diff --git a/experiments/BERT.py b/experiments/BERT.py index e111908e..7086ad9a 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -7,7 +7,8 @@ def run_BERT(size, input_seq, config): from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - from tests.test_transformer import DecoderBlock + # from tests.test_transformer import DecoderBlock + from tests.Fusion.test_transformer_fusion import DecoderBlock scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) device = scheduler.execution_engine.module.custom_device() @@ -35,7 +36,7 @@ def run_BERT(size, input_seq, config): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.json') + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name sys.path.append(base_dir) args = argparse.ArgumentParser() diff --git a/experiments/attention.py b/experiments/attention.py index acfed848..e8f89dac 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -10,9 +10,9 @@ def run_attention(size, config): def attention(query, key, value): import math d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-1) - return torch.matmul(p_attn, value) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) device = scheduler.execution_engine.module.custom_device() diff --git a/tests/Fusion/test_attention_fusion.py b/tests/Fusion/test_attention_fusion.py index a513b0bb..95bdf165 100644 --- a/tests/Fusion/test_attention_fusion.py +++ b/tests/Fusion/test_attention_fusion.py @@ -47,8 +47,7 @@ def forward(self, query, key, value): x = torch.matmul(value.transpose(-1, -2), p_attn) # 3) "Concat" using a view and apply a final linear. x = ( - x.contiguous() - .view(-1, self.h * self.d_k) + x.view(-1, self.h * self.d_k) ) del query del key diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 44f70b69..bd219051 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -50,6 +50,27 @@ def custom_matmul(bias, a, b): y = custom_matmul(b2, x2, w2) test_result("Addmm Forward", res, y) +def test_linear(device, input_size=128, hidden_size=128, output_size=128): + def custom_linear(a, b, bias): + linear = torch.nn.Linear(hidden_size, output_size) + linear.weight = torch.nn.Parameter(b) + linear.bias = torch.nn.Parameter(bias) + return linear(a) + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(output_size, hidden_size) + bias = torch.randn(output_size) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_linear) + res = opt_fn(x1, w1, b1) + y = custom_linear(x2, w2, b2) + test_result("Linear Forward", res, y) + if __name__ == "__main__": import os import sys diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 83ed5850..82773da2 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -41,14 +41,12 @@ def forward(self, query, key, value): ] # 2) Apply attention on all the projected vectors in batch. - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k) - p_attn = scores.softmax(dim=-1) - x = torch.matmul(p_attn, value) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(self.d_k) + p_attn = scores.softmax(dim=-2) + x = torch.matmul(value.transpose(-1, -2), p_attn) # 3) "Concat" using a view and apply a final linear. x = ( - x.transpose(0, 1) - .contiguous() - .view(-1, self.h * self.d_k) + x.view(-1, self.h * self.d_k) ) del query del key From 831dddf9f32c2ae068a0b5db261211875cd424da Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Wed, 18 Jun 2025 07:23:46 +0000 Subject: [PATCH 04/62] [Fix] BMM weight fused --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 51 +++++++++++--------- PyTorchSimFrontend/mlir/mlir_template.py | 4 +- tests/Fusion/test_prologue_fusion.py | 8 +-- tests/test_transformer.py | 10 ++-- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 25858222..41f90864 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -303,28 +303,35 @@ def render(self, input_reorder = self.input_reorder ) - kernel.prologue_info = dict ( - input_sram_var = "X_buffer2D", - input_dram_var = "X", - input_index_var = "index0", - input_tag_var = "tag1", - input_numel = B * M * K, - input_tile_size = (TILE_M, TILE_K), - input_sram_stride = [1, TILE_M], - input_subtile_size = (SUB_TILE_M, SUB_TILE_K), - weight_sram_var = "W_buffer2D", - weight_dram_var = "W", - weight_index_var = "index1", - weight_tag_var = "tag2", - weight_numel = B * K * N, - weight_tile_size = (TILE_K, TILE_N), - weight_sram_stride = [1, TILE_K], - weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), - tile_size = (TILE_K, TILE_N), - vlane_split_axis = 1, - vlane_stride = 1, - is_bmm = True, - ) + if prologue_nodes: + # if Input fused: + # tile_size = (TILE_M, TILE_K) + # input_sram_stride = [1, TILE_M] + # elif Weight fused: + tile_size = (TILE_K, TILE_N) + input_sram_stride = [1, TILE_K] + kernel.prologue_info = dict ( + input_sram_var = "X_buffer2D", + input_dram_var = "X", + input_index_var = "index0", + input_tag_var = "tag1", + input_numel = B * M * K, + input_tile_size = (TILE_M, TILE_K), + input_sram_stride = input_sram_stride, + input_subtile_size = (SUB_TILE_M, SUB_TILE_K), + weight_sram_var = "W_buffer2D", + weight_dram_var = "W", + weight_index_var = "index1", + weight_tag_var = "tag2", + weight_numel = B * K * N, + weight_tile_size = (TILE_K, TILE_N), + weight_sram_stride = [1, TILE_K], + weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), + tile_size = tile_size, + vlane_split_axis = 1, + vlane_stride = 1, + is_bmm = True, + ) kernel.epilogue_info = dict( output_node = self.output_node.name, dependent_buf = [], diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index a9da6e9d..c6893e73 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -687,8 +687,8 @@ def load_prologue(self, name: str, index: sympy.Expr): # Broadcast pattern zero_index = self.const_cse.generate(self.const_buffer, "arith.constant 0 : index") if self.prologue_info['is_bmm']: # FIXME: hardcoded - idx = f"%b, %t_m, %{zero_index}" - map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1, d2) -> (d0 * 512 + d1)>") + idx = f"%b, %t_k, %t_n" + map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1, d2) -> (d0 * 512 + d2)>") vlane_split_axis = 2 else: idx = f"%t_m, %{zero_index}" diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py index 12098b24..926782be 100644 --- a/tests/Fusion/test_prologue_fusion.py +++ b/tests/Fusion/test_prologue_fusion.py @@ -55,11 +55,11 @@ def matmul_fused(a, b, c): def test_elem_bmm_fusion(device, batch_size=1, m=512, n=512, k=64): def bmm(a, b, c, d): - return torch.bmm((a - b)/c , d) + return torch.bmm(a , (d - b)/c) torch.manual_seed(0) a = torch.randn(batch_size, m, k).to(device=device) - b = torch.randn(batch_size, m, 1).to(device=device) - c = torch.randn(batch_size, m, 1) * 1000 + b = torch.randn(batch_size, 1, n).to(device=device) + c = torch.randn(batch_size, 1, n) * 1000 c = c.to(device=device) d = torch.randn(batch_size, k, n).to(device=device) opt_fn = torch.compile(dynamic=False)(bmm) @@ -78,4 +78,4 @@ def bmm(a, b, c, d): device = module.custom_device() test_elem_broadcast_fusion(device) test_elem_fusion(device) - test_elem_bmm_fusion(device, batch_size=12, m=512, n=64, k=512) \ No newline at end of file + test_elem_bmm_fusion(device, batch_size=12, m=64, n=512, k=512) \ No newline at end of file diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 82773da2..cfa2a622 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -87,9 +87,9 @@ def test_Attention(device, head=16, seq=512, d_k=64): def attention(query, key, value): import math d_k = query.size(-1) - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-1) - return torch.matmul(p_attn, value), p_attn + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) torch.manual_seed(0) query = torch.randn(head, seq, d_k).to(device=device) @@ -97,9 +97,9 @@ def attention(query, key, value): value = torch.randn(head, seq, d_k).to(device=device) opt_fn = torch.compile(dynamic=False)(attention) - res, p_attn = opt_fn(query, key, value) + res = opt_fn(query, key, value) - cpu_res, cpu_p_attn = attention(query.cpu(), key.cpu(), value.cpu()) + cpu_res = attention(query.cpu(), key.cpu(), value.cpu()) test_result("Attention Forward", res, cpu_res) def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): From d0108fd891c31021a67b5e84118abb7da3c11588 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 18 Jun 2025 11:14:16 +0000 Subject: [PATCH 05/62] [Frontend/Fusion] Implement matmul+var_mean fusion for LayerNorm --- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_scheduling.py | 7 +-- PyTorchSimFrontend/mlir/mlir_template.py | 33 +++++++++--- tests/Fusion/test_matmul_reduction.py | 53 ++++++++++++++++++- 4 files changed, 82 insertions(+), 13 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index ec1dd9a8..35132739 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -199,7 +199,7 @@ def render(self, if (M == 0) or (N == 0) or (K == 0): TILE_M, TILE_N, TILE_K = 1, 1, 1 template = EMPTY_TEMPLATE - elif n_extra_node==1 and epilogue_nodes[0].is_reduction(): + elif n_extra_node>=1 and epilogue_nodes[0].is_reduction(): TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node, min_tile=True) template = GEMM_REDUCTION_TEMPLATE nr_rdim = 1 diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index d41a9128..bc0e8560 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -31,13 +31,10 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if node1.get_device() == node2.get_device(): from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (node1.is_template() and len(node1.get_nodes())==1 and \ - (isinstance(node1.node.template, MLIRGemmTemplate) or isinstance(node1.node.template, MLIRBMMTemplate)) and \ + if (node1.is_template() and (isinstance(node1.get_nodes()[0].node.template, MLIRGemmTemplate) or isinstance(node1.node.template, MLIRBMMTemplate)) and \ node2.is_reduction() and len(node2.get_nodes())==1): # For matmul/bmm+reduction case - size_match = node1.node.get_size() == node2.node.get_size() + node2.node.get_reduction_size() - if len(node1.node.get_size()) == len(node2.node.get_size()): - size_match = node1.node.get_size() == [dim for dim in node2.node.get_size() if dim!=1] + node2.node.get_reduction_size() + size_match = reduce(operator.mul, node1.get_nodes()[0].node.get_size(), 1) == reduce(operator.mul, node2.node.get_size(), 1) * reduce(operator.mul, node2.node.get_reduction_size(), 1) stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.node).split("\n") if "r0" in i][1] target_symbol = symbols("r0") # We can't fuse dim=-1 diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index c6893e73..935510b6 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -81,10 +81,14 @@ def __init__(self, self.prologue_buffer_group = IndentedBufferGroup(self) self.epilogue_buffer_group = IndentedBufferGroup(self) self.global_vars = IndentedBuffer() + # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False self.reduction_body_loop = None self.reduction_idx = None + self.reduction_buffer_idx = 0 + self.reduction_info = {} + self.reduction_epilogue_result = {} # Overwrite ops self.load = self.load_epilogue @@ -863,8 +867,23 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): argmax_or_argmin = reduction_type in {"argmax", "argmin"} - if argmax_or_argmin or is_welford_reduction(reduction_type): + if argmax_or_argmin: raise NotImplementedError() #TODO: argmin, argmax + if is_welford_reduction(reduction_type): + if reduction_type == "welford_combine": + raise NotImplementedError("welford_combine") + else: + assert reduction_type == "welford_reduce" + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + reduction_key = src_dtype, reduction_type, value + sum = self.reduction_epilogue(dtype, src_dtype, "sum", value) + sqr_sum = self.reduction_epilogue(dtype, src_dtype, "sum", ops.mul(value, value)) + self.welford_reduce_out = (sum, sqr_sum, None) + return sum, sqr_sum, None + # Check duplicated reductions + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_epilogue_result: + return self.reduction_epilogue_result[reduction_key] # Reduction fusion codegen part type_name = mlir_common.DTYPE_TO_MLIR[dtype] @@ -872,13 +891,15 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): vshape = f"vector<{vec_size}x{type_name}>" tile_shape = f"memref<{self.reduction_body_loop.size * self.vector_lane}x{vec_size}x{type_name}, 1>" - name = f"{reduction_type}_buffer" + name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" + self.reduction_buffer_idx += 1 index = "dummy_index" # Not used tile_numel_per_lane = self.compute_body_loop.step * self.reduction_body_loop.size sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, None, index, self.const_buffer) - zero_var = self.get_const_cse(0) + self.reduction_epilogue_result[reduction_key] = sram_var # Load partial result + zero_var = self.get_const_cse(0) operation = "affine.vector_load" compute_index_var = ",".join([f"%{self.reduction_loop_idx}"] + [f"%{zero_var}"]) line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" @@ -892,7 +913,7 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): operation = "affine.vector_store" line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" self.compute.writeline(line) # Need to be placed after partial reduction - self.reduction_info = {sram_var : reduction_type} + self.reduction_info[sram_var] = reduction_type return sram_var def store_reduction_epilogue(self, name, index, value): @@ -911,7 +932,7 @@ def store_reduction_epilogue(self, name, index, value): tile_numel_per_lane = vlane_stride * nr_outer_loop * 2 dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - tile_shape = f"memref<{self.kernel_group.tile_desc.get_tile_size()[1]*2}x{type_name}, 1>" + tile_shape = f"memref<{self.kernel_group.tile_desc.get_tile_size()[1]}x{type_name}, 1>" tile_stride = [1] sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index, index, buffer=self.const_buffer) @@ -960,7 +981,7 @@ def store_reduction_epilogue(self, name, index, value): # MVOUT Encoding # Generate DMA instruction - index_var = "red_idx" + index_var = self.reduction_idx code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, type_name, dram_var, index_var, sram_var, sram_index_var, f"{name}_tag", dram_shape, tile_shape, tile_stride) self.reductions_suffix.writeline(DeferredLine(name, code)) diff --git a/tests/Fusion/test_matmul_reduction.py b/tests/Fusion/test_matmul_reduction.py index 9f2cc7f3..07dd914d 100644 --- a/tests/Fusion/test_matmul_reduction.py +++ b/tests/Fusion/test_matmul_reduction.py @@ -38,6 +38,55 @@ def matmul_fused(a, b, c): test_result("Matmul Reduction Fusion activation", res[0], y[0]) test_result("Matmul Reduction Fusion reduction", res[1], y[1]) +def test_matmul_var_mean(device, size=512): + def matmul_fused(a, b, c): + result = torch.matmul(a, b.T) + var, mean = torch.var_mean(result, dim=-2) + return result, var, mean + torch.manual_seed(0) + N = size + input = torch.randn(3072, 768) + weight = torch.randn(512, 768) + #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) + #weight = torch.eye(N, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + c = 7 + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, c) + y = matmul_fused(x2, w2, c) + test_result("Matmul var_mean Fusion activation", res[0], y[0]) + test_result("Matmul var_mean Fusion reduction", res[1], y[1]) + test_result("Matmul var_mean Fusion reduction", res[2], y[2]) + +def test_matmul_add_var_mean(device, size=512): + def matmul_fused(a, b, c, d): + result = torch.matmul(a, b.T) + c.T + var, mean = torch.var_mean(result + d, dim=-2) + return result, var, mean + torch.manual_seed(0) + N = size + input = torch.randn(768, 3072) + weight = torch.randn(512, 3072) + bias = torch.randn(768, 512) + residual = torch.randn(768,512) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + r1 = residual.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + r2 = residual.to("cpu") + opt_fn = torch.compile(dynamic=False)(matmul_fused) + res = opt_fn(x1, w1, b1, r1) + y = matmul_fused(x2, w2, b2, r2) + test_result("Matmul+residual+var_mean Fusion activation", res[0], y[0]) + test_result("Matmul+residual+var_mean Fusion reduction", res[1], y[1]) + test_result("Matmul+residual+var_mean Fusion reduction", res[2], y[2]) + if __name__ == "__main__": import os import sys @@ -46,4 +95,6 @@ def matmul_fused(a, b, c): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_matmul_reduce(device) + #test_matmul_reduce(device) + test_matmul_var_mean(device) + #test_matmul_add_var_mean(device) From bb2a083e2f15b29902a27ee7acfe349af7b90c9e Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 18 Jun 2025 16:41:02 +0000 Subject: [PATCH 06/62] [Temporary] Make compile it force --- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 34 ++++++++++++------- PyTorchSimFrontend/mlir/mlir_template.py | 6 ++-- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 35132739..ed1361ff 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -24,6 +24,7 @@ #map0 = affine_map<(d0, d1) -> ({{ X_map }})> #map1 = affine_map<(d0, d1) -> ({{ W_map }})> #map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> +#map3 = affine_map<(d0, d1) -> (d0 * {{ N }})> memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> @@ -50,14 +51,11 @@ affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { %index2 = affine.apply #map2(%t_m, %t_n) + %index3 = affine.apply #map2(%t_m, %t_n) {%- if Bias %} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % + memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + , %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } {%- else %} affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} @@ -102,6 +100,7 @@ #map0 = affine_map<(d0, d1) -> ({{ X_map }})> #map1 = affine_map<(d0, d1) -> ({{ W_map }})> #map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> +#map3 = affine_map<(d0, d1) -> (d0 * {{ N }})> memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> @@ -128,14 +127,11 @@ affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { %index2 = affine.apply #map2(%t_m, %t_n) + %index3 = affine.apply #map2(%t_m, %t_n) {%- if Bias %} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % + memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + , %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } {%- else %} affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} @@ -174,7 +170,6 @@ def render(self, X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] W_tensor = empty_strided(W.layout.size, W.layout.stride) X_tensor = empty_strided(X.layout.size, X.layout.stride) @@ -219,6 +214,18 @@ def render(self, TOG_latency = M if SUB_TILE_M > M else SUB_TILE_M kernel.loop_size =[TOG_latency, SUB_TILE_N, SUB_TILE_K] + # Extract Bias info + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + if Bias is not None: + if Bias.data.get_numel() == M*N: + Bias_idx = "%index2" + elif Bias.data.get_numel() == M: + Bias_idx = "%index3" + else: + Bias_idx = "%t_n" + else: + Bias_idx = None + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -237,6 +244,7 @@ def render(self, W = W, Y = Y, Bias = Bias, + Bias_idx = Bias_idx, Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, X_map = X_map, W_map = W_map, diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 935510b6..8017a3a5 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -766,13 +766,13 @@ def load_epilogue(self, name: str, index: sympy.Expr): load_dim = [] if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: load_dim = V.graph.graph_inputs[name].layout.size - index_var = self.epilogue_info['index_var'] if len(load_dim) != 1 else 'tile_n' + index_var = self.epilogue_info['index_var'] if len(load_dim) <= 1 else 'tile_n' index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis if len(load_dim) != 1 else 0 # FIXME: Fixed split axis for 1d load dim - vlane_stride = self.kernel_group.tile_desc.vlane_stride if len(load_dim) != 1 else 1 # FIXME: Fixed stride for 1d load dim + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis if len(load_dim) <= 1 else 0 # FIXME: Fixed split axis for 1d load dim + vlane_stride = self.kernel_group.tile_desc.vlane_stride if len(load_dim) <= 1 else 1 # FIXME: Fixed stride for 1d load dim tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.epilogue_info['tile_stride'] From 519088521059fba16a98ff9b883872cdb7b364c8 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 18 Jun 2025 18:21:37 +0000 Subject: [PATCH 07/62] [Frontend/Fusion] Fix&cleanup fusion policy --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 3 +- PyTorchSimFrontend/mlir/mlir_scheduling.py | 47 +++++++++++--------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 41f90864..043cd46b 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -267,7 +267,8 @@ def render(self, SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - if n_extra_node==1 and epilogue_nodes[0].is_reduction(): + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: template = BMM_REDUCTION_TEMPLATE nr_rdim = 1 elif prologue_nodes: diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index bc0e8560..ffca0d8c 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -28,33 +28,40 @@ def __init__(self, scheduler): self.max_fusion_size = 5 def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: - if node1.get_device() == node2.get_device(): + # Extract base template node + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] + if node1.get_device() != node2.get_device(): + return False + + if len(base_template_node1) == 1 and len(base_template_node2) == 0: from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (node1.is_template() and (isinstance(node1.get_nodes()[0].node.template, MLIRGemmTemplate) or isinstance(node1.node.template, MLIRBMMTemplate)) and \ - node2.is_reduction() and len(node2.get_nodes())==1): + if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction() and len(node2.get_nodes())==1: # For matmul/bmm+reduction case size_match = reduce(operator.mul, node1.get_nodes()[0].node.get_size(), 1) == reduce(operator.mul, node2.node.get_size(), 1) * reduce(operator.mul, node2.node.get_reduction_size(), 1) stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.node).split("\n") if "r0" in i][1] target_symbol = symbols("r0") # We can't fuse dim=-1 - possible = int(sympify(stride).coeff(target_symbol)) != 1 - return size_match and possible - - # For prologue fusion case - if not node1.is_template() and len(node1.get_nodes())==1 and node2.is_template(): - # Return false if node2 is Convolution template - # if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ - # node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': - # return False - if node2.get_nodes()[0].node.origin_node is not None and hasattr(node2.get_nodes()[0].node.origin_node.target, "_name") and node2.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': - return False - if node1.is_reduction(): - return False - if len(node1.read_writes.writes) != 1: - return False - if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: - return True + layout_possible = int(sympify(stride).coeff(target_symbol)) != 1 + dependecy_check = base_template_node1[0].node.name in node2.node.get_read_names() and len(node2.node.get_read_names()) == 1 + return size_match and layout_possible and dependecy_check + + # For prologue fusion case + if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: + # Return false if node2 is Convolution template + # if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ + # node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': + # return False + target_node = base_template_node2[0].node + if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': + return False + if node1.is_reduction(): + return False + if len(node1.read_writes.writes) != 1: + return False + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + return True return self.scheduler.can_fuse_origin(node1, node2) From e555ab87a0c09f922f4b96a01de3a6ffaa0085b5 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 19 Jun 2025 03:02:52 +0000 Subject: [PATCH 08/62] [Frontend/Fusion] Fix prologue target buf selecting logic --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index ffca0d8c..4f3c159e 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -212,10 +212,18 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e for node in prologue_nodes: # Reuse created spad read_list = sorted(list(node.read_writes.reads)) - if reduce(operator.mul, read_list[-1].size, 1) == template_node.node.get_numel(): - prologue_input_arg = read_list[-1].name - else: - prologue_input_arg = read_list[0].name + candidate_found = False + # Why? There is a case that memdep.get_size() != data.get_size() + buf_dict = {} + buf_dict.update({val.get_name() : val for val in V.graph.graph_inputs.values()}) + buf_dict.update({val.name : val for val in V.graph.buffers}) + for candidate_read in read_list: + if reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read.name + candidate_found = True + break + assert(candidate_found) + assert(len(node.read_writes.writes)==1) prologue_output_arg = list(node.read_writes.writes)[0].name template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] if template_node.get_nodes()[0].node.origin_node.target._name == 'aten::bmm': From 3fc33e192f2833cd85caee9f6edd17702aa6a988 Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Thu, 19 Jun 2025 06:29:44 +0000 Subject: [PATCH 09/62] [Frontend] Optimize fusion tile size --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_common.py | 2 +- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 9 +++++---- PyTorchSimFrontend/mlir/mlir_template.py | 9 +++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 043cd46b..91ba9ba1 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -262,7 +262,7 @@ def render(self, TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - TILE_K = TILE_K // 4 if prologue_nodes else TILE_K + TILE_K = TILE_K // 2 if prologue_nodes else TILE_K SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index c3dc0c51..4409ee8e 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -482,7 +482,7 @@ def dummy_tile_size(): tile_size[0] = 2 * vlane_stride * self.vector_lane elif len(tile_size) == 3: tile_size[-1] = self.vector_lane - tile_size[-2] = 2 * self.vector_lane + tile_size[-2] = 4 * self.vector_lane tile_size[-3] = 2 else: raise NotImplementedError("dummy tile size fail!") diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index ed1361ff..050624db 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -12,7 +12,7 @@ from PyTorchSimFrontend import extension_config GEMM_TEMPLATE = r""" -// GEMM kernel +// GEMM {% if prologue_nodes -%}prologue fused{%- endif %} {% if epilogue_nodes -%}eilogue fused{%- endif %} kernel // M = {{ M }} // N = {{ N }} // K = {{ K }} @@ -88,7 +88,7 @@ """ GEMM_REDUCTION_TEMPLATE = r""" -// GEMM kernel +// GEMM reduction kernel // M = {{ M }} // N = {{ N }} // K = {{ K }} @@ -190,16 +190,17 @@ def render(self, if self.output_node.name in n_extra_read: n_extra_read.remove(self.output_node.name) + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 nr_rdim = 0 if (M == 0) or (N == 0) or (K == 0): TILE_M, TILE_N, TILE_K = 1, 1, 1 template = EMPTY_TEMPLATE elif n_extra_node>=1 and epilogue_nodes[0].is_reduction(): - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node, min_tile=True) + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, len(n_extra_read), n_prologue_node, min_tile=True) template = GEMM_REDUCTION_TEMPLATE nr_rdim = 1 else: - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, len(n_extra_read), min_tile=True) + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, len(n_extra_read), n_prologue_node, min_tile=True) template = GEMM_TEMPLATE TILE_M = min(extension_config.CONFIG_FORCE_TILE_M, TILE_M) TILE_N = min(extension_config.CONFIG_FORCE_TILE_N, TILE_N) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 8017a3a5..5357979b 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -155,7 +155,7 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, pad_k=True, min_tile=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False): spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer @@ -183,13 +183,14 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, pad_k=True, min_tile tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) - input_size_per_lane = self.get_spad_size_per_lane(tile_M, tile_K) + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision n_tile = math.ceil(M / tile_M) * math.ceil(N / tile_N) - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile: + check_spad_size = (used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and tile_N // tile_M < 10: max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) From 66a4c41b6b24002ffd611de5eaf04721324a3ec4 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 19 Jun 2025 07:13:12 +0000 Subject: [PATCH 10/62] [Frontend/Fusion] Update 1D load epilogue --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 8 +++++--- PyTorchSimFrontend/mlir/mlir_template.py | 10 ++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 4f3c159e..14c36dc2 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -39,13 +39,15 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction() and len(node2.get_nodes())==1: # For matmul/bmm+reduction case - size_match = reduce(operator.mul, node1.get_nodes()[0].node.get_size(), 1) == reduce(operator.mul, node2.node.get_size(), 1) * reduce(operator.mul, node2.node.get_reduction_size(), 1) + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.node.get_size(), 1) * reduce(operator.mul, node2.node.get_reduction_size(), 1) stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.node).split("\n") if "r0" in i][1] target_symbol = symbols("r0") # We can't fuse dim=-1 layout_possible = int(sympify(stride).coeff(target_symbol)) != 1 - dependecy_check = base_template_node1[0].node.name in node2.node.get_read_names() and len(node2.node.get_read_names()) == 1 - return size_match and layout_possible and dependecy_check + # Directed linked? + dependency_check = node2 in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) + return size_match and layout_possible and dependency_check & dependency_size # For prologue fusion case if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 5357979b..201e046b 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -764,16 +764,14 @@ def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.stores.writeline(line) def load_epilogue(self, name: str, index: sympy.Expr): - load_dim = [] - if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: - load_dim = V.graph.graph_inputs[name].layout.size - index_var = self.epilogue_info['index_var'] if len(load_dim) <= 1 else 'tile_n' + is_1d_source = len(index.free_symbols) == 1 + index_var = self.epilogue_info['index_var'] if not is_1d_source else 'tile_n' index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis if len(load_dim) <= 1 else 0 # FIXME: Fixed split axis for 1d load dim - vlane_stride = self.kernel_group.tile_desc.vlane_stride if len(load_dim) <= 1 else 1 # FIXME: Fixed stride for 1d load dim + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis if not is_1d_source else 0 # FIXME: Fixed split axis for 1d load dim + vlane_stride = self.kernel_group.tile_desc.vlane_stride if not is_1d_source else 1 # FIXME: Fixed stride for 1d load dim tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) tile_stride = self.epilogue_info['tile_stride'] From ee5c1a9a2c57da24b7d28d55293c7cd485ba8688 Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Thu, 19 Jun 2025 14:19:46 +0000 Subject: [PATCH 11/62] [Frontend] Welford reduction fusion debug --- PyTorchSimFrontend/mlir/mlir_template.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 201e046b..d6cdaf06 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -89,6 +89,7 @@ def __init__(self, self.reduction_buffer_idx = 0 self.reduction_info = {} self.reduction_epilogue_result = {} + self.reduction_mean = [] # Overwrite ops self.load = self.load_epilogue @@ -974,6 +975,26 @@ def store_reduction_epilogue(self, name, index, value): #init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value], dtype)} : {type_name}") #out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value], out, init, axis=0, shape=vshape, reduced_shape=final_reduced_shape)) + if self.welford_reduce_out is not None: + # mean + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(768)} : f32") + if self.buffer_types[name][1] > 1: + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") + else: + divider_vec = divider + + if self.current_node.node.origin_node: # FIXME: This is a temporary solution + # mean = E(X) / N + self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) + out = self.reduction_mean[i] + else: + # m2 = (E(X^2) - E(X)^2) * N + sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") + mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") + variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") + m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") + out = m2 + operation = "affine.vector_store" line = f"{operation} %{out}, %{sram_var}[%{index_var}] : {tile_shape}, {new_reduced_shape}" self.reductions_suffix.writeline(DeferredLine(name, line)) From 5e70202bc5d63399af665b9f58ff5ded7b16a1cd Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Thu, 19 Jun 2025 17:15:12 +0000 Subject: [PATCH 12/62] [Fix] Matmul epilogue fusion --- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 5 +++++ PyTorchSimFrontend/mlir/mlir_template.py | 10 +++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 050624db..310b92dd 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -25,6 +25,7 @@ #map1 = affine_map<(d0, d1) -> ({{ W_map }})> #map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> #map3 = affine_map<(d0, d1) -> (d0 * {{ N }})> +#map4 = affine_map<(d0, d1) -> (d0 + d1 * {{ M }})> memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> @@ -52,6 +53,7 @@ affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { %index2 = affine.apply #map2(%t_m, %t_n) %index3 = affine.apply #map2(%t_m, %t_n) + %index4 = affine.apply #map4(%t_m, %t_n) {%- if Bias %} memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} @@ -101,6 +103,7 @@ #map1 = affine_map<(d0, d1) -> ({{ W_map }})> #map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> #map3 = affine_map<(d0, d1) -> (d0 * {{ N }})> +#map4 = affine_map<(d0, d1) -> (d0 + d1 * {{ M }})> memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> @@ -128,6 +131,7 @@ {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { %index2 = affine.apply #map2(%t_m, %t_n) %index3 = affine.apply #map2(%t_m, %t_n) + %index4 = affine.apply #map4(%t_m, %t_n) {%- if Bias %} memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} @@ -283,6 +287,7 @@ def render(self, sram_var = "Y_buffer", dram_var = "Y", index_var = "index2", + t_index_var = "index4", # FIXME: for epilogue transposed input tag_var = "tag", vlane_split_axis = 1, vlane_stride = 1, diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index d6cdaf06..503bc874 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -766,7 +766,15 @@ def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): def load_epilogue(self, name: str, index: sympy.Expr): is_1d_source = len(index.free_symbols) == 1 - index_var = self.epilogue_info['index_var'] if not is_1d_source else 'tile_n' + is_transpose = False # FIXME: Only works for 2d input + if len(index.args) == 2: + for expr in index.args: + if len(expr.args): + if expr.args[1].name == "index0" and expr.args[0] > 1: + is_transpose = True + break + key = 't_index_var' if is_transpose else 'index_var' + index_var = self.epilogue_info[key] if not is_1d_source else 'tile_n' index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dtype = V.graph.get_dtype(name) From 2c67e9be80e7478bbb0d76097986a5160468e4d2 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 19 Jun 2025 10:00:13 +0000 Subject: [PATCH 13/62] [Frontend] Add a spad reuse feature in the fusion kernel --- PyTorchSimFrontend/mlir/mlir_template.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 503bc874..166bfc3c 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -90,6 +90,7 @@ def __init__(self, self.reduction_info = {} self.reduction_epilogue_result = {} self.reduction_mean = [] + self.reuse_buffer_names = {} # Overwrite ops self.load = self.load_epilogue @@ -797,9 +798,15 @@ def load_epilogue(self, name: str, index: sympy.Expr): code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, f"{name}_tag", dram_shape, tile_shape, tile_stride) self.cse.generate(self.dma_loads, code, assignment = False) + elif name in self.reuse_buffer_names: + sram_var = self.reuse_buffer_names[name] + code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + f"{name}_tag", dram_shape, tile_shape, tile_stride) + self.cse.generate(self.dma_loads, code, assignment = False) + else: + sram_var = self.buffer_names[name] # Load vector from sram - sram_var = self.buffer_names[name] zero_var = self.get_const_cse(0) if not self.reduction_fusion: compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) @@ -1023,6 +1030,9 @@ def set_tile_size(self, template_epilogue_info): vlane_split_axis=template_epilogue_info['vlane_split_axis'], vlane_stride=template_epilogue_info['vlane_stride']) + if "reuse_buffer_names" in template_epilogue_info: + self.reuse_buffer_names.update(template_epilogue_info["reuse_buffer_names"]) + if 'nr_rdim' in template_epilogue_info and template_epilogue_info['nr_rdim']==1: tile_desc.nr_rdim = 1 numel_per_lane = tile_desc.get_numel_per_lane() From c559bdceb0d92f900c3406c5e2ba7e1105c84cb6 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 20 Jun 2025 02:22:18 +0000 Subject: [PATCH 14/62] [Frontend] Fix transposed 1D bias --- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 310b92dd..b1d597a0 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -24,7 +24,7 @@ #map0 = affine_map<(d0, d1) -> ({{ X_map }})> #map1 = affine_map<(d0, d1) -> ({{ W_map }})> #map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -#map3 = affine_map<(d0, d1) -> (d0 * {{ N }})> +#map3 = affine_map<(d0, d1) -> (d0)> #map4 = affine_map<(d0, d1) -> (d0 + d1 * {{ M }})> memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> @@ -52,12 +52,10 @@ affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { %index2 = affine.apply #map2(%t_m, %t_n) - %index3 = affine.apply #map2(%t_m, %t_n) + %index3 = affine.apply #map3(%t_m, %c0) %index4 = affine.apply #map4(%t_m, %t_n) {%- if Bias %} - memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], {{ Bias_axis }}, %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } {%- else %} affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} @@ -102,7 +100,7 @@ #map0 = affine_map<(d0, d1) -> ({{ X_map }})> #map1 = affine_map<(d0, d1) -> ({{ W_map }})> #map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -#map3 = affine_map<(d0, d1) -> (d0 * {{ N }})> +#map3 = affine_map<(d0, d1) -> (d0)> #map4 = affine_map<(d0, d1) -> (d0 + d1 * {{ M }})> memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> @@ -130,12 +128,10 @@ affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { %index2 = affine.apply #map2(%t_m, %t_n) - %index3 = affine.apply #map2(%t_m, %t_n) + %index3 = affine.apply #map3(%t_m, %c0) %index4 = affine.apply #map4(%t_m, %t_n) {%- if Bias %} - memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], {{ Bias_axis }}, %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } {%- else %} affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} @@ -224,12 +220,16 @@ def render(self, if Bias is not None: if Bias.data.get_numel() == M*N: Bias_idx = "%index2" + Bias_axis = "%axis" elif Bias.data.get_numel() == M: Bias_idx = "%index3" + Bias_axis = "%axis" else: Bias_idx = "%t_n" + Bias_axis = "%c0" else: Bias_idx = None + Bias_axis = None kernel.render_options = dict( KERNEL_NAME=self.name, @@ -250,7 +250,7 @@ def render(self, Y = Y, Bias = Bias, Bias_idx = Bias_idx, - Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, + Bias_axis = Bias_axis, X_map = X_map, W_map = W_map, Y_numel = M * N, From d2aa73d8763e1b3bd26e88c1f6b063aec45a2f2a Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Fri, 20 Jun 2025 05:03:12 +0000 Subject: [PATCH 15/62] [fix] prologue fusion args & shape --- PyTorchSimFrontend/mlir/mlir_common.py | 1 + PyTorchSimFrontend/mlir/mlir_scheduling.py | 9 ++++----- PyTorchSimFrontend/mlir/mlir_template.py | 19 ++++++++++++------- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 4409ee8e..29ef65c9 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -342,6 +342,7 @@ def __init__(self, kernel_group, reason=None): self.buffer_types : dict = None # format: dtype, numel, size, stride self.compute_idx = "compute_idx" self.compute_body_loop = LoopLevel(self.compute_idx, 1) + self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 14c36dc2..307d5afe 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -200,10 +200,10 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e kernel.kernel_group.set_tile_info(tile_desc) if prologue_nodes: _, (group, reduction_group) = max( - prologue_nodes, key=lambda x: int(x.is_reduction()) + [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) ).group - tile_desc = kernel.set_tile_size(kernel.prologue_info) - kernel.kernel_group.set_prologue_tile_info(tile_desc) + prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) + kernel.kernel_group.set_prologue_tile_info(prologue_tile_desc) vars, reduction_vars = kernel.set_ranges(group, reduction_group) # Flush created varaibles, since template fusion doen't share variable kernel.cse.cache.clear() @@ -217,10 +217,9 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e candidate_found = False # Why? There is a case that memdep.get_size() != data.get_size() buf_dict = {} - buf_dict.update({val.get_name() : val for val in V.graph.graph_inputs.values()}) buf_dict.update({val.name : val for val in V.graph.buffers}) for candidate_read in read_list: - if reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): + if candidate_read.name in buf_dict and reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): prologue_input_arg = candidate_read.name candidate_found = True break diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 166bfc3c..3334d991 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -363,7 +363,7 @@ def codegen_prologue_body(self): buf.body.splice(buf.dma_loads) if (buf.loads.getvalue() != '' or buf.compute.getvalue() != '' or buf.stores.getvalue() != ''): - buf.body.writelines(self.compute_body_loop.lines()) + buf.body.writelines(self.prologue_compute_body_loop.lines()) compute_body = mlir_common.ParallelLoopBuffer() with contextlib.ExitStack() as stack: stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) @@ -688,19 +688,20 @@ def load_prologue(self, name: str, index: sympy.Expr): load_dim = [] if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: load_dim = V.graph.graph_inputs[name].layout.size - if self.kernel_group.prologue_tile_desc.get_numel() == self.buffer_types[name][1]: + if self.ranges == self.buffer_types[name][2]: index_var = self.prologue_info['input_index_var'] if len(load_dim) != 1 else 'tile_n' + vlane_split_axis = self.kernel_group.prologue_tile_desc.vlane_split_axis if len(load_dim) != 1 else 0 # FIXME: Fixed split axis for 1d load dim else: # Broadcast pattern zero_index = self.const_cse.generate(self.const_buffer, "arith.constant 0 : index") if self.prologue_info['is_bmm']: # FIXME: hardcoded idx = f"%b, %t_k, %t_n" map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1, d2) -> (d0 * 512 + d2)>") - vlane_split_axis = 2 + vlane_split_axis = 2 # 3D GEMM prologue should be loaded by axis 2 else: idx = f"%t_m, %{zero_index}" map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1) -> (d0)>") - vlane_split_axis = self.kernel_group.prologue_tile_desc.vlane_split_axis if len(load_dim) != 1 else 0 # FIXME: Fixed split axis for 1d load dim + vlane_split_axis = 1 # 2D GEMM prologue should be loaded by axis 1 index_var = self.apply_cse.generate(self.dma_loads, f"affine.apply #{map_var}({idx})") index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) @@ -1024,7 +1025,7 @@ def store_reduction_epilogue(self, name, index, value): def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, buffer=None): return super().get_scratchpad_buffer(dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, True, buffer=buffer) - def set_tile_size(self, template_epilogue_info): + def set_tile_size(self, template_epilogue_info, prologue=False): tile_desc = mlir_common.MLIRMultiDimTile(template_epilogue_info['tile_size'], self.vector_lane, vlane_split_axis=template_epilogue_info['vlane_split_axis'], @@ -1050,8 +1051,12 @@ def set_tile_size(self, template_epilogue_info): self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: tile_desc.vec_size=64 - self.compute_body_loop.size = tile_desc.get_numel_per_lane() - self.compute_body_loop.step = tile_desc.get_compute_vec_size() + if prologue: + self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() + self.prologue_compute_body_loop.step = tile_desc.get_compute_vec_size() + else: + self.compute_body_loop.size = tile_desc.get_numel_per_lane() + self.compute_body_loop.step = tile_desc.get_compute_vec_size() return tile_desc class MLIRTemplateCaller(CUDATemplateCaller): From fa7e57a293ac5e197111d24daf2c08e70a98c3b0 Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Sat, 21 Jun 2025 05:54:28 +0000 Subject: [PATCH 16/62] [fix] prologue prohibit subtile --- PyTorchSimFrontend/mlir/mlir_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 3334d991..ed2eb504 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -550,7 +550,7 @@ def emit_dma_start(buffer_name, index_var, tag_var, size, tile_size, subtile_siz tile_memref = f"memref<{tile_shape}xf32, 1>" tag_memref = f"memref<1xi32>" attrs = f"sram_stride=[1, {tile_size[0]}]" - async_flag = "true" if async_flag else "false" + async_flag = "false" if subtile_size: subtile_shape = ", ".join([str(x) for x in subtile_size]) attrs = f"subtile_size=[{subtile_shape}], async={async_flag}, {attrs}" @@ -567,9 +567,9 @@ def hook(): if prologue_code.getvalue(): code.writeline(emit_dma_start(self.prologue_info["input_sram_var"], self.prologue_info["input_index_var"], self.prologue_info["input_tag_var"], self.prologue_info["input_numel"], self.prologue_info["input_tile_size"], subtile_size=self.prologue_info["input_subtile_size"], label="X")) + code.splice(prologue_code) code.writeline(emit_dma_start(self.prologue_info["weight_sram_var"], self.prologue_info["weight_index_var"], self.prologue_info["weight_tag_var"], self.prologue_info["weight_numel"], self.prologue_info["weight_tile_size"], subtile_size=self.prologue_info["weight_subtile_size"], label="W")) - code.splice(prologue_code) else: code.writeline(emit_dma_start(self.prologue_info["input_sram_var"], self.prologue_info["input_index_var"], self.prologue_info["input_tag_var"], self.prologue_info["input_numel"], self.prologue_info["input_tile_size"], self.prologue_info["input_subtile_size"], async_flag=True, label="X")) From 75de3d4939f95f390c7d387c55dd18f56f181eb4 Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Sat, 21 Jun 2025 10:17:49 +0000 Subject: [PATCH 17/62] [Validation] manual gemm tile size & fix tiling for double buffering --- PyTorchSimFrontend/extension_config.py | 8 +++++ .../mlir/mlir_codegen_backend.py | 2 +- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 31 ++++++++++++++--- PyTorchSimFrontend/mlir/mlir_template.py | 34 ++++++++++++++++--- validation/gemm_tpuv3_cheatsheet.json | 17 ++++++++++ 5 files changed, 82 insertions(+), 10 deletions(-) create mode 100644 validation/gemm_tpuv3_cheatsheet.json diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 17fa74d9..e461cc85 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -57,6 +57,14 @@ CONFIG_FORCE_TILE_N = int(os.environ.get("TORCHSIM_FORCE_TIME_N", default=sys.maxsize)) CONFIG_FORCE_TILE_K = int(os.environ.get("TORCHSIM_FORCE_TIME_K", default=sys.maxsize)) +# For GEMM tile size +CONFIG_MANUAL_TILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_TILE_SIZE', default=False)) +CONFIG_TILE_M = int(os.environ.get('TORCHSIM_TILE_M', default=CONFIG_VECTOR_LANE)) +CONFIG_TILE_N = int(os.environ.get('TORCHSIM_TILE_N', default=CONFIG_VECTOR_LANE)) +CONFIG_TILE_K = int(os.environ.get('TORCHSIM_TILE_K', default=CONFIG_VECTOR_LANE)) +CONFIG_GEMM_CHEATSHEET_PATH = os.environ.get('TORCHSIM_GEMM_CHEATSHEET_PATH', + default=f"{CONFIG_TORCHSIM_DIR}/validation/gemm_tpuv3_cheatsheet.json") + # SRAM Buffer allocation plan def load_plan_from_module(module_path): if module_path is None: diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 1272a46e..d091b3eb 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1393,7 +1393,7 @@ def make_choices(self, nodes, kernel_name): for vlane_stride in [2, 4, 8]: os.environ['TORCHSIM_VECTOR_LANE_STRIDE'] = str(vlane_stride) previous_tile_size = initial_tile_size - increase_dim = 0 # increase the first dimension + increase_dim = -2 # increase the first dimension while previous_tile_size[increase_dim] * 2 <= previous_ranges[increase_dim] and previous_tile_size[increase_dim] <= 2 ** 13 and prevent_infinite_loop < 10: incrase_dim = -1 # only increase the last dimension prevent_infinite_loop += 1 diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index b1d597a0..a70efb21 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -1,4 +1,5 @@ import os +import json from torch import empty_strided from typing import List, Optional, cast @@ -192,20 +193,42 @@ def render(self, n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 nr_rdim = 0 - if (M == 0) or (N == 0) or (K == 0): + # Determine tile size + # case 1: use cheat sheet + if extension_config.CONFIG_GEMM_CHEATSHEET_PATH is not None: + try: + with open(extension_config.CONFIG_GEMM_CHEATSHEET_PATH, "r") as f: + data = json.load(f) + except FileNotFoundError: + data = {} + gemm_shape = f"{M}_{K}_{N}" + if gemm_shape in data: + tile_info = data[gemm_shape] + TILE_M = tile_info["TILE_M"] + TILE_N = tile_info["TILE_N"] + TILE_K = tile_info["TILE_K"] + else: # case 2: use gemm_combination_mapping + min_tile = (n_extra_node + n_prologue_node) == 0 + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=min_tile) + # case 3: use manual tile size + if extension_config.CONFIG_MANUAL_TILE_SIZE: + TILE_M = extension_config.CONFIG_TILE_M + TILE_N = extension_config.CONFIG_TILE_N + TILE_K = extension_config.CONFIG_TILE_K + + if (M == 0) or (N == 0) or (K == 0): # exception for MoE TILE_M, TILE_N, TILE_K = 1, 1, 1 template = EMPTY_TEMPLATE elif n_extra_node>=1 and epilogue_nodes[0].is_reduction(): - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, len(n_extra_read), n_prologue_node, min_tile=True) template = GEMM_REDUCTION_TEMPLATE nr_rdim = 1 else: - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, len(n_extra_read), n_prologue_node, min_tile=True) template = GEMM_TEMPLATE + TILE_M = min(extension_config.CONFIG_FORCE_TILE_M, TILE_M) TILE_N = min(extension_config.CONFIG_FORCE_TILE_N, TILE_N) TILE_K = min(extension_config.CONFIG_FORCE_TILE_K, TILE_K) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane if (TILE_M == M and TILE_N == N): SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane else: # Avoid Row Conflict of weights diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index ed2eb504..a72342bc 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -23,6 +23,7 @@ from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode +from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR from . import mlir_common class IndentedBufferGroup: @@ -162,8 +163,7 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer max_spad_per_lane = spad_size_per_lane // 2 # double buffer - force_double_buffer = 2 if n_extra_node > 0 else 1 # In fusion case, double buffer should be forced - minimum_n_tile = self.num_cores * force_double_buffer if min_tile else 1 + minimum_n_tile = self.num_cores * 2 if min_tile else 1 m_pad_factor = self.vector_lane if M > self.vector_lane else 8 n_pad_factor = self.vector_lane if N > self.vector_lane else 8 k_pad_factor = self.vector_lane if K > self.vector_lane else (8 if pad_k else 1) @@ -179,7 +179,31 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_N_range = sympy.divisors(indexJ) if N > self.vector_lane else [1] tile_K_range = sympy.divisors(indexK) if K > self.vector_lane else [1] maximize_i_j = 1 # reuse weight - for k in tile_K_range: + for k in tile_K_range: # store tile candidates for manual mapping + tile_K = k * self.vector_lane if K > self.vector_lane else K_padded + for i in tile_M_range: + tile_M = i * self.vector_lane if M > self.vector_lane else M_padded + for j in tile_N_range: + tile_N = j * self.vector_lane if N > self.vector_lane else N_padded + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) + output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + file_path = f"{CONFIG_TORCHSIM_DIR}/validation/gemm_candidates/gemm_{M}_{K}_{N}.txt" + line_to_write = f"{tile_M} {tile_K} {tile_N}\n" + try: + with open(file_path, "r") as f: + lines = f.readlines() + except FileNotFoundError: + lines = [] + if line_to_write not in lines: + with open(file_path, "a") as f: + f.write(line_to_write) + + for k in tile_K_range: # heuristic search tile_K = k * self.vector_lane if K > self.vector_lane else K_padded for i in tile_M_range: tile_M = i * self.vector_lane if M > self.vector_lane else M_padded @@ -191,8 +215,8 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision n_tile = math.ceil(M / tile_M) * math.ceil(N / tile_N) - check_spad_size = (used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane) - if check_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and tile_N // tile_M < 10: + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and tile_N // tile_M < 10: max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) diff --git a/validation/gemm_tpuv3_cheatsheet.json b/validation/gemm_tpuv3_cheatsheet.json new file mode 100644 index 00000000..76a26e1a --- /dev/null +++ b/validation/gemm_tpuv3_cheatsheet.json @@ -0,0 +1,17 @@ +{ + "512_2048_8192" : { + "TILE_M" : 512, + "TILE_K" : 512, + "TILE_N" : 1024 + }, + "512_2048_2048" : { + "TILE_M" : 512, + "TILE_K" : 512, + "TILE_N" : 1024 + }, + "2048_2048_512" : { + "TILE_M" : 1024, + "TILE_K" : 512, + "TILE_N" : 512 + } +} \ No newline at end of file From ed77bbaf0b1fb99973b953b3dbe9638350a03929 Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Mon, 30 Jun 2025 04:07:40 +0000 Subject: [PATCH 18/62] [experiments] FG DMA experiments --- PyTorchSimFrontend/extension_config.py | 5 ++++ PyTorchSimFrontend/mlir/mlir_gemm_template.py | 26 +++++++++++++---- scripts/CompilerOpt_experiment/DMAopt.sh | 28 +++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 scripts/CompilerOpt_experiment/DMAopt.sh diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index e461cc85..15413103 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -64,6 +64,11 @@ CONFIG_TILE_K = int(os.environ.get('TORCHSIM_TILE_K', default=CONFIG_VECTOR_LANE)) CONFIG_GEMM_CHEATSHEET_PATH = os.environ.get('TORCHSIM_GEMM_CHEATSHEET_PATH', default=f"{CONFIG_TORCHSIM_DIR}/validation/gemm_tpuv3_cheatsheet.json") +CONFIG_SUBTILE = int(os.environ.get('TORCHSIM_SUBTILE', default=True)) +CONFIG_MANUAL_SUBTILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_SUBTILE_SIZE', default=False)) +CONFIG_SUBTILE_M = int(os.environ.get('TORCHSIM_SUBTILE_M', default=CONFIG_VECTOR_LANE)) +CONFIG_SUBTILE_N = int(os.environ.get('TORCHSIM_SUBTILE_N', default=CONFIG_VECTOR_LANE)) +CONFIG_SUBTILE_K = int(os.environ.get('TORCHSIM_SUBTILE_K', default=CONFIG_VECTOR_LANE)) # SRAM Buffer allocation plan def load_plan_from_module(module_path): diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index a70efb21..3ac8154a 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -228,13 +228,27 @@ def render(self, TILE_M = min(extension_config.CONFIG_FORCE_TILE_M, TILE_M) TILE_N = min(extension_config.CONFIG_FORCE_TILE_N, TILE_N) TILE_K = min(extension_config.CONFIG_FORCE_TILE_K, TILE_K) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane - if (TILE_M == M and TILE_N == N): - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - else: # Avoid Row Conflict of weights + + # Calculate Sub Tile Size for fine-grained DMA + if extension_config.CONFIG_SUBTILE: + # Case 1: adjust selective fine-grained DMA (SFG-DMA) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane + if (TILE_M == M and TILE_N == N and TILE_N <= 512): + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + else: # Avoid Row Conflict of weights + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K if TILE_K > 1024 else kernel.vector_lane + # Case 2: use manual sub tile size (FG-DMA) + if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: + SUB_TILE_M = extension_config.CONFIG_SUBTILE_M + SUB_TILE_N = extension_config.CONFIG_SUBTILE_N + SUB_TILE_K = extension_config.CONFIG_SUBTILE_K + # Case 3: None Subtile + else: + SUB_TILE_M = TILE_M SUB_TILE_N = TILE_N - SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N # FIXME: hardcoded & 126 line has same feature - SUB_TILE_K = TILE_K + SUB_TILE_K = TILE_K + TOG_latency = M if SUB_TILE_M > M else SUB_TILE_M kernel.loop_size =[TOG_latency, SUB_TILE_N, SUB_TILE_K] diff --git a/scripts/CompilerOpt_experiment/DMAopt.sh b/scripts/CompilerOpt_experiment/DMAopt.sh new file mode 100644 index 00000000..469cf766 --- /dev/null +++ b/scripts/CompilerOpt_experiment/DMAopt.sh @@ -0,0 +1,28 @@ +#!/bin/bash +export TORCHSIM_CONFIG="/root/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_128x128_c2_simple_noc_tpuv2.json" + +# None FG DMA +export TORCHSIM_SUBTILE=0 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 + +# FG DMA +export TORCHSIM_SUBTILE=1 +export TORCHSIM_MANUAL_SUBTILE_SIZE=1 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 + +# SFG DMA +export TORCHSIM_SUBTILE=1 +export TORCHSIM_MANUAL_SUBTILE_SIZE=0 +python experiments/gemm.py --size 128 128 128 +python experiments/gemm.py --size 256 256 256 +python experiments/gemm.py --size 512 512 512 +python experiments/gemm.py --size 1024 1024 1024 +python experiments/gemm.py --size 2048 2048 2048 \ No newline at end of file From db2d505ef8a0827ff986a669a8e1be4f2370673a Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Mon, 30 Jun 2025 04:08:33 +0000 Subject: [PATCH 19/62] [Fix] prohibit multi-thread for CI --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index d091b3eb..fee5702a 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1442,7 +1442,7 @@ def get_cycle(choice): if len(choices) == 0: # can't autotune return None - with ThreadPoolExecutor(max_workers=5) as executor: + with ThreadPoolExecutor(max_workers=1) as executor: results = list(executor.map(get_cycle, choices)) max_idx = results.index(min(results)) print(f"[Auto-tune] Optimal tile size: {choices[max_idx][2].tile_desc.get_tile_size()}, vlane_stride: {choices[max_idx][2].tile_desc.vlane_stride}, cycles: {results[max_idx]}") From 3e6daf3e27db24917afbf993f5b394ff34bfd66a Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Mon, 30 Jun 2025 04:31:34 +0000 Subject: [PATCH 20/62] [Fix] minimum tile size and subtile K --- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_template.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 3ac8154a..bfd0633b 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -237,7 +237,7 @@ def render(self, SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane else: # Avoid Row Conflict of weights SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K if TILE_K > 1024 else kernel.vector_lane + SUB_TILE_K = TILE_K # Case 2: use manual sub tile size (FG-DMA) if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: SUB_TILE_M = extension_config.CONFIG_SUBTILE_M diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index a72342bc..1db14e27 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -163,7 +163,7 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer max_spad_per_lane = spad_size_per_lane // 2 # double buffer - minimum_n_tile = self.num_cores * 2 if min_tile else 1 + minimum_n_tile = self.num_cores if min_tile else 1 m_pad_factor = self.vector_lane if M > self.vector_lane else 8 n_pad_factor = self.vector_lane if N > self.vector_lane else 8 k_pad_factor = self.vector_lane if K > self.vector_lane else (8 if pad_k else 1) @@ -214,9 +214,9 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - n_tile = math.ceil(M / tile_M) * math.ceil(N / tile_N) + n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) - if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and tile_N // tile_M < 10: + if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) From 29ee378a24999e3d5a4e3adb1857ec2cb22f24b1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 30 Jun 2025 04:55:44 +0000 Subject: [PATCH 21/62] [Frontend] Make fusion optionable --- PyTorchSimFrontend/extension_config.py | 4 ++++ PyTorchSimFrontend/mlir/mlir_scheduling.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 15413103..b8c2b3b4 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -70,6 +70,10 @@ CONFIG_SUBTILE_N = int(os.environ.get('TORCHSIM_SUBTILE_N', default=CONFIG_VECTOR_LANE)) CONFIG_SUBTILE_K = int(os.environ.get('TORCHSIM_SUBTILE_K', default=CONFIG_VECTOR_LANE)) +# Advanced fusion options +CONFIG_FUSION_REDUCTION = int(os.environ.get('TORCHSIM_FUSION_REDUCTION', default=False)) +CONFIG_FUSION_PROLOGUE = int(os.environ.get('TORCHSIM_FUSION_PROLOGUE', default=False)) + # SRAM Buffer allocation plan def load_plan_from_module(module_path): if module_path is None: diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 307d5afe..a526d17c 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -34,7 +34,7 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if node1.get_device() != node2.get_device(): return False - if len(base_template_node1) == 1 and len(base_template_node2) == 0: + if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION: from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction() and len(node2.get_nodes())==1: @@ -50,7 +50,7 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return size_match and layout_possible and dependency_check & dependency_size # For prologue fusion case - if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: + if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: # Return false if node2 is Convolution template # if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ # node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': From af6e63d633e9bd418241408173f877fd57f3bf03 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 30 Jun 2025 08:27:22 +0000 Subject: [PATCH 22/62] [Frontend] Use kernel name from define_kernel --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index a526d17c..41264a74 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -138,10 +138,10 @@ def codegen_nodes(self, nodes): ex_kernel = self.target_kernel(kernel_group=self.kernel_group) ex_kernel.kernel_group = self.kernel_group - kernel_name = f"extension_kernel_{MLIRScheduling.count}" + kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name) - self.define_kernel(src_code, kernel_name, ex_kernel.vector_lane, + src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() From b82bf94566ebd74763d7d26dab9dd13220239daa Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 30 Jun 2025 08:29:35 +0000 Subject: [PATCH 23/62] [Frontend] Don't use buffer's unique name to reuse kernels --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index fee5702a..f6fe6a76 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -871,6 +871,7 @@ def __init__(self, kernel_group, reason=None): self.tags = dict() self.dma_read_cache = dict() self.dma_write_cache = dict() + self.spadbuf_counter = 0 self.dma_read_counter = 1 self.dma_write_counter = 1 self.affine_yield = {} @@ -958,10 +959,11 @@ def load(self, name: str, index: sympy.Expr): index = self.convert_indirect_indexing(index) padding = self.get_padding_type() dram_var = self.kernel_group.args.input(name) - dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + local_tile_desc, index_var = self.get_dma_info(name, index) + vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride tile_numel_per_lane = local_tile_desc.get_numel_per_lane() @@ -1305,7 +1307,7 @@ def index_expr(self, index, dtype): self.header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__ ((section(\".spad\")));") self.gem5_header.writeline(f"{c_type} {new_name}_spad[{compute_vec_size}] __attribute__((aligned(64)));") self.global_vars.writeline(f"memref.global @{new_name}_spad : {tile_shape}") - self.global_vars_dict[new_name] = [] + self.global_vars_dict[new_name] = dict() sram_var = self.spad_cse.generate(self.spad_buffer, f"memref.get_global @{new_name}_spad : {tile_shape}") # Initialize base vector if not self.base_vector_initialized: @@ -1673,17 +1675,18 @@ def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape buffer = self.spad_buffer if name not in self.global_vars_dict: - self.global_vars_dict[name] = list() + self.global_vars_dict[name] = dict() if str(raw_index) not in self.global_vars_dict[name]: - new_name = f"{name}_{len(self.global_vars_dict[name])}" + new_name = f"buf{self.spadbuf_counter}" + self.spadbuf_counter+=1 # Add definition to header self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}] __attribute__((aligned(64)));") self.global_vars.writeline(f"memref.global @{new_name}_spad : {dram_tile_shape}") - self.global_vars_dict[name].append(str(raw_index)) + self.global_vars_dict[name][str(raw_index)] = new_name else: - new_name = f"{name}_{self.global_vars_dict[name].index(str(raw_index))}" + new_name = self.global_vars_dict[name][str(raw_index)] sram_var = self.spad_cse.generate(buffer, f"memref.get_global @{new_name}_spad : {dram_tile_shape}") zero_cse = self.get_const_cse(0) From c915f343473949f781e24d596da2d057f2ec7ab7 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 30 Jun 2025 08:32:34 +0000 Subject: [PATCH 24/62] [Frontend] Add manual tile_stride for DimTile --- PyTorchSimFrontend/mlir/mlir_common.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 29ef65c9..92af0570 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -180,8 +180,10 @@ def set_info(outer, inner, arg_type): class MLIRMultiDimTile(): def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=None, vec_size=None): self._tile_size = list(tile_size) + self._tile_stride = None self.tile_axis_order = list(range(len(tile_size))) self.vec_size = vec_size + self.update_tile_stride() # Vector lane mapping config self.vector_lane = vector_lane @@ -196,6 +198,11 @@ def set_tile_size(self, tile_size, tile_axis_order=None): self.tile_axis_order = list(range(len(tile_size))) else: self.tile_axis_order = tile_axis_order + self.update_tile_stride() + + def set_tile_size_stride(self, tile_size, tile_stride): + self._tile_size = tile_size + self._tile_stride = tile_stride def get_tile_size(self): return self._tile_size @@ -216,7 +223,7 @@ def get_numel_per_lane(self): size *= dim_size return size - def get_tile_stride(self): + def update_tile_stride(self): strides = [1] * len(self._tile_size) init = 1 @@ -228,7 +235,10 @@ def get_tile_stride(self): for _, size, original_indices in sorted_pairs: strides[original_indices] = init init *= size - return strides + self._tile_stride = strides + + def get_tile_stride(self): + return self._tile_stride def get_tile_size_per_lane(self): tile_size_per_lane = list(self._tile_size) From 491b91111e0ba987a965140b06724fd5422df1af Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 1 Jul 2025 06:33:24 +0000 Subject: [PATCH 25/62] [Frontend] Add utility method for kernel class --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index f6fe6a76..6f38b08a 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1477,6 +1477,12 @@ def _prepare_simulator_headers(self, src_code): write_atomic(spike_write_path, self.header.getvalue() + spad_end_symbol + spad_section_end_symbol) write_atomic(gem5_write_path, self.gem5_header.getvalue()) + def get_arg_info(self, name): + arg_info = dict() + arg_info.update(V.graph.graph_inputs) + arg_info.update({i.get_name(): i for i in V.graph.buffers}) + return arg_info[name] + def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffer=None): # Need more argument? """ A tile descriptor exists that is configured on a kernel group From d86dc3adddbe35969f3ac2faa1d8baedddc43043 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 3 Jul 2025 05:43:57 +0000 Subject: [PATCH 26/62] [Frontend/Template] Rework template codegen There're a lot of changes. Fusion mechanism is refactored. Major changes is that keep consistency with template and fusion nodes To do this, I changed the loop order, and added Revert() function to revert squeezed size of point/reduction nodes --- PyTorchSimFrontend/common_diff.py | 1031 ----------------- PyTorchSimFrontend/extension_config.py | 6 +- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 372 +++--- .../mlir/mlir_codegen_backend.py | 164 ++- PyTorchSimFrontend/mlir/mlir_common.py | 10 +- .../mlir/mlir_conv_mt_template.py | 346 ++++++ .../mlir/mlir_conv_sb_template.py | 342 ++++++ .../mlir/mlir_conv_sbs_template.py | 343 ++++++ PyTorchSimFrontend/mlir/mlir_conv_template.py | 669 ++--------- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 422 ++++--- PyTorchSimFrontend/mlir/mlir_lowering.py | 17 +- .../mlir/mlir_maxpool_template.py | 25 +- PyTorchSimFrontend/mlir/mlir_scheduling.py | 179 +-- PyTorchSimFrontend/mlir/mlir_template.py | 602 +++++----- 14 files changed, 2009 insertions(+), 2519 deletions(-) delete mode 100644 PyTorchSimFrontend/common_diff.py create mode 100644 PyTorchSimFrontend/mlir/mlir_conv_mt_template.py create mode 100644 PyTorchSimFrontend/mlir/mlir_conv_sb_template.py create mode 100644 PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py diff --git a/PyTorchSimFrontend/common_diff.py b/PyTorchSimFrontend/common_diff.py deleted file mode 100644 index 6c1c875c..00000000 --- a/PyTorchSimFrontend/common_diff.py +++ /dev/null @@ -1,1031 +0,0 @@ -import contextlib -import dataclasses -import functools -import itertools -import logging -import operator -import re -from collections import namedtuple -from itertools import chain -from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Set, Union - -import sympy -from sympy.printing.printer import Printer - -import torch -import torch.fx -from torch.utils._sympy.value_ranges import ValueRanges - -from .. import metrics -from ..utils import ( - DeferredLineBase, - free_symbol_startswith, - get_sympy_Expr_dtype, - IndentedBuffer, - sympy_dot, - sympy_subs, - unique, -) -from ..virtualized import ops, OpsValue, V - -schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") - - -def data_type_logger(msg): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Data type propagation: %s", msg) - - -TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"]) -SizeArg = namedtuple("SizeArg", ["name", "expr"]) - -DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"]) -device_codegens: Dict[str, DeviceCodegen] = {} - - -# The code generated by Inductor consists of two main parts: kernel code and wrapper code. -# For any new backend looking to integrate with Inductor, customization of these two main -# parts are necessary to generate its specific code. -# -# Kernel code generation is determined by different Scheduling. Consequently, a new -# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, -# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. -# -# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code -# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, -# and override specific member functions to create backend-specific Python wrapper code. -# -# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part -# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces -# provide flexibility to the backend. A backend can choose to implement these classes from scratch, -# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, -# register_backend_for_device, to equip a new backend at runtime. -# -# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. -# This backend can be used as a reference: -# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 -def register_backend_for_device( - device: str, device_scheduling: type, device_wrapper_codegen: type -): - device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen) - - -def get_scheduling_for_device(device: str): - return device_codegens[device].scheduling if device in device_codegens else None - - -def get_wrapper_codegen_for_device(device: str): - return ( - device_codegens[device].wrapper_codegen if device in device_codegens else None - ) - - -def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): - from ..ir import FlexibleLayout - - # added contiguous index prevents reordering - return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] - - -@functools.lru_cache(None) -def boolean_ops(): - return ( - "is_inf", - "is_nan", - "bitwise_xor", - "logical_not", - "signbit", - "le", - "lt", - "ge", - "gt", - "eq", - "ne", - ) - - -DTYPE_TO_COMPUTATION_DTYPE = { - torch.bfloat16: torch.float, - torch.float16: torch.float, - **{ - dtype: dtype - for dtype in [ - torch.bool, - torch.float32, - torch.float64, - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - ] - }, -} - - -class DataTypePropagation: - def __init__(self, body) -> None: - self.body = body - self.graphs: Dict[Union[Callable[..., Any], str], Any] = { - "root": body.root_block.graph - } - for k, v in body.subblocks.items(): - self.graphs[k] = v.graph - - def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): - inputs = node.all_input_nodes - input_nodes = [ - n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" - ] - if len(input_nodes) == 0: - return None - - all_input_nodes_propogated = all( - OptimizationContext.key in n.meta - and n.meta[OptimizationContext.key].dtype is not None - for n in input_nodes - ) - if not all_input_nodes_propogated: - return None - - return functools.reduce( - torch.promote_types, - [n.meta[OptimizationContext.key].dtype for n in input_nodes], - ) - - def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): - sub_graph = self.graphs[node.target] - dtype = self.propagate_graph(sub_graph) - assert dtype - return dtype - - def deduce_node_dtype(self, node: torch.fx.Node): - if node.target in boolean_ops(): - return torch.bool - - if node.op == "placeholder": - return None - - if node.target == "output": - # we can infer output node if it only have 1 arg - if len(node.args) != 1: - return None - - if node.target in ( - "to_dtype", - "index_expr", - ): - return node.args[-1] - - if node.target in ( - "rand", - "randn", - ): - return torch.float - - if node.target in ( - "get_index", - "index_expr", - ): - return torch.int64 - - if node.target in ( - "load", - "store", - "store_reduction", - ): - buf_name = node.args[1] - return V.graph.get_dtype(buf_name) - - if node.target == operator.getitem: - return self.deduce_node_dtype(node.args[0]) - - assert isinstance(node.target, str) - - if node.target == "reduction": - return node.args[1] - - if node.target == "constant": - return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] - - if node.target.startswith("masked_subblock"): - return self.deduce_node_dtype_by_subgraph(node) - - return self.deduce_node_dtype_by_inputs(node) - - def propagate_graph(self, graph: torch.fx.Graph): - assert graph.nodes - graph_dtype = None - # For masked_subblock, we use output's dtype to represent - # the dtype of this subgraph. For other cases, graph_dtype - # might be None - for node in graph.nodes: - if OptimizationContext.key in node.meta: - opt_ctx = node.meta[OptimizationContext.key] - else: - opt_ctx = OptimizationContext() - - opt_ctx.dtype = self.deduce_node_dtype(node) - node.meta[OptimizationContext.key] = opt_ctx - if node.target == "output": - graph_dtype = opt_ctx.dtype - return graph_dtype - - def propagate(self): - self.propagate_graph(self.graphs["root"]) - - @classmethod - def propagate_loopbody(cls, body): - return cls(body).propagate() - - @classmethod - def propagate_scheduler_node(cls, node): - from ..ir import LoopBody - from ..scheduler import SchedulerNode - - assert isinstance(node, SchedulerNode) - assert isinstance(node._body, LoopBody) - DataTypePropagation.propagate_loopbody(node._body) - - -class ExprPrinter(Printer): - @staticmethod - def paren(string): - def all_in_parens(string): - if string[0] != "(" or len(string) < 2: - return False - count = 1 - for i, char in enumerate(string[1:]): - if char == "(": - count += 1 - elif char == ")": - count -= 1 - if count == 0 and i != len(string) - 2: - return False - assert count == 0 - return True - - if ( - isinstance(string, CSEVariable) - or re.match(r"^[a-z0-9_.]+$", string, re.I) - or re.match(r"^\([^)]*\)$", string, re.I) - or string == "" - ): - return string - # don't put extra parens for strings that are already wrapped in parens - if all_in_parens(string): - return string - return f"({string})" - - def _print_Pow(self, expr): - # Pow() confuses triton - base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) # type: ignore[attr-defined] - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) # type: ignore[attr-defined] - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" - - def _print_Unequality(self, expr): - return " != ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mul(self, expr): - return "*".join(map(self.paren, map(self._print, expr.args))) - - def _print_Add(self, expr): - return " + ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_CleanDiv(self, expr): - return self._print_FloorDiv(expr) # type: ignore[attr-defined] - - def _print_GreaterThan(self, expr): - # GreaterThan: >= - # StrictlyGreaterThan: > - # Go figure... - return " >= ".join(map(self.paren, map(self._print, expr.args))) - - -class PythonPrinter(ExprPrinter): - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - mod = self.paren(self.doprint(mod)) - if div != "1": - x = f"({x} // {div})" - return f"{x} % {mod}" - - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"({x} // {div})" - - def _helper_sqrt(self, expr): - return f"math.sqrt({self._print(expr)})" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - return f"math.floor({self._print(expr.args[0])})" - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - return f"math.ceil({self._print(expr.args[0])})" - - -class OpOverrides: - def __init__(self, parent): - super().__init__() - self._parent = parent - - def __getattr__(self, item): - return getattr(self._parent, item) - - @staticmethod - def identity(value): - # used to trigger cse - return value - - @staticmethod - def constant(value, dtype): - return repr(value) - - @staticmethod - def reciprocal(x): - return ops.div("1", x) - - @staticmethod - def square(x): - return ops.mul(x, x) - - @staticmethod - def bitwise_not(x): - return f"~{ExprPrinter.paren(x)}" - - @staticmethod - def logical_not(a): - return f"{ExprPrinter.paren(a)} == 0" - - @staticmethod - def bitwise_and(x, y): - return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_or(x, y): - return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_xor(x, y): - return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" - - @staticmethod - def bitwise_left_shift(x, y): - return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" - - # TODO(fdrocha): this is currently not being used anywhere, - # pending on moving triton pin past 972b761 - @staticmethod - def bitwise_right_shift(x, y): - return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" - - @staticmethod - def remainder(a, b): - r = ops.mod(a, b) - return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) - - @staticmethod - def load_seed(name, offset): - return ops.load(name, sympy.Integer(offset)) - - -class DeferredLine(DeferredLineBase): - """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" - - def __init__(self, name, line): - super().__init__(line) - self.name = name - - def __call__(self): - if ( - self.name not in V.graph.removed_buffers - and self.name not in V.graph.inplaced_to_remove - ): - return self.line - return None - - def _new_line(self, line): - return DeferredLine(self.name, line) - - -class BracesBuffer(IndentedBuffer): - def indent(self, offset=1): - @contextlib.contextmanager - def ctx(): - for _ in range(offset): - self.writeline("{") - self._indent += 1 - for _ in range(-offset): - self._indent -= 1 - self.writeline("}") - yield - for _ in range(-offset): - self.writeline("{") - self._indent += 1 - for _ in range(offset): - self._indent -= 1 - self.writeline("}") - - return ctx() - - -class InplacedBuffer(NamedTuple): - inner_name: str - other_names: List[str] - - -class KernelArgs: - @staticmethod - def _lookup(prefix, odict, name): - assert isinstance(name, (str, sympy.Symbol)) - if name not in odict: - odict[name] = f"{prefix}{len(odict)}" - return odict[name] - - def __init__(self, sizevars=None): - self.input_buffers = dict() - self.output_buffers = dict() - self.inplace_buffers = dict() - self.sizevars = sizevars or dict() - - def __repr__(self): - return "KernelArgs({})".format( - ", ".join( - map( - repr, - [ - self.input_buffers, - self.output_buffers, - self.inplace_buffers, - self.sizevars, - ], - ) - ) - ) - - def _buffer_is_marked_removed(self, name): - return isinstance(name, str) and name.startswith("REMOVED") - - def input(self, name): - if V.graph.scheduler: - name = V.graph.scheduler.mutation_real_name.get(name, name) - assert name not in V.graph.removed_buffers, name - if name in self.output_buffers: - return self.output_buffers[name] - if name in self.inplace_buffers: - return self.inplace_buffers[name].inner_name - if name.startswith("seed"): - return self._lookup("seed", self.input_buffers, name) - return self._lookup("in_ptr", self.input_buffers, name) - - def output(self, name): - if V.graph.scheduler: - name = V.graph.scheduler.mutation_real_name.get(name, name) - assert name not in V.graph.removed_buffers, name - if name in self.inplace_buffers: - return self.inplace_buffers[name].inner_name - return self._lookup("out_ptr", self.output_buffers, name) - - def make_inplace(self, input_name, output_name): - assert output_name not in self.inplace_buffers - if input_name in self.inplace_buffers: - buf = self.inplace_buffers[input_name] - buf.other_names.append(output_name) - self.inplace_buffers[output_name] = buf - else: - buf = InplacedBuffer( - f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", - [input_name, output_name], - ) - self.inplace_buffers[input_name] = buf - self.inplace_buffers[output_name] = buf - - def seed_offset(self, name, value): - if value in self.sizevars: - return self.sizevars[value] - if name in self.sizevars.values(): - name = ( - f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" - ) - self.sizevars[value] = name - return name - - def size(self, name): - if str(name) == "seed": - self.sizevars["seed"] = "seed" - return "seed" - return self._lookup("ks", self.sizevars, name) - - def call_names(self): - return chain( - self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() - ) - - def wrap_ptr_arg(self, buf, dtype): - return f"c_void_p({buf}.data_ptr())" - - def wrap_size_arg(self, size): - return f"c_long({size})" - - def cpp_argdefs(self): - from .cpp import DTYPE_TO_CPP, INDEX_TYPE - - # TODO(jansel): replace this with data from scheduler - buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers} - for name, val in V.graph.graph_inputs.items(): - if isinstance(val, sympy.Expr): - buffer_types[name] = get_sympy_Expr_dtype(val) - else: - buffer_types[name] = val.get_dtype() - buffer_types.update( - {name: val.dtype for name, val in V.graph.constants.items()} - ) - - call_args = [] - arg_defs = [] - arg_types = [] - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - outer = inplaced.other_names[-1] - inner = inplaced.inner_name - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"{cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"{cpp_dtype}*") - for outer, inner in self.input_buffers.items(): - if outer in self.inplace_buffers: - continue - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"const {cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"const {cpp_dtype}*") - for outer, inner in self.output_buffers.items(): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - dtype = buffer_types[outer] - cpp_dtype = DTYPE_TO_CPP[dtype] - arg_defs.append(f"{cpp_dtype}* {inner}") - call_args.append(self.wrap_ptr_arg(outer, dtype)) - arg_types.append(f"{cpp_dtype}*") - for outer, inner in self.sizevars.items(): - arg_defs.append(f"const {INDEX_TYPE} {inner}") - call_args.append(self.wrap_size_arg(outer)) - arg_types.append(f"const {INDEX_TYPE}") - return arg_defs, call_args, arg_types - - def python_argdefs(self): - arg_defs = [] - call_args = [] - precompile_args: List[Union[TensorArg, SizeArg]] = [] - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - arg_defs.append(inplaced.inner_name) - call_args.append(inplaced.other_names[-1]) - precompile_args.append( - TensorArg( - inplaced.inner_name, - inplaced.other_names[-1], - V.graph.get_dtype(inplaced.other_names[-1]), - ) - ) - for outer, inner in chain( - self.input_buffers.items(), self.output_buffers.items() - ): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - arg_defs.append(inner) - call_args.append(outer) - precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer))) - for outer, inner in self.sizevars.items(): - arg_defs.append(inner) - call_args.append(outer) - precompile_args.append(SizeArg(inner, outer)) - - return arg_defs, call_args, precompile_args - - def aliases(self): - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - for other in inplaced.other_names: - if other in V.graph.inplaced_to_remove: - continue - if other in self.input_buffers: - yield self.input_buffers[other], inplaced.inner_name - if other in self.output_buffers: - yield self.output_buffers[other], inplaced.inner_name - - def is_removed(self, name): - def _is_removed(name, buffers): - return name not in buffers or self._buffer_is_marked_removed(buffers[name]) - - return _is_removed(name, self.output_buffers) and _is_removed( - name, self.inplace_buffers - ) - - # Includes inplace buffers, excludes removed buffers. Essentially, - # after you do a call into this kernel, which buffers actually contain - # updated data? Modeled off of python_argdefs. - def live_output_buffers(self): - live_outs = set() - for inplaced in unique(self.inplace_buffers.values()): - if self._buffer_is_marked_removed(inplaced): - continue - live_outs.add(inplaced.other_names[-1]) - for outer, inner in self.output_buffers.items(): - if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): - continue - live_outs.add(outer) - return live_outs - - -class CSEVariable: - """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. - To do so, the backends can simply overload `Kernel.create_cse_var` - The "CSEVariable.update_on_args" method gives you a hook for annotations - See example of TritonCSEVariable in triton.py - """ - - def __init__(self, name, bounds: ValueRanges): - assert isinstance(bounds, ValueRanges) - self.name = name - self.bounds = bounds - - def __str__(self): - return self.name - - def __hash__(self) -> int: - return hash(self.name) - - def __eq__(self, other) -> bool: - return type(other) == type(self) and other.name == self.name - - def update_on_args(self, name, args, kwargs): - pass - - -class CppWrapperKernelArgs(KernelArgs): - def wrap_ptr_arg(self, buf, dtype): - from .cpp import DTYPE_TO_CPP - - return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" - - def wrap_size_arg(self, size): - return f"{size}" - - -class CSE: - """Common subexpression elimination""" - - def __init__( - self, - prefix="", - suffix="", - name_prefix="tmp", - iter_buffers=None, - store_cache=None, - reduction_cache=None, - varname_map=None, - ): - self.prefix = prefix - self.suffix = suffix - self.cache = {} - self.name_prefix = name_prefix - self.store_cache = store_cache or {} - self.reduction_cache = reduction_cache or {} - self.iter_buffer_ids = iter_buffers or itertools.count() - self.invalidated_stores = set() - self.varname_map = varname_map or {} - - def invalidate(self, keep_vars: Set[str]): - for name, tmp in list(self.store_cache.items()): - if tmp not in keep_vars: - del self.store_cache[name] - self.invalidated_stores.add(name) - self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} - - def clone(self): - # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional - return CSE( - prefix=self.prefix, - suffix=self.suffix, - name_prefix=self.name_prefix, - iter_buffers=self.iter_buffer_ids, - store_cache=self.store_cache, - varname_map=self.varname_map, - ) - - def generate( - self, - buffer: IndentedBuffer, - expr: Union[str, CSEVariable, OpsValue], - *, - bounds: ValueRanges = ValueRanges.unknown(), - write=True, - assignment=True, - ) -> CSEVariable: - if isinstance(expr, OpsValue): - expr = expr.value - - assert isinstance(expr, (str, CSEVariable)), type(expr) - assert write or assignment - if isinstance(expr, CSEVariable): - # If the expressions were always created with all the information, we could - # assert expr.bounds == bounds, but sometimes the expression is created - # with the loose ValueRanges.unknown(), so we need to tighten the bounds - expr.bounds = expr.bounds.tighten(bounds) - return expr - cache_key = expr - var = self.cache.get(cache_key, None) - if not var: - var = self.newvar(bounds) if assignment else None - self.cache[cache_key] = var - if write: - if V.kernel.current_node: - V.kernel.current_node.codegen_originating_info( - buffer, only_once=True - ) - if assignment: - line = f"{self.prefix}{var} = {expr}{self.suffix}" - else: - line = f"{expr}{self.suffix}" - buffer.writeline(line) - else: - var.bounds = var.bounds.tighten(bounds) - - return var - - def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable: - var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" - var = V.kernel.create_cse_var(var_name, bounds) - self.varname_map[var_name] = var - return var - - -class CodeGen: - def __init__(self): - super().__init__() - self.exit_stack = contextlib.ExitStack() - - def __enter__(self): - self.exit_stack.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - -class Kernel(CodeGen): - newvar_prefix = "" - suffix = "" - overrides = None - load_format = None - store_format = None - - def __init__(self, args=None): - super().__init__() - metrics.generated_kernel_count += 1 - self.args = args or KernelArgs() - self.loads = IndentedBuffer() - self.compute = IndentedBuffer() - self.stores = IndentedBuffer() - self.cse = CSE(self.newvar_prefix, self.suffix) - self.must_keep_buffers = set() - self.store_buffer_names = set() - # set in set_current_node - self.current_node = None - self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None - - @contextlib.contextmanager - def set_current_node(self, node): - prior = self.current_node - self.current_node = node - self.node_to_bounds = node._body.bounds().get_bounds() - try: - yield - finally: - self.current_node = prior - - @contextlib.contextmanager - def swap_buffers(self, lb, cb=None, sb=None): - if cb is None: - cb = lb - loads = self.loads - compute = self.compute - stores = self.stores - cse = self.cse - self.loads = lb - self.compute = cb - self.stores = sb - self.cse = cse.clone() - try: - yield - finally: - self.loads = loads - self.compute = compute - self.stores = stores - self.cse = cse - - def load(self, name: str, index: sympy.Expr): - raise NotImplementedError() - - def indirect_load(self, name: str, index: sympy.Expr): - """A load the depends on an index we have read""" - prior = self.loads - try: - # put the load in the compute section as it might have deps - self.loads = self.compute - return self.load(name, index) - finally: - self.loads = prior - - def store_reduction(self, name, index, value): - raise NotImplementedError() - - def store(self, name, index, value, mode=None): - raise NotImplementedError() - - def reduction(self, dtype, src_dtype, reduction_type, value): - raise NotImplementedError() - - def bucketize( - self, - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - """ - See [Note: Inductor bucketize op] - """ - raise NotImplementedError() - - def __enter__(self): - class CSEProxy: - self.name = "CSEProxy" - - @staticmethod - def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] - def inner(*args, **kwargs): - # TritonTemplateKernel has no current_node - buf_bounds = ValueRanges.unknown() - if hasattr(V.interpreter, "current_node"): - fx_node = V.interpreter.current_node - assert isinstance(self.node_to_bounds, dict) - buf_bounds = self.node_to_bounds.get( - fx_node, ValueRanges.unknown() - ) - - csevar = self.cse.generate( - self.compute, - getattr(parent_handler, name)(*args, **kwargs), # type: ignore[has-type] - bounds=buf_bounds, - ) - csevar.update_on_args(name, args, kwargs) - return csevar - - return inner - - @staticmethod - def indirect_indexing(index_var, size, check=True): - # Skip CSE since this doesn't return an expression - return self.indirect_indexing(index_var, size, check) # type: ignore[attr-defined] - - @staticmethod - def load(name: str, index: sympy.Expr): - if name in self.cse.invalidated_stores: - # A load from an invalidated store requires us to - # keep the actual buffer around - V.kernel.must_keep_buffers.add(name) - if free_symbol_startswith(index, "tmp"): - return self.indirect_load(name, index) - store_cache = self.cse.store_cache - if name in store_cache: - return store_cache[name] - return self.load(name, index) - - @staticmethod - def store(name, index, value, mode=None): - self.store_buffer_names.add(name) - if mode is None: - self.cse.store_cache[name] = value - if self.current_node: - for other_name in self.current_node.get_mutations(): - self.cse.store_cache[other_name] = value - if name not in V.graph.removed_buffers: - return self.store(name, index, value, mode=mode) - - @staticmethod - def store_reduction(name, index, value): - self.store_buffer_names.add(name) - self.cse.store_cache[name] = value - if self.current_node: - for other_name in self.current_node.get_mutations(): - self.cse.store_cache[other_name] = value - - if name not in V.graph.removed_buffers: - return self.store_reduction(name, index, value) - - @staticmethod - def reduction(dtype, src_dtype, reduction_type, value): - return self.reduction(dtype, src_dtype, reduction_type, value) - - @staticmethod - def bucketize( - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - """ - [Note: Inductor bucketize op] - - Given values (tensor) and offsets_name (reference to the name of a 1D - tensor), calculate the bucket that each value belongs to. - - e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True - return = [ 0, 1, 1, 1, 1, 3, 3, 4]. - - When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. - When right == True, bucket i refers to range [offsets[i], offsets[i+1]). - - Offsets must be non-decreasing or the result is undefined. - """ - return self.bucketize( - values, offsets_name, offsets_size, indexing_dtype, right - ) - - super().__enter__() - assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) - self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if V.graph.scheduler: - V.graph.scheduler.remove_kernel_local_buffers() - super().__exit__(exc_type, exc_val, exc_tb) - - def rename_indexing(self, index) -> sympy.Expr: - # adds the necessary kernel args for index expressions - # and renames variables in index expressions to kernel arg names - if isinstance(index, (list, tuple)): - return [self.rename_indexing(x) for x in index] - index = V.graph.sizevars.simplify(index) - sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) - replacements = { - x: self.args.size(x) - for x in sorted_symbols - if x.name.startswith("s") or x.name.startswith("ps") - } - return sympy_subs(index, replacements) - - def create_cse_var(self, *args, **kwargs): - return CSEVariable(*args, **kwargs) - - -@dataclasses.dataclass -class OptimizationContext: - key: ClassVar[str] = "opt_ctx" - - # Load value as mask - is_load_as_mask: bool = False - - dtype: torch.dtype = None - ops_name: str = "" - is_most_inner_loop_irrevelant: bool = False - - # Load uint8 value as float32 - is_load_uint8_as_float: bool = False \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index b8c2b3b4..1761e05c 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -37,7 +37,7 @@ # Backendsim config CONFIG_TORCHSIM_BACKEND_CONFIG = os.environ.get('TORCHSIM_CONFIG', default=f'{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') -CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", False)) +CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", True)) CONFIG_BACKENDSIM_EAGER_MODE = int(os.environ.get("BACKENDSIM_EAGER_MODE", default=False)) CONFIG_BACKENDSIM_DRYRUN = int(os.environ.get('BACKENDSIM_DRYRUN', default=False)) CONFIG_BACKENDSIM_DEBUG_LEVEL = os.environ.get("BACKENDSIM_DEBUG_LEVEL", "") @@ -71,8 +71,8 @@ CONFIG_SUBTILE_K = int(os.environ.get('TORCHSIM_SUBTILE_K', default=CONFIG_VECTOR_LANE)) # Advanced fusion options -CONFIG_FUSION_REDUCTION = int(os.environ.get('TORCHSIM_FUSION_REDUCTION', default=False)) -CONFIG_FUSION_PROLOGUE = int(os.environ.get('TORCHSIM_FUSION_PROLOGUE', default=False)) +CONFIG_FUSION_REDUCTION = int(os.environ.get('TORCHSIM_FUSION_REDUCTION', default=True)) +CONFIG_FUSION_PROLOGUE = int(os.environ.get('TORCHSIM_FUSION_PROLOGUE', default=True)) # SRAM Buffer allocation plan def load_plan_from_module(module_path): diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 91ba9ba1..b81b3862 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -1,14 +1,14 @@ import os from torch import empty_strided -from typing import List, Optional, cast +from typing import List, Optional +import sympy from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common BMM_TEMPLATE = r""" // BMM kernel @@ -21,63 +21,38 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1, d2) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1, d2) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> -memref.global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 2 : index - %X_buffer = memref.get_global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - affine.for %b=0 to {{ B }} { - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - - %index2 = affine.apply #map2(%b, %t_m, %t_n) + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ B }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {% if Bias -%} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer2D[0, 0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1 , {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer2D[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%b, %t_m, %t_k) - %index1 = affine.apply #map1(%b, %t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer2D[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ B * M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer2D[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ B * K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[1, SUB_TILE_K, SUB_TILE_N], indent_size=10) }} linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=8)}} - } { outer_loop=true } - } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } } { outer_loop=true } return } @@ -94,59 +69,36 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1, d2) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1, d2) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> -memref.global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 2 : index - %X_buffer = memref.get_global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - affine.for %b=0 to {{ B }} { - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ B }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - - %index2 = affine.apply #map2(%b, %t_m, %t_n) {% if Bias -%} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer2D[0, 0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1 , {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer2D[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%b, %t_m, %t_k) - %index1 = affine.apply #map1(%b, %t_k, %t_n) - {{kernel.prepare_input(indent_size=10)}} + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{kernel.load_input(indent_size=10)}} linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } - memref.dma_start %Y_buffer[%c0, %c0, %c0], %Y[%index2], %c_mvout, %tag[%c0], %axis, %vstride : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<{{ B * M * N }}xf32>, memref<1xi32> { padding=0, sram_stride=[1, 1, {{ TILE_M }}] } - } { outer_loop=true } - } { outer_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } } { outer_loop=true } return } @@ -163,65 +115,39 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1, d2) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1, d2) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1, d2) -> (d0 * {{ M * N }} + d1 * {{ N }} + d2)> -memref.global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 2 : index - %X_buffer = memref.get_global @X_spad : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - affine.for %b=0 to {{ B }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %red_idx = affine.apply affine_map<(d0, d1) -> ({{M}}*d0 + d1)>(%b, %t_n) - {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { + {{ kernel.def_local_vars() }} + affine.for %index0=0 to {{ B }} { + affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { + affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %index2 = affine.apply #map2(%b, %t_m, %t_n) {% if Bias -%} - memref.dma_start %Bias[ - {%- if Bias_rank == 2 -%} %index2 {%- else -%} %t_n {%- endif -%} - ], %Y_buffer2D[0, 0], %c_mvin3, %tag0[%c0], % - {%- if Bias_rank == 2 -%} axis {%- else -%} c0 {%- endif -%} - , %vstride : memref< - {%- if Bias_rank == 2 -%} {{ M * N }} {%- else -%} {{ N }} {%- endif -%} - xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1 , {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} // Why not N,M? Currently, dma-fine-grained pass assume M->N order... {%- else -%} - affine.vector_store %v0, %Y_buffer2D[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%b, %t_m, %t_k) - %index1 = affine.apply #map1(%b, %t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer2D[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ B * M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer2D[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ B * K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} - + affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[1, SUB_TILE_K, SUB_TILE_N], indent_size=10) }} linalg.matmul ins(%X_buffer2D, %W_buffer2D : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) outs(%Y_buffer2D : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true, loop_k=true } + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=8)}} - } { outer_loop=true, loop_m=true } + } { outer_loop=true, subtile_loop="m" } {{kernel.reduction_output(indent_size=6)}} - } { outer_loop=true, loop_n=true} + } { outer_loop=true, subtile_loop="n" } } { outer_loop=true } return } @@ -239,9 +165,8 @@ def render(self, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node - #if epilogue_nodes is not None and len(epilogue_nodes) > 0: - # self.output_node = cast(Buffer, epilogue_nodes[-1]) + # Extract input arguments info X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] @@ -252,113 +177,150 @@ def render(self, W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) if len(X_tensor.size()) > 3: X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) + B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + W_stride = W_tensor.stride() X_stride = X_tensor.stride() - W_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(W_stride)]) - X_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(X_stride)]) - B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + # Select tile size n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) - TOG_latency = M if TILE_M > M else TILE_M - kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - TILE_K = TILE_K // 2 if prologue_nodes else TILE_K SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + TOG_latency = M if TILE_M > M else TILE_M + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + TILE_K = TILE_K // 2 if prologue_nodes else TILE_K + + # Select template code nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] if nr_reduction_nodes: template = BMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} nr_rdim = 1 elif prologue_nodes: template = BMM_PROLOGUE_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} nr_rdim = 0 else: template = BMM_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} nr_rdim = 0 + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 2 + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + X_tile_size = [1, TILE_M, TILE_K] + X_tile_stride = [0, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_stride = X_tensor.stride() + X_idx = [loop_dim[0]*X_stride[0], loop_dim[1]*X_stride[1], loop_dim[3]*X_stride[2]] # To keep index arguemnt order, we used index_list + + W_tile_size = [1, TILE_K, TILE_N] + W_tile_stride = [0, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("W_buffer") + W_stride = W_tensor.stride() + W_idx = [loop_dim[0]*W_stride[0], loop_dim[3]*W_stride[1], loop_dim[2]*W_stride[2]] + + vlane_split_axis = vlane_split_axis if nr_rdim==0 else 1 + Y_tile_size = [1, TILE_M, TILE_N] if nr_rdim == 0 else [1, TILE_N, TILE_M] + Y_tile_stride=[0, 1, TILE_M] if nr_rdim == 0 else [0, TILE_M, 1] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + if nr_rdim == 0: + Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[1]*Y_stride[1], loop_dim[2]*Y_stride[2]] + else: + Y_idx = [loop_dim[0]*Y_stride[0], loop_dim[2]*Y_stride[2], loop_dim[1]*Y_stride[1]] + + # Extract Bias info + if Bias is not None: + Bias_stride = Bias.get_layout().stride + if nr_rdim == 0: + Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[1]*Bias_stride[1], loop_dim[2]*Bias_stride[2]] + else: + Bias_idx = [loop_dim[0]*Bias_stride[0], loop_dim[2]*Bias_stride[2], loop_dim[1]*Bias_stride[1]] + else: + Bias_idx = None + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - B=B, - M=M, - N=N, - K=K, - TILE_M=TILE_M, - TILE_N=TILE_N, - TILE_K=TILE_K, + B=B, M=M, N=N, K=K, + TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, DATA_STYPE="f32", - DATA_SIZE=4, - X = X, - W = W, - Y = Y, - Bias = Bias, - Bias_rank = len(Bias.data.get_size()) if Bias is not None else 0, - X_map = X_map, - W_map = W_map, - Y_numel = B * M * N, + X = X, W = W,Y = Y, Bias = Bias, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, input_reorder = self.input_reorder ) if prologue_nodes: - # if Input fused: - # tile_size = (TILE_M, TILE_K) - # input_sram_stride = [1, TILE_M] - # elif Weight fused: - tile_size = (TILE_K, TILE_N) - input_sram_stride = [1, TILE_K] + prologue_output_name = list(prologue_nodes[0].read_writes.writes)[0].name + if prologue_output_name == X.get_name(): + # Input fusion case + prologue_var = "X" + prologue_sram_var = "X_buffer" + prologue_tile_desc = X_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2":"index3"} + is_input_fused = True + else: + # Weight fusion case + prologue_var = "W" + prologue_sram_var = "W_buffer" + prologue_tile_desc = W_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index3", "index2":"index2"} + is_input_fused = False + kernel.prologue_info = dict ( - input_sram_var = "X_buffer2D", input_dram_var = "X", - input_index_var = "index0", - input_tag_var = "tag1", - input_numel = B * M * K, - input_tile_size = (TILE_M, TILE_K), - input_sram_stride = input_sram_stride, - input_subtile_size = (SUB_TILE_M, SUB_TILE_K), - weight_sram_var = "W_buffer2D", + input_sram_var = "X_buffer", + input_tile_desc = X_tile_desc, + input_idx = X_idx, + input_subtile_size = [1, TILE_M, TILE_K], # TODO. Curently, Subtiling is not supported for prologue template + input_dim_aliasing = {"index0":"index0", "index1":"index1", "index2":"index3"}, + weight_dram_var = "W", - weight_index_var = "index1", - weight_tag_var = "tag2", - weight_numel = B * K * N, - weight_tile_size = (TILE_K, TILE_N), - weight_sram_stride = [1, TILE_K], - weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), - tile_size = tile_size, - vlane_split_axis = 1, - vlane_stride = 1, + weight_sram_var = "W_buffer", + weight_tile_desc = W_tile_desc, + weight_idx = W_idx, + weight_subtile_size = [1, TILE_K, TILE_N], # TODO. Curently, Subtiling is not supported for prologue template + weight_dim_aliasing = {"index0":"index0", "index1":"index3", "index2":"index2"}, + + # Descriptor for fusion + dram_var = prologue_var, + sram_var = prologue_sram_var, + dram_tile_desc = prologue_tile_desc, + dim_aliasing = prologue_dim_aliasing, is_bmm = True, + is_input_fused = is_input_fused ) + kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], sram_var = "Y_buffer", dram_var = "Y", - index_var = "index2", - tag_var = "tag", - vlane_split_axis = 2, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - dram_shape = f"memref<{kernel.render_options['Y_numel']}x{kernel.render_options['DATA_STYPE']}>", - tile_size = (1, TILE_M, TILE_N), - tile_stride = [1, 1, TILE_M], + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, - reduction_idx = "red_idx" + dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) - - self.header = f"float X_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_K)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{kernel.get_spad_size_per_lane(TILE_K, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" - return code def codegen_header(self, code, extra_headers): @@ -368,6 +330,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) \ No newline at end of file + write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6f38b08a..6dbe9047 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -877,7 +877,6 @@ def __init__(self, kernel_group, reason=None): self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} - self.is_template_kernel = False self.spad_buffer_dict = dict() self.base_vector_initialized = False @@ -919,7 +918,7 @@ def convert_index(self, expr, buffer): index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})") return index - def parse_indices(self, expr, buffer=None) -> common.CSEVariable: + def parse_indices(self, expr, buffer=None, comments="") -> common.CSEVariable: if buffer is None: buffer = self.applys @@ -951,6 +950,40 @@ def parse_indices(self, expr, buffer=None) -> common.CSEVariable: args = ", ".join(map(str, indices)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) + index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[] {comments}") + return index + + def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: + if buffer is None: + buffer = self.applys + expr_list = [arg for arg in expr_list if arg != sympy.Number(0)] + + if len(expr_list) == 1 and expr_list[0].is_number: + # Constant case + return self.get_const_cse(int(expr_list[0])) + elif len(expr_list) == 1 and expr_list[0].is_symbol: + # Identity case + return expr_list[0] + + indices = [] + new_expr_list = [0] * len(expr_list) + for idx, arg in enumerate(expr_list): + if arg.is_Mul and arg.args[0].is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) + new_expr_list[idx] = arg.subs(arg.args[1], new_arg) + indices.append(str(new_arg)) + elif not arg.is_number: + new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) + new_expr_list[idx] = new_arg + indices.append(str(new_arg)) + else: + new_expr_list[idx] = arg + + # Extract index var + expr_str = str(sum(new_expr_list)) + args = ", ".join(map(str, indices)) + map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") + args = ", ".join([f"%{i}" for i in indices]) index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]") return index @@ -958,16 +991,18 @@ def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) index = self.convert_indirect_indexing(index) padding = self.get_padding_type() + + # Extract dram info dram_var = self.kernel_group.args.input(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - local_tile_desc, index_var = self.get_dma_info(name, index) - + # Extract sram info + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride tile_numel_per_lane = local_tile_desc.get_numel_per_lane() - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) tile_stride = local_tile_desc.get_tile_stride() @@ -976,11 +1011,12 @@ def load(self, name: str, index: sympy.Expr): compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() # Define scratch pad buffer - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) # MVIN Encoding + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride, padding) + f"{name}_tag", dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector load instruction @@ -1018,16 +1054,14 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] # Prepare dma instruction - local_tile_desc, index_var = self.get_dma_info(name, index) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride - tile_numel_per_lane = local_tile_desc.get_numel_per_lane() dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) tile_stride = local_tile_desc.get_tile_stride() tile_size = local_tile_desc.get_tile_size() - # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() @@ -1039,7 +1073,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): if require_store: # Define scratch pad buffer - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector store instruction store_size, operand_type = self.var_info[value] @@ -1058,8 +1092,9 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): sram_index_var = self.spad_buffer_dict[str(value)][3] # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + f"{name}_tag", dram_shape, tile_shape, attribute) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -1152,7 +1187,9 @@ def store_reduction(self, name, index, value): # Store reduction can't share cached value stored in cse, # since it is not innermost loop body. tmp_cse = self.cse + tmp_apply_cse = self.apply_cse self.cse = self.reduction_cse + self.apply_cse = self.reduction_cse dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) @@ -1160,10 +1197,9 @@ def store_reduction(self, name, index, value): index = self.rename_indexing(index) # Tile is always reuduced in inner loop - local_tile_desc, index_var = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) + local_tile_desc, index_var, dram_stride = self.get_dma_info(name, index, broadcast=False, store_reduction=True, buffer=self.reductions_suffix) vlane_split_axis = local_tile_desc.vlane_split_axis vlane_stride = local_tile_desc.vlane_stride - tile_numel_per_lane = local_tile_desc.get_numel_per_lane() dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = local_tile_desc.get_mlir_shape(mlir_dtype) @@ -1173,7 +1209,7 @@ def store_reduction(self, name, index, value): vshape = f"{mlir_dtype}" else: vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) if self.welford_reduce_out is not None: # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out @@ -1206,12 +1242,14 @@ def store_reduction(self, name, index, value): # MVOUT Encoding # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + f"{name}_tag", dram_shape, tile_shape, attribute) self.reductions_suffix.writeline(common.DeferredLine(name, code)) # Restore origin cse self.cse = tmp_cse + self.apply_cse = tmp_apply_cse def indirect_indexing(self, index_var, size, check=True): return str(index_var) @@ -1496,7 +1534,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # TODO. kg_tile_desc = self.kernel_group.tile_desc - buffer_info = self.buffer_types[name] # Note: index could contain symbols that represent dynamic axies # Extract dimension of index(e.g, index0, index1) local_dims = [int(str(i)[5:]) for i in index.free_symbols if "index" in str(i)] @@ -1505,23 +1542,16 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc = mlir_common.MLIRMultiDimTile([1], self.vector_lane) local_dims.sort() # Assume that smaller index is placed in the outer loop indirect_dims = [f"{i}" for i in index.free_symbols if "tmp" in str(i)] - indirect_arg_dims = [f"%{i}" for i in index.free_symbols if "tmp" in str(i)] for indirect_dim in indirect_dims: index = index.replace(sympy.Symbol(indirect_dim), 0) # Reduction can have two type of tile size if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): - # We have to create custom apply map to provide dram stride - # ex) (d0, d1, ... dn, dn+1, dn+2, dk) -> (s0*d0 + s1*d1 + ... dn*0+ dn+1*0 + ... dk*0 + const) - fake_dim = self.get_const_cse(0) - input_expr = ",".join(["d"+str(i) for i in total_dims]) - output_expr = str(index).replace('index', 'd') - input_argument = ",".join(["%index" + str(i) if i in local_dims else f"%{fake_dim}" for i in total_dims]) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({input_expr})[{','.join(indirect_dims)}] -> ({output_expr})>") - index_var = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({input_argument})[{','.join(indirect_arg_dims)}]") local_dims = total_dims # Brodatcast tile shape - else: - index_var = self.parse_indices(index, buffer=buffer) + + index_var = self.parse_indices(index, buffer=buffer) + input_argument = [f"index{str(i)}" for i in local_dims] + dram_stride = [index.coeff(sympy.Symbol(arg)) for arg in input_argument] if kg_tile_desc.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vlane_split_axis) @@ -1533,6 +1563,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.set_tile_size([kg_tile_desc.get_used_vlane() * kg_tile_desc.vlane_stride]) # Force it to use vector instruction. local_tile_desc.vlane_split_axis = local_vlane_split_axis # last axis local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + dram_stride = [0] # Edge case # Case 1. Tile is 1-D vector type elif len(local_dims) == 1 and len(local_dims) <= self.reduction_depth: local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(local_dims[0])]) @@ -1565,6 +1596,14 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) local_tile_desc.vlane_split_axis = local_vlane_split_axis local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + # Case 4. Tile is 4-D tile (e.g., Convolution epilogue) + elif len(local_dims) == 4: + is_reduction = self.reduction_depth < 3 and not store_reduction + if is_reduction: + raise NotImplementedError("Currently not implemented... ;)") + local_tile_desc.set_tile_size([kg_tile_desc.get_dim_size(dim) for dim in local_dims]) + local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride else: raise NotImplementedError("Currently not implemented... ;)") @@ -1580,27 +1619,26 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Update local_tile_desc.set_tile_size(new_tile_size) local_tile_desc.vlane_split_axis = new_vlane_split_axis + return local_tile_desc, index_var, dram_stride - return local_tile_desc, index_var - - def get_dma_code(self, dma_type_name, attribute1, attribute2, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, - tag_name, dram_shape, tile_shape, tile_stride, padding_type=0): - dma_key = (attribute1, attribute2, mlir_dtype) + def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, + tag_name, dram_shape, tile_shape, attribute): + dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: - dma_type, attribute1, attribute2 = self.dma_read_cache[dma_key] + dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key] elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: - dma_type, attribute1, attribute2 = self.dma_write_cache[dma_key] + dma_type, vlane_split_axis, vlane_stride = self.dma_write_cache[dma_key] else: - attribute1 = self.get_const_cse(attribute1) - attribute2 = self.get_const_cse(attribute2) + vlane_split_axis = self.get_const_cse(vlane_split_axis) + vlane_stride = self.get_const_cse(vlane_stride) if dma_type_name == "MVIN": dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) self.dma_read_counter += 1 - self.dma_read_cache[dma_key] = [dma_type, attribute1, attribute2] + self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] else: dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) # self.dma_write_counter += 1 Is it okay? - self.dma_write_cache[dma_key] = [dma_type, attribute1, attribute2] + self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] tag = self.get_tag_cse(tag_name) zero_cse = self.get_const_cse(0) @@ -1608,7 +1646,7 @@ def get_dma_code(self, dma_type_name, attribute1, attribute2, mlir_dtype, dram_v dram_operand = f"%{dram_var}[%{dram_index_var}]" sram_operand = f"%{sram_var}[{sram_index_var}]" # Use string tag_var = f"%{tag}[%{zero_cse}]" - dma_attribute = f"%{attribute1}, %{attribute2}" + dma_attribute = f"%{vlane_split_axis}, %{vlane_stride}" sram_shape = tile_shape tag_shape = "memref<1xi32>" @@ -1619,9 +1657,7 @@ def get_dma_code(self, dma_type_name, attribute1, attribute2, mlir_dtype, dram_v src_operand, dst_operand = sram_operand, dram_operand src_shape, dst_shape = sram_shape, dram_shape - code = f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape}" - code = code + f" {{padding={padding_type}, sram_stride={tile_stride}}}" - return code + return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" def adjust_tile_size(self): if self.read_writes is not None: @@ -1672,34 +1708,44 @@ def adjust_tile_size(self): if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars): raise NotImplementedError() - def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, indices, raw_index, is_template=False, buffer=None): + def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_numel_per_lane = tile_desc.get_numel_per_lane() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) # Make sure each lane's buffer has at least two element - tile_size = max(tile_size_per_lane, 2) * self.vector_lane + tile_size = max(tile_numel_per_lane, 2) * self.vector_lane if buffer is None: buffer = self.spad_buffer - if name not in self.global_vars_dict: - self.global_vars_dict[name] = dict() + if dram_name not in self.global_vars_dict: + self.global_vars_dict[dram_name] = dict() - if str(raw_index) not in self.global_vars_dict[name]: - new_name = f"buf{self.spadbuf_counter}" + if str(raw_index) not in self.global_vars_dict[dram_name]: + new_name = f"buf{self.spadbuf_counter}_spad" if forced_name is None else f"{forced_name}_spad" self.spadbuf_counter+=1 # Add definition to header - self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") - self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}] __attribute__((aligned(64)));") - self.global_vars.writeline(f"memref.global @{new_name}_spad : {dram_tile_shape}") - self.global_vars_dict[name][str(raw_index)] = new_name + self.header.writeline(f"{c_type} {new_name}[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") + self.gem5_header.writeline(f"{c_type} {new_name}[{tile_size}] __attribute__((aligned(64)));") + self.global_vars.writeline(f"memref.global @{new_name} : {tile_shape}") + self.global_vars_dict[dram_name][str(raw_index)] = new_name else: - new_name = self.global_vars_dict[name][str(raw_index)] - sram_var = self.spad_cse.generate(buffer, f"memref.get_global @{new_name}_spad : {dram_tile_shape}") + new_name = self.global_vars_dict[dram_name][str(raw_index)] + return new_name - zero_cse = self.get_const_cse(0) - sram_dims = len(dram_tile_shape.split("x")) - 1 - sram_index_var = ",".join([f"%{zero_cse}"] * sram_dims) + def get_scratchpad_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None): + if buffer is None: + buffer = self.spad_buffer - return sram_var, indices, sram_index_var + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + new_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, raw_index, buffer=buffer) + sram_var = self.spad_cse.generate(buffer, f"memref.get_global @{new_name} : {tile_shape}") + + zero_cse = self.get_const_cse(0) + sram_index_var = ",".join([f"%{zero_cse}"] * tile_desc.get_nr_dim()) + return sram_var, sram_index_var def get_const_cse(self, value, dtype="index") -> common.CSEVariable: # Type convert diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 92af0570..00bf4169 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -179,6 +179,7 @@ def set_info(outer, inner, arg_type): class MLIRMultiDimTile(): def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=None, vec_size=None): + self.name = "" self._tile_size = list(tile_size) self._tile_stride = None self.tile_axis_order = list(range(len(tile_size))) @@ -192,6 +193,9 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N self.implicit_dim_size = None self.nr_rdim = 0 + def set_name(self, name: str): + self.name = name + def set_tile_size(self, tile_size, tile_axis_order=None): self._tile_size = tile_size if tile_axis_order is None: @@ -204,6 +208,9 @@ def set_tile_size_stride(self, tile_size, tile_stride): self._tile_size = tile_size self._tile_stride = tile_stride + def get_name(self) -> str: + return self.name + def get_tile_size(self): return self._tile_size @@ -316,9 +323,6 @@ def __init__(self): def set_tile_info(self, tile_desc : MLIRMultiDimTile): self.tile_desc = tile_desc - def set_prologue_tile_info(self, tile_desc : MLIRMultiDimTile): - self.prologue_tile_desc = tile_desc - class BaseMLIRHardwareInfo(): def __init__(self): # Default HW setting diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py new file mode 100644 index 00000000..7968f813 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -0,0 +1,346 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Multi Channel Tile Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(1 * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + // Load input matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to 1 { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ TILE_O_W }} { + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=10)}} + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(2, 0, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, BATCH, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvMultiTileTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_O_W, TILE_M, TILE_K] + X_tile_stride = [TILE_O_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("o_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*STRIDE_W, X_dim[2]*I_C*(I_W+2*PADDING_W), X_dim[3]] + + W_tile_size = [TILE_K_H, 1, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , Symbol("c0"), W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [TILE_M, TILE_N, TILE_O_H, TILE_O_W] + Y_tile_stride = [1, TILE_M, TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] + Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"tile_m", "index1":"tile_n", "index2":"o_h", "index3":"o_w"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_K = TILE_K + + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py new file mode 100644 index 00000000..f2df1e43 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -0,0 +1,342 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Single Batch Conv2D kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W, TILE_K) }} + d1)> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + %index_i_w = affine.apply #map_I_W(%tile_m, %k_w) + // Load input & weight matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_I_H, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvSingleBatchTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [1, TILE_I_H, TILE_I_W, TILE_K] + X_tile_stride = [TILE_I_H * TILE_I_W * TILE_K , TILE_I_W * TILE_K, 1, TILE_I_W] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("c0"), Symbol("index_i_h"), Symbol("index_i_w"), Symbol("tile_k")] + X_idx = [X_dim[0]*((I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C), X_dim[1]*((I_W+2*PADDING_W)*I_C), X_dim[2]*I_C, X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] + Y_tile_stride = [TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N, TILE_M, 1] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_idx = [Number(0), Symbol("tile_n")*O_H*O_W, Symbol("o_h")*O_W, Symbol("tile_m")] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py new file mode 100644 index 00000000..3b60dcbc --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -0,0 +1,343 @@ +import os +import math +from sympy import Symbol, Number +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common +from torch._inductor.codecache import get_hash +from PyTorchSimFrontend import extension_config + +CONV_TEMPLATE = r""" +// Single Batch Conv2D (Stride != 1) kernel +// BATCH = {{ BATCH }} +// I_C = {{ I_C }} +// I_H = {{ I_H }} +// I_W = {{ I_W }} +// O_C = {{ O_C }} +// K_H = {{ K_H }} +// K_W = {{ K_W }} +// O_H = {{ O_H }} +// O_W = {{ O_W }} +// TILE_M = {{ TILE_M }} +// TILE_N = {{ TILE_N }} +// TILE_K = {{ TILE_K }} +// TILE_I_H={{ TILE_I_H }}, +// TILE_I_W={{ TILE_I_W }}, +// TILE_O_H={{ TILE_O_H }}, +// TILE_O_W={{ TILE_O_W }}, +// TILE_K_H={{ TILE_K_H }}, +// TILE_K_W={{ TILE_K_W }}, +// SUB_TILE_M={{ SUB_TILE_M }}, +// SUB_TILE_N={{ SUB_TILE_N }}, +// SUB_TILE_I_W={{ SUB_TILE_I_W }}, +// SUB_TILE_K_H={{ SUB_TILE_K_H }}, +// SUB_TILE_K_W={{ SUB_TILE_K_W }}, +// PADDING_H = {{ PADDING_H }} +// PADDING_W = {{ PADDING_W }} +// STRIDE_H = {{ STRIDE_H }} +// STRIDE_W = {{ STRIDE_W }} +// DATA_STYPE = {{ DATA_STYPE }} + +#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> +#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> +#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M * TILE_K_W, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> +#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> + +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %c0 = arith.constant 0 : index + {{- kernel.def_local_vars(indent_size=2) }} + + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { + // Initialize output + {%- if BIAS %} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} + {%- else %} + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + {%- endif %} + affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { + affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { + affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { + %index_i_h = affine.apply #map_I_H(%o_h, %k_h) + // Load input & weight matrix + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_K], indent_size=14) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=14) }} + // Compute body part + affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. + affine.for %tile_k_w = 0 to {{ TILE_K_W }} { + %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + affine.for %tile_o_h = 0 to {{ TILE_O_H }} { + affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W + %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) + %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) + %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { inner_loop=true } + } { accumulation_loop=true, subtile_loop="k" } + } { accumulation_loop=true } + } { accumulation_loop=true } + // Store output matrix + {{kernel.store_output(indent_size=8)}} + } { outer_loop=true, subtile_loop="m" } + } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + return +} +""" + +WRAPPER_TEMPLATE = r""" +def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: + # Padding input + padded_shape = list(X.shape) + padded_shape[2] += 2 * {{ PADDING_H }} + padded_shape[3] += 2 * {{ PADDING_W }} + X_padding = torch.zeros(padded_shape, device=X.device) + X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X + + # Tanspose inputs + {%- for buf, name in kernel.get_conv_inputs().items() %} + {%- if name == "X" %} + {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) + {%- elif name == "W" %} + {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) + {%- elif name == "Bias" %} + {{ name }} = {{ name }} + {%- endif %} + {%- endfor %} + + # Launch kernel + {{ KERNEL_NAME }} + {%- if BACKENDSIM_EAGER_MODE %} + yield ({{KERNEL_NAME}}, ) + {%- endif %} +""" + +class MLIRConvSingleBatchStridedTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [i for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + + # Select tile size adn template + conv_template = CONV_TEMPLATE + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_K_H, TILE_M, TILE_K] + X_tile_stride = [TILE_K_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("k_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*((I_W+2*PADDING_W)*I_C), X_dim[1]*I_C, X_dim[2]*(I_C*STRIDE_W), X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] + Y_tile_stride = [TILE_O_W * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_idx = [Number(0), Symbol("tile_n")*O_H*O_W, Symbol("o_h")*O_W, Symbol("tile_m")] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=X, W=W, Y=Y, BIAS=Bias, + PADDED_INPUT_SIZE=self.get_padded_input_size(X), + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, + O_C=O_C, + K_H=K_H, + K_W=K_W, + O_H=O_H, + O_W=O_W, + TILE_M=TILE_M, + TILE_N=TILE_N, + TILE_K=TILE_K, + TILE_I_H=TILE_I_H, + TILE_I_W=TILE_I_W, + TILE_O_H=TILE_O_H, + TILE_O_W=TILE_O_W, + TILE_K_H=TILE_K_H, + TILE_K_W=TILE_K_W, + SUB_TILE_M=SUB_TILE_M, + SUB_TILE_N=SUB_TILE_N, + SUB_TILE_K=SUB_TILE_K, + SUB_TILE_I_H=SUB_TILE_I_H, + SUB_TILE_I_W=SUB_TILE_I_W, + SUB_TILE_K_H=SUB_TILE_K_H, + SUB_TILE_K_W=SUB_TILE_K_W, + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, + DATA_STYPE="f32", + input_reorder=self.input_reorder + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "output_buffer", + dram_var = "Y", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} + code = self._template_from_string(conv_template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TOG_latency = O_W if TILE_M > O_W else TILE_M + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes + + def codegen_header(self, code, extra_headers): + write_path = extension_codecache.get_write_path(code) + if not os.path.exists(write_path): + os.makedirs(write_path) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, extra_headers[0]) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, extra_headers[1]) + self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 0b6d13ef..cd4ddf82 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -1,16 +1,15 @@ import os import math -from sympy import divisors, Range -from typing import List, Optional, cast +from sympy import Symbol, Number +from typing import List, Optional from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache +from PyTorchSimFrontend.mlir import mlir_common from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config @@ -43,56 +42,30 @@ // PADDING_W = {{ PADDING_W }} // STRIDE_H = {{ STRIDE_H }} // STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} // DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * BATCH * O_C }} + d1 * {{ BATCH * O_C }} + d2 * {{ O_C }} + d3)> // output (O_H, O_W, BATCH, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * BATCH * I_C }} + d1 * {{ BATCH * I_C }} + d2 * {{ I_C }} + d3)> // input (I_H, I_W, BATCH, I_C) -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) #map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> #map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> #offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> #offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> #offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> - -memref.global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} + {{- kernel.def_local_vars(indent_size=2) }} - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { - affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%o_h, %o_w, %tile_m, %tile_n) + affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { + affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { + affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { + affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { // Initialize output {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride - : memref<{{ O_C }}xf32>, memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_O_H }}, {{ TILE_O_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_W * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> {%- endif %} @@ -101,406 +74,37 @@ affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { %index_i_h = affine.apply #map_I_H(%o_h, %k_h) %index_i_w = affine.apply #map_I_W(%o_w, %k_w) - %index1 = affine.apply #map1(%index_i_h, %index_i_w, %tile_m, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %k_w, %tile_k, %tile_n) // weight index // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_I_H }}, {{ SUB_TILE_I_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_I_W * TILE_M * TILE_K }}, {{ TILE_M * TILE_K }}, 1, {{ TILE_M }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_M, SUB_TILE_K], indent_size=16) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_K, SUB_TILE_N], indent_size=16) }} + // Compute body part affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %tile_i_w = affine.apply #map_I_W(%tile_o_w, %tile_k_w) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_i_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } - } { accumulation_loop=true } + } { accumulation_loop=true, subtile_loop="k" } } { accumulation_loop=true } } { accumulation_loop=true } // Store output matrix {{kernel.store_output(indent_size=10)}} } { outer_loop=true } } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - -MULTI_TILE_CONV_TEMPLATE = r""" -// Multi Channel Tile Conv2D kernel -// BATCH = {{ BATCH }} -// I_C = {{ I_C }} -// I_H = {{ I_H }} -// I_W = {{ I_W }} -// O_C = {{ O_C }} -// K_H = {{ K_H }} -// K_W = {{ K_W }} -// O_H = {{ O_H }} -// O_W = {{ O_W }} -// TILE_M = {{ TILE_M }} -// TILE_N = {{ TILE_N }} -// TILE_K = {{ TILE_K }} -// TILE_I_H={{ TILE_I_H }}, -// TILE_I_W={{ TILE_I_W }}, -// TILE_O_H={{ TILE_O_H }}, -// TILE_O_W={{ TILE_O_W }}, -// TILE_K_H={{ TILE_K_H }}, -// TILE_K_W={{ TILE_K_W }}, -// SUB_TILE_M={{ SUB_TILE_M }}, -// SUB_TILE_N={{ SUB_TILE_N }}, -// SUB_TILE_I_W={{ SUB_TILE_I_W }}, -// SUB_TILE_K_H={{ SUB_TILE_K_H }}, -// SUB_TILE_K_W={{ SUB_TILE_K_W }}, -// PADDING_H = {{ PADDING_H }} -// PADDING_W = {{ PADDING_W }} -// STRIDE_H = {{ STRIDE_H }} -// STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} -// DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * BATCH * O_C }} + d1 * {{ BATCH * O_C }} + d2 * {{ O_C }} + d3)> // output (O_H, O_W, BATCH, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * BATCH * I_C }} + d1 * {{ I_C * STRIDE_W }} + d2 * {{ I_C * (I_W + 2 * PADDING_W) }} + d3)> // input (I_H, BATCH, I_W, I_C) -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) -#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> -#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(1 * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> -#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> -#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_O_W * TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> - -memref.global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} - - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { - affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%o_h, %o_w, %tile_m, %tile_n) - // Initialize output - {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride - : memref<{{ O_C }}xf32>, memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ TILE_O_H }}, {{ TILE_O_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_W * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} - {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) }}xf32> - {%- endif %} - affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { - affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { - %index_i_h = affine.apply #map_I_H(%o_h, %k_h) - %index1 = affine.apply #map1(%index_i_h, %o_w, %tile_m, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %c0, %tile_k, %tile_n) // weight index - // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_I_H }}, {{ SUB_TILE_I_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_O_W * TILE_M * TILE_K }}, {{ TILE_M * TILE_K }}, 1, {{ TILE_M }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} - affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. - affine.for %tile_k_w = 0 to 1 { - %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ 1 }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - affine.for %tile_o_h = 0 to {{ TILE_O_H }} { - affine.for %tile_o_w = 0 to {{ TILE_O_W }} { - %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) - %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) - %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ TILE_I_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_O_H }}x{{ TILE_O_W }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - // Store output matrix - {{kernel.store_output(indent_size=10)}} - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - -SINGLE_BATCH_CONV_TEMPLATE = r""" -// Single Batch Conv2D kernel -// BATCH = {{ BATCH }} -// I_C = {{ I_C }} -// I_H = {{ I_H }} -// I_W = {{ I_W }} -// O_C = {{ O_C }} -// K_H = {{ K_H }} -// K_W = {{ K_W }} -// O_H = {{ O_H }} -// O_W = {{ O_W }} -// TILE_M = {{ TILE_M }} -// TILE_N = {{ TILE_N }} -// TILE_K = {{ TILE_K }} -// TILE_I_H={{ TILE_I_H }}, -// TILE_I_W={{ TILE_I_W }}, -// TILE_O_H={{ TILE_O_H }}, -// TILE_O_W={{ TILE_O_W }}, -// TILE_K_H={{ TILE_K_H }}, -// TILE_K_W={{ TILE_K_W }}, -// SUB_TILE_M={{ SUB_TILE_M }}, -// SUB_TILE_N={{ SUB_TILE_N }}, -// SUB_TILE_I_W={{ SUB_TILE_I_W }}, -// SUB_TILE_K_H={{ SUB_TILE_K_H }}, -// SUB_TILE_K_W={{ SUB_TILE_K_W }}, -// PADDING_H = {{ PADDING_H }} -// PADDING_W = {{ PADDING_W }} -// STRIDE_H = {{ STRIDE_H }} -// STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} -// DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * O_H * O_C }} + d1 * {{ O_W * O_C }} + d2 * {{ O_C }} + d3)> // output (BATCH, O_H, O_W, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * (I_H + 2 * PADDING_W) * I_C }} + d1 * {{ (I_W + 2 * PADDING_W) * I_C }} + d2 * {{ I_C }} + d3)> // input (BATCH, I_H, I_W, I_C) Stride should be changed if kernel stride > 1 -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) -#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> -#map_I_W = affine_map<(d0, d1) -> (d0 * {{ STRIDE_W }} + d1)> -#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> -#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_I_W, TILE_K) }} + d1)> -#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> -memref.global @X_spad : memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} - affine.for %o_w = 0 to {{ O_W }} step {{ TILE_O_W }} { - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%c0, %o_h, %tile_m, %tile_n) - // Initialize output - {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride // not implemented yet - : memref<{{ O_C }}xf32>, memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ 1 }}, {{ TILE_O_H }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_H * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} - {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - {%- endif %} - affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { - affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { - affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { - %index_i_h = affine.apply #map_I_H(%o_h, %k_h) - %index_i_w = affine.apply #map_I_W(%o_w, %k_w) - %index1 = affine.apply #map1(%c0, %index_i_h, %index_i_w, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %k_w, %tile_k, %tile_n) // weight index - // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ 1 }}, {{ SUB_TILE_I_H }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_I_H * TILE_I_W * TILE_K }}, {{ TILE_I_W * TILE_K }}, 1, {{ TILE_I_W }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} - affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. - affine.for %tile_k_w = 0 to {{ TILE_K_W }} { - %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - affine.for %tile_o_h = 0 to {{ TILE_O_H }} { - affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W - %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) - %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) - %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ 1 }}x{{ TILE_I_H }}x{{ TILE_I_W }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - // Store output matrix - {{kernel.store_output(indent_size=8)}} - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - -SINGLE_BATCH_CONV_STRIDE_TEMPLATE = r""" -// Single Batch Conv2D (Stride != 1) kernel -// BATCH = {{ BATCH }} -// I_C = {{ I_C }} -// I_H = {{ I_H }} -// I_W = {{ I_W }} -// O_C = {{ O_C }} -// K_H = {{ K_H }} -// K_W = {{ K_W }} -// O_H = {{ O_H }} -// O_W = {{ O_W }} -// TILE_M = {{ TILE_M }} -// TILE_N = {{ TILE_N }} -// TILE_K = {{ TILE_K }} -// TILE_I_H={{ TILE_I_H }}, -// TILE_I_W={{ TILE_I_W }}, -// TILE_O_H={{ TILE_O_H }}, -// TILE_O_W={{ TILE_O_W }}, -// TILE_K_H={{ TILE_K_H }}, -// TILE_K_W={{ TILE_K_W }}, -// SUB_TILE_M={{ SUB_TILE_M }}, -// SUB_TILE_N={{ SUB_TILE_N }}, -// SUB_TILE_I_W={{ SUB_TILE_I_W }}, -// SUB_TILE_K_H={{ SUB_TILE_K_H }}, -// SUB_TILE_K_W={{ SUB_TILE_K_W }}, -// PADDING_H = {{ PADDING_H }} -// PADDING_W = {{ PADDING_W }} -// STRIDE_H = {{ STRIDE_H }} -// STRIDE_W = {{ STRIDE_W }} -// DILATION_H = {{ DILATION_H }} -// DILATION_W = {{ DILATION_W }} -// DATA_STYPE = {{ DATA_STYPE }} -// DATA_SIZE = {{ DATA_SIZE }} - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ O_W * O_H * O_C }} + d1 * {{ O_W * O_C }} + d2 * {{ O_C }} + d3)> // output (BATCH, O_H, O_W, O_C) -#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ (I_W + 2 * PADDING_W) * I_C }} + d1 * {{ I_C }} + d2 * {{ I_C * STRIDE_W }} + d3)> // input (I_H, (k_w), I_W, I_C) // duplicate for k_w -#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * {{ K_W * I_C * O_C }} + d1 * {{ I_C * O_C }} + d2 * {{ O_C }} + d3)> // weight (K_H, K_W, I_C, O_C) -#map_I_H = affine_map<(d0, d1) -> (d0 * {{ STRIDE_H }} + d1)> -#offset_w_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_K_W * TILE_K, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_K, TILE_N) }})> -#offset_x_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M * TILE_K_W, TILE_K) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_K) }})> -#offset_y_map = affine_map<(d0, d1) -> (d0 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }} + d1 * {{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }})> - -memref.global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_conv_kernel(inputs=[X, W, BIAS], outputs=[Y], names_str="X, W, Bias, Y", padded_input_size=PADDED_INPUT_SIZE, input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index - %c_mvin3 = arith.constant 14 : index - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %input_axis = arith.constant 3 : index - %weight_axis = arith.constant 2 : index - %input_buffer = memref.get_global @X_spad : memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %weight_buffer = memref.get_global @W_spad : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %output_buffer = memref.get_global @Y_spad : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32> - %tag3 = memref.alloc() : memref<1xi32> - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - {{- kernel.def_local_vars() }} - - affine.for %o_h = 0 to {{ O_H }} step {{ TILE_O_H }} { - affine.for %tile_m = 0 to {{ O_W }} step {{ TILE_M }} { - affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { - %index0 = affine.apply #map0(%c0, %o_h, %tile_m, %tile_n) - // Initialize output - {%- if BIAS %} - memref.dma_start %Bias[%tile_n], %output_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag0[%c0], %c0, %vstride // not implemented yet - : memref<{{ O_C }}xf32>, memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ 1 }}, {{ TILE_O_H }}, {{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_O_H * TILE_M * TILE_N }}, {{ TILE_M * TILE_N }}, 1, {{ TILE_M }}]} - {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> - {%- endif %} - affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { - affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { - affine.for %tile_k = 0 to {{ I_C }} step {{ TILE_K }} { - %index_i_h = affine.apply #map_I_H(%o_h, %k_h) - %index1 = affine.apply #map1(%index_i_h, %k_w, %tile_m, %tile_k) // input index - %index2 = affine.apply #map2(%k_h, %k_w, %tile_k, %tile_n) // weight index - // Load input matrix - memref.dma_start %X[%index1], %input_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag1[%c0], %input_axis, %vstride - : memref<{{ BATCH * I_C * (I_H + 2 * PADDING_H) * (I_W + 2 * PADDING_W) }}xf32>, memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_I_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[{{ TILE_K_W * TILE_M * TILE_K }}, {{ TILE_M * TILE_K }}, 1, {{ TILE_M }}]} - // Load kernel matrix - memref.dma_start %W[%index2], %weight_buffer[%c0, %c0, %c0, %c0], %c_mvin, %tag2[%c0], %input_axis, %vstride - : memref<{{ O_C * I_C * K_H * K_W }}xf32>, memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K_H }}, {{ SUB_TILE_K_W }}, {{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[{{ TILE_K_W * TILE_K * TILE_N }}, {{ TILE_K * TILE_N }}, 1, {{ TILE_K }}]} - affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. - affine.for %tile_k_w = 0 to {{ TILE_K_W }} { - %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - affine.for %tile_o_h = 0 to {{ TILE_O_H }} { - affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W - %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) - %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) - %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<{{ TILE_I_H }}x{{ TILE_K_W }}x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ 1 }}x{{ TILE_O_H }}x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { inner_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - // Store output matrix - {{kernel.store_output(indent_size=8)}} - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } return } """ @@ -514,40 +118,14 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: X_padding = torch.zeros(padded_shape, device=X.device) X_padding[:, :, {{ PADDING_H }}:X.shape[2] + {{ PADDING_H }}, {{ PADDING_W }}:X.shape[3] + {{ PADDING_W }}] = X - # Holding original output tensor - {%- for buf, name in kernel.get_conv_outputs().items() %} - {{ name }}_t = {{ name }} - {%- endfor %} - # Tanspose inputs {%- for buf, name in kernel.get_conv_inputs().items() %} {%- if name == "X" %} - {%- if MULTI_TILE %} - {{ name }} = {{ name }}_padding.permute(2, 0, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, BATCH, I_W, I_C) - {%- elif SINGLE_BATCH %} - {{ name }} = {{ name }}_padding.permute(0, 2, 3, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (BATCH, I_H, I_W, I_C) - {%- else %} {{ name }} = {{ name }}_padding.permute(2, 3, 0, 1).contiguous() # (BATCH, I_C, I_H, I_W) -> (I_H, I_W, BATCH, I_C) - {%- endif %} {%- elif name == "W" %} {{ name }} = {{ name }}.permute(2, 3, 1, 0).contiguous() # (O_C, I_C, K_H, K_W) -> (K_H, K_W, I_C, O_C) {%- elif name == "Bias" %} {{ name }} = {{ name }} - {%- else %} - {%- if SINGLE_BATCH %} - {{ name }} = {{ name }}.permute(0, 2, 3, 1).contiguous() if {{ name }}.dim() == 4 else {{ name }} # (BATCH, O_C, O_H, O_W) -> (BATCH, O_H, O_W, O_C) - {%- else %} - {{ name }} = {{ name }}.permute(2, 3, 0, 1).contiguous() if {{ name }}.dim() == 4 else {{ name }} # (BATCH, O_C, O_H, O_W) -> (O_H, O_W, BATCH, O_C) - {%- endif %} - {%- endif %} - {%- endfor %} - - # Transpose outputs - {%- for buf, name in kernel.get_conv_outputs().items() %} - {%- if SINGLE_BATCH %} - {{ name }} = {{ name }}.permute(0, 2, 3, 1).contiguous() # (BATCH, O_C, O_H, O_W) -> (BATCH, O_H, O_W, O_C) - {%- else %} - {{ name }} = {{ name }}.permute(2, 3, 0, 1).contiguous() # (BATCH, O_C, O_H, O_W) -> (O_H, O_W, BATCH, O_C) {%- endif %} {%- endfor %} @@ -556,15 +134,6 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: {%- if BACKENDSIM_EAGER_MODE %} yield ({{KERNEL_NAME}}, ) {%- endif %} - - # Transpose back outputs - {%- for buf, name in kernel.get_conv_outputs().items() %} - {%- if SINGLE_BATCH %} - {{ name }}_t.copy_({{ name }}.permute(0, 3, 1, 2).contiguous()) # (BATCH, O_H, O_W, O_C) -> (BATCH, O_C, O_H, O_W) - {%- else %} - {{ name }}_t.copy_({{ name }}.permute(2, 3, 0, 1).contiguous()) # (O_H, O_W, BATCH, O_C) -> (BATCH, O_C, O_H, O_W) - {%- endif %} - {%- endfor %} """ class MLIRConvTemplate(MLIRTemplate): @@ -581,21 +150,6 @@ def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + "_" + "_".join([str(i) for i in self.dilation]) self.kernel_args = ['X', 'W', 'Bias', 'Y'] - def is_transposed(self, node): - if isinstance(node, ReinterpretView): - if node.layout.stride != node.data.layout.stride: - if node.layout.stride[-2] == node.data.layout.stride[-1] and node.layout.stride[-1] == node.data.layout.stride[-2]: - return True - else: - raise NotImplementedError("If the stride is not equal to the original stride, it should have been transposed.") - return False - - def is_multi_tile(self, I_C): - return I_C < (self.kernel.vector_lane // 8) # 8 is hard-coded for now. This should be changed to a better heuristic. - - def is_single_batch(self, BATCH): - return BATCH == 1 - def get_padded_input_size(self, X): input_padded = list(X.layout.size) input_padded[2] += 2 * self.padding[0] @@ -607,6 +161,7 @@ def render(self, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, **kwargs): + # Extract input arguments info if template_buffer_node is not None: self.output_node = template_buffer_node self.kernel = kernel @@ -617,93 +172,68 @@ def render(self, Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - BATCH = X.layout.size[0] - I_C = X.layout.size[1] - O_C = W.layout.size[0] - K_H = W.layout.size[2] - K_W = W.layout.size[3] + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_I_W * TILE_I_H * TILE_M, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) - x_spad_size = TILE_I_W * TILE_I_H * TILE_M * TILE_K - w_spad_size = TILE_K_W * TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_O_W * TILE_M * TILE_N + # Select tile size adn template conv_template = CONV_TEMPLATE - TOG_latency = BATCH if TILE_M > BATCH else TILE_M - if self.is_single_batch(BATCH) and self.stride[0] != 1: - conv_template = SINGLE_BATCH_CONV_STRIDE_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_I_H * TILE_M, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) - x_spad_size = TILE_K_W * TILE_I_H * TILE_M * TILE_K - w_spad_size = TILE_K_W * TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_M * TILE_N - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - elif self.is_single_batch(BATCH) and self.stride[0] == 1: - conv_template = SINGLE_BATCH_CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_I_W * TILE_I_H, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_W * TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) - x_spad_size = TILE_I_W * TILE_I_H * TILE_K - w_spad_size = TILE_K_W * TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_M * TILE_N - TOG_latency = O_W if TILE_M > O_W else TILE_M - elif self.is_multi_tile(I_C): - conv_template = MULTI_TILE_CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - SUB_TILE_K = TILE_K - x_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_I_W * TILE_I_H * TILE_M, TILE_K) - w_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_K_H * TILE_K, TILE_N) - y_spad_size_per_lane = kernel.get_spad_size_per_lane(TILE_O_H * TILE_O_W * TILE_M, TILE_N) - x_spad_size = TILE_I_W * TILE_I_H * TILE_M * TILE_K - w_spad_size = TILE_K_H * TILE_K * TILE_N - y_spad_size = TILE_O_H * TILE_O_W * TILE_M * TILE_N + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_I_H, TILE_I_W, TILE_M, TILE_K ] + X_tile_stride = [TILE_I_W*TILE_M*TILE_K, TILE_M*TILE_K, 1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("input_buffer") + X_dim = [Symbol("index_i_h"), Symbol("index_i_w"), Symbol("tile_m"), Symbol("tile_k")] + X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*BATCH, X_dim[2]*I_C, X_dim[3]] + + W_tile_size = [TILE_K_H, TILE_K_W, TILE_K, TILE_N] + W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("weight_buffer") + W_dim = [Symbol("k_h"), Symbol("k_w"), Symbol("tile_k"), Symbol("tile_n")] + W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] + + Y_tile_size = [TILE_M, TILE_N, TILE_O_H, TILE_O_W] + Y_tile_stride = [1, TILE_M, TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N] # N, C, H, W + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("output_buffer") + Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] + Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] + + # Extract Bias info + Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - X=X, - W=W, - BIAS=Bias, - Y=Y, + X=X, W=W, Y=Y, BIAS=Bias, PADDED_INPUT_SIZE=self.get_padded_input_size(X), - BATCH=X.layout.size[0], - I_C=X.layout.size[1], - I_H=X.layout.size[2], - I_W=X.layout.size[3], + BATCH=BATCH, + I_C=I_C, + I_H=I_H, + I_W=I_W, O_C=O_C, K_H=K_H, K_W=K_W, @@ -725,43 +255,46 @@ def render(self, SUB_TILE_I_W=SUB_TILE_I_W, SUB_TILE_K_H=SUB_TILE_K_H, SUB_TILE_K_W=SUB_TILE_K_W, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - STRIDE_H=self.stride[0], - STRIDE_W=self.stride[1], - DILATION_H=self.dilation[0], - DILATION_W=self.dilation[1], + PADDING_H=PADDING_H, + PADDING_W=PADDING_W, + STRIDE_H=STRIDE_H, + STRIDE_W=STRIDE_W, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, + X_idx = X_idx, + W_idx = W_idx, + Bias_idx = Bias_idx, DATA_STYPE="f32", - DATA_SIZE=4, input_reorder=self.input_reorder ) kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], sram_var = "output_buffer", dram_var = "Y", - index_var = "index0", - tag_var = "tag", - vlane_split_axis = 3, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - dram_shape = f"memref<{BATCH * O_C * O_H * O_W}x{kernel.render_options['DATA_STYPE']}>", - tile_size = (TILE_O_H, TILE_O_W, TILE_M, TILE_N) if conv_template in (CONV_TEMPLATE, MULTI_TILE_CONV_TEMPLATE) else (1, TILE_O_H, TILE_M, TILE_N), - tile_stride = [TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N, 1, TILE_M] + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, + dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} ) + kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} code = self._template_from_string(conv_template).render(**kernel.render_options) - self.header = f"float X_spad[{x_spad_size_per_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{w_spad_size_per_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{y_spad_size_per_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{x_spad_size}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{w_spad_size}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{y_spad_size}] __attribute__ ((section(\".spad\")));\n" - kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) - return code + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = BATCH if TILE_M > BATCH else TILE_M + TOG_latency = 8 if TOG_latency < 8 else TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node @@ -778,8 +311,6 @@ def outer_func_render(self, kernel_name, input_args): OUTPUT=Y, PADDING_H=self.padding[0], PADDING_W=self.padding[1], - MULTI_TILE=self.is_multi_tile(self.input_shape[1]), - SINGLE_BATCH=self.is_single_batch(self.input_shape[0]), VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, BACKENDSIM_EAGER_MODE=eager_mode, input_reorder=self.input_reorder @@ -813,7 +344,7 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) + write_atomic(gem5_write_path, extra_headers[1]) self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index bfd0633b..ace6ea9d 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -1,16 +1,17 @@ import os import json +from pathlib import Path from torch import empty_strided -from typing import List, Optional, cast +from typing import List, Optional +import sympy from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel -from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode -from torch._inductor.ir import ReinterpretView from torch._inductor.codecache import write_atomic import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common GEMM_TEMPLATE = r""" // GEMM {% if prologue_nodes -%}prologue fused{%- endif %} {% if epilogue_nodes -%}eilogue fused{%- endif %} kernel @@ -22,62 +23,36 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -#map3 = affine_map<(d0, d1) -> (d0)> -#map4 = affine_map<(d0, d1) -> (d0 + d1 * {{ M }})> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 1 : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - - affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - %index2 = affine.apply #map2(%t_m, %t_n) - %index3 = affine.apply #map3(%t_m, %c0) - %index4 = affine.apply #map4(%t_m, %t_n) + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { {%- if Bias %} - memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], {{ Bias_axis }}, %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%t_m, %t_k) - %index1 = affine.apply #map1(%t_k, %t_n) + affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {% if prologue_nodes -%} // prologue nodes - {{kernel.prepare_input(indent_size=8)}} + {{kernel.load_input(indent_size=8)}} {%- else -%} - memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K, SUB_TILE_N], indent_size=8) }} {%- endif %} - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true } + linalg.matmul ins(%X_buffer, %W_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }}) + outs(%Y_buffer : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}) + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=6)}} - } { outer_loop=true } - } { outer_loop=true } + } { outer_loop=true, subtile_loop="n" } + } { outer_loop=true, subtile_loop="m" } return } """ @@ -98,58 +73,34 @@ // TILE_K = {{ TILE_K }} // SUB_TILE_M = {{ SUB_TILE_M }} // SUB_TILE_N = {{ SUB_TILE_N }} -#map0 = affine_map<(d0, d1) -> ({{ X_map }})> -#map1 = affine_map<(d0, d1) -> ({{ W_map }})> -#map2 = affine_map<(d0, d1) -> (d0 * {{ N }} + d1)> -#map3 = affine_map<(d0, d1) -> (d0)> -#map4 = affine_map<(d0, d1) -> (d0 + d1 * {{ M }})> -memref.global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> -memref.global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> -memref.global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { - %c_mvin = arith.constant 2 : index - %c_mvin2 = arith.constant 1 : index{% if Bias %} - %c_mvin3 = arith.constant 14 : index{% endif %} - %c_mvout = arith.constant 3 : index - %vstride = arith.constant 1 : index - %axis = arith.constant 1 : index - %X_buffer = memref.get_global @X_spad : memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> - %tag = memref.alloc() : memref<1xi32> - %tag0 = memref.alloc() : memref<1xi32> - %tag1 = memref.alloc() : memref<1xi32> - %tag2 = memref.alloc() : memref<1xi32>{% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} - %c0 = arith.constant 0 : index -{{ kernel.def_local_vars() }} - - affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { - {{kernel.reduction_acc()}} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {{kernel.reduction_iter_arg()}} { - %index2 = affine.apply #map2(%t_m, %t_n) - %index3 = affine.apply #map3(%t_m, %c0) - %index4 = affine.apply #map4(%t_m, %t_n) + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {% if not Bias %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + {% endif %} + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> {%- if Bias %} - memref.dma_start %Bias[{{ Bias_idx }}], %Y_buffer[%c0, %c0], %c_mvin3, %tag0[%c0], {{ Bias_axis }}, %vstride : memref<{{ Bias.data.get_numel() }}xf32>, memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_M }}] } + {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {%- endif %} - affine.for %t_k = 0 to {{ K }} step {{ TILE_K }} { - %index0 = affine.apply #map0(%t_m, %t_k) - %index1 = affine.apply #map1(%t_k, %t_n) - memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag1[%c0], %axis, %vstride - : memref<{{ M * K }}xf32>, memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_M }}, {{ SUB_TILE_K }}], async=1, sram_stride=[1, {{ TILE_M }}]} - memref.dma_start %W[%index1], %W_buffer[%c0, %c0], %c_mvin2, %tag2[%c0], %axis, %vstride - : memref<{{ K * N }}xf32>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>, memref<1xi32> { subtile_size=[{{ SUB_TILE_K }}, {{ SUB_TILE_N }}], async=1, sram_stride=[1, {{ TILE_K }}]} - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1>) - } { accumulation_loop=true, loop_k=true } + affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} + {{ kernel.def_dma_op("MVIN", "W", W_idx, W_tile_desc, subtile_size=[SUB_TILE_K, SUB_TILE_N], indent_size=8) }} + linalg.matmul ins(%X_buffer, %W_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }}) + outs(%Y_bufferT : memref<{{TILE_M}}x{{TILE_N}}x{{DATA_STYPE}}, 1>) + } { accumulation_loop=true, subtile_loop="k" } {{kernel.store_output(indent_size=6)}} - } { outer_loop=true, loop_m=true} + } { outer_loop=true, subtile_loop="m" } {{kernel.reduction_output(indent_size=4)}} - } { outer_loop=true, loop_n=true } + } { outer_loop=true, subtile_loop="n" } return } """ @@ -166,114 +117,90 @@ def render(self, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node - # if epilogue_nodes is not None and len(epilogue_nodes) > 0: - # self.output_node = cast(Buffer, epilogue_nodes[-1]) #FIXME: Temperary solution - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - - W_tensor = empty_strided(W.layout.size, W.layout.stride) - X_tensor = empty_strided(X.layout.size, X.layout.stride) + # Extract input arguments info + X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + X_tensor = empty_strided(X.layout.size, X.layout.stride) + W_tensor = empty_strided(W.layout.size, W.layout.stride) if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: raise NotImplementedError("Please report this case to us...") - W_stride = W_tensor.stride() - X_stride = X_tensor.stride() - W_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(W_stride)]) - X_map = " + ".join([f"d{idx}*{s}" for idx, s in enumerate(X_stride)]) - M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] - n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - # Caculate extra reads + # Extract fusion info + n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 n_extra_read = set() if epilogue_nodes is not None: - for enode in epilogue_nodes: - n_extra_read.update(enode.node.get_read_names()) - if self.output_node.name in n_extra_read: - n_extra_read.remove(self.output_node.name) + for enode in epilogue_nodes: + n_extra_read.update(enode.node.get_read_names()) + if self.output_node.name in n_extra_read: + n_extra_read.remove(self.output_node.name) - n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 - nr_rdim = 0 - # Determine tile size - # case 1: use cheat sheet - if extension_config.CONFIG_GEMM_CHEATSHEET_PATH is not None: - try: - with open(extension_config.CONFIG_GEMM_CHEATSHEET_PATH, "r") as f: - data = json.load(f) - except FileNotFoundError: - data = {} - gemm_shape = f"{M}_{K}_{N}" - if gemm_shape in data: - tile_info = data[gemm_shape] - TILE_M = tile_info["TILE_M"] - TILE_N = tile_info["TILE_N"] - TILE_K = tile_info["TILE_K"] - else: # case 2: use gemm_combination_mapping - min_tile = (n_extra_node + n_prologue_node) == 0 - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=min_tile) - # case 3: use manual tile size - if extension_config.CONFIG_MANUAL_TILE_SIZE: - TILE_M = extension_config.CONFIG_TILE_M - TILE_N = extension_config.CONFIG_TILE_N - TILE_K = extension_config.CONFIG_TILE_K + # Select tile size + M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + # Select template code if (M == 0) or (N == 0) or (K == 0): # exception for MoE - TILE_M, TILE_N, TILE_K = 1, 1, 1 template = EMPTY_TEMPLATE - elif n_extra_node>=1 and epilogue_nodes[0].is_reduction(): + nr_rdim = 0 + elif n_epilogue_node>=1 and epilogue_nodes[0].is_reduction(): template = GEMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index1", "index1":"index0"} nr_rdim = 1 else: template = GEMM_TEMPLATE - - TILE_M = min(extension_config.CONFIG_FORCE_TILE_M, TILE_M) - TILE_N = min(extension_config.CONFIG_FORCE_TILE_N, TILE_N) - TILE_K = min(extension_config.CONFIG_FORCE_TILE_K, TILE_K) - - # Calculate Sub Tile Size for fine-grained DMA - if extension_config.CONFIG_SUBTILE: - # Case 1: adjust selective fine-grained DMA (SFG-DMA) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane - if (TILE_M == M and TILE_N == N and TILE_N <= 512): - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - else: # Avoid Row Conflict of weights - SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K - # Case 2: use manual sub tile size (FG-DMA) - if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: - SUB_TILE_M = extension_config.CONFIG_SUBTILE_M - SUB_TILE_N = extension_config.CONFIG_SUBTILE_N - SUB_TILE_K = extension_config.CONFIG_SUBTILE_K - # Case 3: None Subtile - else: - SUB_TILE_M = TILE_M - SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1"} + nr_rdim = 0 TOG_latency = M if SUB_TILE_M > M else SUB_TILE_M kernel.loop_size =[TOG_latency, SUB_TILE_N, SUB_TILE_K] + # Prepare tile descriptors + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_M, TILE_K] + X_tile_stride = [1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index2") * X_stride[1]] # To keep index arguemnt order, we used index_list + + W_tile_size = [TILE_K, TILE_N] + W_tile_stride = [1, TILE_K] + W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) + W_tile_desc.set_name("W_buffer") + W_stride = W.get_layout().stride + W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]] + + vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0 + Y_tile_size = [TILE_M, TILE_N] if nr_rdim == 0 else [TILE_N, TILE_M] + Y_tile_stride=[1, TILE_M] if nr_rdim == 0 else [TILE_M, 1] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + if nr_rdim == 0: + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] + else: + Y_idx = [sympy.Symbol("index1") * Y_stride[1], sympy.Symbol("index0") * Y_stride[0]] + # Extract Bias info Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] if Bias is not None: - if Bias.data.get_numel() == M*N: - Bias_idx = "%index2" - Bias_axis = "%axis" - elif Bias.data.get_numel() == M: - Bias_idx = "%index3" - Bias_axis = "%axis" + Bias_stride = Bias.get_layout().stride + if nr_rdim == 0: + Bias_idx = [sympy.Symbol("index0") * Bias_stride[0], sympy.Symbol("index1") * Bias_stride[1]] else: - Bias_idx = "%t_n" - Bias_axis = "%c0" + Bias_idx = [sympy.Symbol("index1") * Bias_stride[1], sympy.Symbol("index0") * Bias_stride[0]] else: Bias_idx = None - Bias_axis = None kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - M=M, - N=N, - K=K, + M=M, N=N, K=K, TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, @@ -281,73 +208,120 @@ def render(self, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, DATA_STYPE="f32", - DATA_SIZE=4, - X = X, - W = W, - Y = Y, + X = X, W = W, Y = Y, Bias = Bias, + X_idx = X_idx, + W_idx = W_idx, Bias_idx = Bias_idx, - Bias_axis = Bias_axis, - X_map = X_map, - W_map = W_map, - Y_numel = M * N, + X_tile_desc = X_tile_desc, + W_tile_desc = W_tile_desc, + Y_tile_desc = Y_tile_desc, epilogue_nodes = epilogue_nodes, prologue_nodes = prologue_nodes, input_reorder = self.input_reorder ) - kernel.prologue_info = dict ( - input_sram_var = "X_buffer", - input_dram_var = "X", - input_index_var = "index0", - input_tag_var = "tag1", - input_numel = M * K, - input_tile_size = (TILE_M, TILE_K), - input_sram_stride = [1, TILE_M], - vector_sram_stride = [TILE_M, 1], - input_subtile_size = (SUB_TILE_M, SUB_TILE_K), - weight_sram_var = "W_buffer", - weight_dram_var = "W", - weight_index_var = "index1", - weight_tag_var = "tag2", - weight_numel = K * N, - weight_tile_size = (TILE_K, TILE_N), - weight_sram_stride = [1, TILE_K], - weight_subtile_size = (SUB_TILE_K, SUB_TILE_N), - tile_size = (TILE_M, TILE_K), - vlane_split_axis = 1, - vlane_stride = 1, - is_bmm = False, - ) + if prologue_nodes: + prologue_output_name = list(prologue_nodes[0].read_writes.writes)[0].name + if prologue_output_name == X.get_name(): + # Input fusion case + prologue_var = "X" + prologue_sram_var = "X_buffer" + prologue_tile_desc = X_tile_desc + prologue_dim_aliasing = {"index0":"index0", "index1":"index2"} + is_input_fused = True + else: + # Weight fusion case + prologue_var = "W" + prologue_sram_var = "W_buffer" + prologue_tile_desc = W_tile_desc + prologue_dim_aliasing = {"index0":"index2", "index1":"index1"} + is_input_fused = False + + kernel.prologue_info = dict ( + input_dram_var = "X", + input_sram_var = "X_buffer", + input_tile_desc = X_tile_desc, + input_idx = X_idx, + input_subtile_size = [TILE_M, TILE_K], + input_dim_aliasing = {"index0":"index0", "index1":"index2"}, + + weight_dram_var = "W", + weight_sram_var = "W_buffer", + weight_tile_desc = W_tile_desc, + weight_idx = W_idx, + weight_subtile_size = [TILE_K, TILE_N], + weight_dim_aliasing = {"index0":"index2", "index1":"index1"}, + + # Descriptor for fusion + dram_var = prologue_var, + sram_var = prologue_sram_var, + dram_tile_desc = prologue_tile_desc, + dim_aliasing = prologue_dim_aliasing, + is_bmm = False, + is_input_fused = is_input_fused + ) kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], - sram_var = "Y_buffer", dram_var = "Y", - index_var = "index2", - t_index_var = "index4", # FIXME: for epilogue transposed input - tag_var = "tag", - vlane_split_axis = 1, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - dram_shape = f"memref<{kernel.render_options['Y_numel']}x{kernel.render_options['DATA_STYPE']}>", - tile_size = (TILE_M, TILE_N), - tile_stride = [1, TILE_M], + sram_var = "Y_buffer", + dram_idx = Y_idx, + dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, - reduction_idx = "t_n" + dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + return code - self.header = f"float X_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_K)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float W_spad[{kernel.get_spad_size_per_lane(TILE_K, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{kernel.get_spad_size_per_lane(TILE_M, TILE_N)}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{TILE_M * TILE_K}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float W_spad[{TILE_K * TILE_N}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{TILE_M * TILE_N}] __attribute__ ((section(\".spad\")));\n" + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + # Check cheat sheet + cheatsheet_path = extension_config.CONFIG_GEMM_CHEATSHEET_PATH + data = {} + if extension_config.CONFIG_GEMM_CHEATSHEET_PATH is not None: + path = Path(cheatsheet_path) + if path.is_file(): + with path.open("r") as f: + data = json.load(f) - kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) + gemm_shape = f"{M}_{K}_{N}" + if gemm_shape in data: + tile_info = data[gemm_shape] + TILE_M = tile_info["TILE_M"] + TILE_N = tile_info["TILE_N"] + TILE_K = tile_info["TILE_K"] + else: # case 2: use gemm_combination_mapping + min_tile = (n_extra_node + n_prologue_node) == 0 + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=min_tile) + # case 3: use manual tile size + if extension_config.CONFIG_MANUAL_TILE_SIZE: + TILE_M = extension_config.CONFIG_TILE_M + TILE_N = extension_config.CONFIG_TILE_N + TILE_K = extension_config.CONFIG_TILE_K - return code + # Edge case + if (M == 0) or (N == 0) or (K == 0): + TILE_M, TILE_N, TILE_K = 1, 1, 1 + + # Calculate Sub Tile Size for fine-grained DMA + if extension_config.CONFIG_SUBTILE: + # Case 1: adjust selective fine-grained DMA (SFG-DMA) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane + if (TILE_M == M and TILE_N == N and TILE_N <= 512): + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + else: # Avoid Row Conflict of weights + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + # Case 2: use manual sub tile size (FG-DMA) + if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: + SUB_TILE_M = extension_config.CONFIG_SUBTILE_M + SUB_TILE_N = extension_config.CONFIG_SUBTILE_N + SUB_TILE_K = extension_config.CONFIG_SUBTILE_K + # Case 3: None Subtile + else: + SUB_TILE_M = TILE_M + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + return TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) @@ -356,6 +330,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) + write_atomic(gem5_write_path, extra_headers[1]) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index b1e1ba0e..aa3cf16e 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -11,7 +11,11 @@ from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate from PyTorchSimFrontend.mlir.mlir_conv_template import MLIRConvTemplate +from PyTorchSimFrontend.mlir.mlir_conv_mt_template import MLIRConvMultiTileTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.extension_config import CONFIG_VECTOR_LANE aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") @@ -96,9 +100,20 @@ def convolution( x.realize() weight.realize() x = ir.ExternKernel.require_channels_last(x) + BATCH = x.layout.size[0] + I_C = x.layout.size[1] weight = ir.ExternKernel.require_channels_last(weight) layout = conv_layout(x, weight, None, **kwargs) - mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) + + # Select conv kernel + if BATCH == 1 and stride[0] == 1: + mlir_template = MLIRConvSingleBatchTemplate([x, weight, bias], layout, **kwargs) + elif BATCH == 1 and stride[0] != 1: + mlir_template = MLIRConvSingleBatchStridedTemplate([x, weight, bias], layout, **kwargs) + elif I_C < CONFIG_VECTOR_LANE // 8: # 8 is hard-coded for now. This should be changed to a better heuristic. + mlir_template = MLIRConvMultiTileTemplate([x, weight, bias], layout, **kwargs) + else: + mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) return mlir_template.generate().output_node() def maxpool_layout( diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index ff617eb4..5395efb2 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -26,8 +26,8 @@ affine.for %i = 0 to {{ BCH }} step {{ out_tile }} { affine.for %j = 0 to {{ W }} step {{ out_tile }} { %index0 = affine.apply #map0(%i, %j) - memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag[%c0], %axis, %vstride : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> - memref.dma_start %Y_buffer[%c0, %c0], %Y[%index0], %c_mvout, %tag[%c0], %axis, %vstride : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> + memref.dma_start %X[%index0], %X_buffer[%c0, %c0], %c_mvin, %tag[%c0], %axis, %vstride : memref<{{ IN }}xf32>, memref<{{ in_tile }}x{{ in_tile }}xf32, 1>, memref<1xi32> {dram_stride=[{{W}}, 1]} + memref.dma_start %Y_buffer[%c0, %c0], %Y[%index0], %c_mvout, %tag[%c0], %axis, %vstride : memref<{{ out_tile }}x{{ out_tile }}xf32, 1>, memref<{{ OUT }}xf32>, memref<1xi32> {dram_stride=[{{W}}, 1]} } { outer_loop=true } } { outer_loop=true } return @@ -62,6 +62,7 @@ def render(self, W = Y.get_size()[3] BCH = B * C * H kernel.loop_size = None + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -77,26 +78,10 @@ def render(self, ) kernel.epilogue_info = dict( output_node = self.output_node.name, - dependent_buf = [], sram_var = "Y_buffer", dram_var = "Y", - index_var = "index0", - tag_var = "tag", - vlane_split_axis = 1, - vlane_stride = 1, - mlir_dtype = kernel.render_options['DATA_STYPE'], - tile_nr_dim = 2, - dram_shape = f"memref<{kernel.render_options['OUT']}x{kernel.render_options['DATA_STYPE']}>", - tile_shape = f"memref<{out_tile}x{out_tile}x{kernel.render_options['DATA_STYPE']}, 1>", - tile_size = (out_tile, out_tile), - tile_stride = [1, out_tile] ) code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - self.header = f"float X_spad[{in_tile * in_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.header += f"float Y_spad[{out_tile * out_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header = f"float X_spad[{in_tile * in_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - self.gem5_header += f"float Y_spad[{out_tile * out_tile // kernel.vector_lane}] __attribute__ ((section(\".spad\")));\n" - kernel.add_loop_info([kernel.render_options["IN"]], [kernel.vector_lane, kernel.vector_lane]) return code @@ -107,6 +92,6 @@ def codegen_header(self, code, extra_headers): spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, self.header+extra_headers[0]) + write_atomic(spike_write_path, extra_headers[0]) if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, self.gem5_header+extra_headers[1]) \ No newline at end of file + write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 41264a74..84418ec7 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -2,7 +2,7 @@ import math from functools import reduce import operator -from sympy import symbols, sympify +from sympy import symbols, sympify, Symbol from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel @@ -10,6 +10,8 @@ from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V +from torch._inductor.ir import LoopBody +from torch._inductor import dependencies from . import mlir_common from . import mlir_lowering @@ -47,7 +49,7 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # Directed linked? dependency_check = node2 in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) - return size_match and layout_possible and dependency_check & dependency_size + return size_match and layout_possible and dependency_check and dependency_size # For prologue fusion case if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: @@ -63,6 +65,7 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if len(node1.read_writes.writes) != 1: return False if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + node1 = self.revert_group(node1) return True return self.scheduler.can_fuse_origin(node1, node2) @@ -84,12 +87,7 @@ def can_fuse_horizontal(self, node1, node2): return False # Can't fuse two template node - nr_template = 0 - for node in node1.get_nodes() + node2.get_nodes(): - if node.is_template(): - nr_template += 1 - - if nr_template > 1: + if node1.is_template() and node2.is_template(): return False # Check template node fusion @@ -100,34 +98,48 @@ def can_fuse_horizontal(self, node1, node2): node2.is_template() and len(node1.get_nodes())==1 and isinstance(node2.node.template, MLIRMaxPoolTemplate): return False - # Different layout is not supported - if node1.get_nodes()[0].node.layout.dtype != node2.get_nodes()[0].node.layout.dtype: - return False - - # Convolution is currently not supported - # if node1.is_template() and node1.get_nodes()[0].node.origin_node is not None and hasattr(node1.get_nodes()[0].node.origin_node.target, "_name") and node1.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': - # return False - - # if node2.is_template() and node2.get_nodes()[0].node.origin_node is not None and hasattr(node2.get_nodes()[0].node.origin_node.target, "_name") and node2.get_nodes()[0].node.origin_node.target._name == 'aten::convolution': - # return False - + # Pointwise check v1_total = math.prod(vars1) if len(vars1) else 0 v2_total = math.prod(vars2) if len(vars2) else 0 if v1_total != v2_total: return False - has_depedency = False - template_node = node1 if node1.is_template() else node2 - act_node = node2 if node1.is_template() else node1 - for write_buf in template_node.read_writes.writes: - has_depedency = has_depedency or (write_buf in act_node.read_writes.reads) - return has_depedency + # Pattern check + template_node, act_node = (node1, node2) if node1.is_template() else (node2, node1) + has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) + if not has_depedency: + return False + + # Revert act_node.group : simplify_and_reorder() modified _body, _size, group + if template_node.group != act_node.group: + self.revert_group(act_node) + return True # Check elementwise fusion if vars1 == vars2 and reduce1 == reduce2: return True return False + def revert_group(self, act_node): + args, var_ranges = dependencies.index_vars_no_squeeze( + act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" + ) + body = LoopBody( + act_node.node.get_store_function(), + (args if act_node.node.get_reduction_type() else args[:1]), + var_ranges, + ) + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + index_size.append(s) + else: + reduce_size.append(s) + node_device = act_node.get_device() + ranges = (index_size, reduce_size) + act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) + def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) @@ -196,74 +208,73 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e for node in [template_node, *prologue_nodes, *epilogue_nodes]: node.mark_run() partial_code = render() - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) if prologue_nodes: - _, (group, reduction_group) = max( - [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) - ).group - prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) - kernel.kernel_group.set_prologue_tile_info(prologue_tile_desc) - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - # Flush created varaibles, since template fusion doen't share variable - kernel.cse.cache.clear() - kernel.prologue_buffer_group.set_buffers() - kernel.call_ranges = None - kernel.load = kernel.load_prologue - kernel.store = kernel.store_prologue - for node in prologue_nodes: - # Reuse created spad - read_list = sorted(list(node.read_writes.reads)) - candidate_found = False - # Why? There is a case that memdep.get_size() != data.get_size() - buf_dict = {} - buf_dict.update({val.name : val for val in V.graph.buffers}) - for candidate_read in read_list: - if candidate_read.name in buf_dict and reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): - prologue_input_arg = candidate_read.name - candidate_found = True - break - assert(candidate_found) - assert(len(node.read_writes.writes)==1) - prologue_output_arg = list(node.read_writes.writes)[0].name - template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] - if template_node.get_nodes()[0].node.origin_node.target._name == 'aten::bmm': - target_buf = f"{template_buf}_buffer2D" - else: - target_buf = f"{template_buf}_buffer" - - # To skip the dma code gen - kernel.buffer_names[prologue_input_arg] = target_buf - kernel.buffer_names[prologue_output_arg] = target_buf - - # Edge delete - kernel.kernel_group.args.input_buffers = { - (arg if buf != template_buf else prologue_input_arg): buf - for arg, buf in kernel.kernel_group.args.input_buffers.items() - } - node.codegen((vars, reduction_vars)) + # Flush created varaibles, since template fusion doen't share variable + with kernel.prologue_buffer_group.as_local(): + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_prologue + _, (group, reduction_group) = max( + [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) + ).group + prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) + kernel.kernel_group.set_tile_info(prologue_tile_desc) + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in prologue_nodes: + # Reuse created spad + read_list = sorted(list(node.read_writes.reads)) + candidate_found = False + # Why? There is a case that memdep.get_size() != data.get_size() + buf_dict = {} + buf_dict.update({val.name : val for val in V.graph.buffers}) + buf_dict.update(V.graph.graph_inputs) + for candidate_read in read_list: + if candidate_read.name in buf_dict and reduce(operator.mul, buf_dict[candidate_read.name].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read.name + candidate_found = True + break + assert(candidate_found) + assert(len(node.read_writes.writes)==1) + prologue_output_arg = list(node.read_writes.writes)[0].name + template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] + target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? + + # To skip the dma code gen + kernel.buffer_names[prologue_input_arg] = target_buf + kernel.buffer_names[prologue_output_arg] = target_buf + + # Edge delete + kernel.kernel_group.args.input_buffers = { + (arg if buf != template_buf else prologue_input_arg): buf + for arg, buf in kernel.kernel_group.args.input_buffers.items() + } + node.codegen((vars, reduction_vars)) + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) if epilogue_nodes: - _, (group, reduction_group) = max( - epilogue_nodes, key=lambda x: int(x.is_reduction()) - ).group - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - # Flush created varaibles, since template fusion doen't share variable - kernel.cse.cache.clear() - kernel.epilogue_buffer_group.set_buffers() - kernel.load = kernel.load_epilogue - kernel.store = kernel.store_epilogue - for node in epilogue_nodes: - if template_node.node.name in [dep[0] for dep in list(node.read_writes.reads)]: - kernel.epilogue_info['dependent_buf'].append(node.node.name) - node.codegen((vars, reduction_vars)) + with kernel.epilogue_buffer_group.as_local(): + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue + kernel.store_reduction = kernel.store_reduction_epilogue + kernel.reduction = kernel.reduction_epilogue + + _, (group, reduction_group) = max( + epilogue_nodes, key=lambda x: int(x.is_reduction()) + ).group + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in epilogue_nodes: + node.codegen((vars, reduction_vars)) + with V.set_kernel_handler(kernel): src_code = ( partial_code if isinstance(partial_code, str) else partial_code.finalize() ) - return src_code + # For consistency, white space could make wrong write_path + buffer = IndentedBuffer() + buffer.splice(src_code) + return buffer.getvalue() def codegen_template(self, template_node, epilogue_nodes): # Handle prologue pattern diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 1db14e27..ccb9b0d1 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -11,7 +11,7 @@ from unittest.mock import patch from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, Pointwise +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, View from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -22,12 +22,13 @@ from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode +from torch._inductor.codegen import common from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR from . import mlir_common class IndentedBufferGroup: - def __init__(self, kernel: 'MLIRTemplateKernel'): + def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): self.kernel = kernel self.body = IndentedBuffer() self.loads = IndentedBuffer() @@ -37,18 +38,51 @@ def __init__(self, kernel: 'MLIRTemplateKernel'): self.dma_loads = IndentedBuffer() self.dma_stores = IndentedBuffer() self.spad_buffer = IndentedBuffer() + self.cse = common.CSE("%", "", name_prefix=f"{prefix}") + self.apply_cse = common.CSE("%", "", name_prefix=f"{prefix}apply") + # Original buffers will be saved later in the 'with' block + self.original_buffers = {} def set_buffers(self): self.kernel.loads = self.loads self.kernel.compute = self.compute self.kernel.stores = self.stores + self.kernel.applys = self.applys self.kernel.dma_loads = self.dma_loads self.kernel.dma_stores = self.dma_stores self.kernel.spad_buffer = self.spad_buffer + self.kernel.cse = self.cse + self.kernel.apply_cse = self.apply_cse + + def restore_buffers(self): + self.kernel.loads = self.original_buffers['loads'] + self.kernel.compute = self.original_buffers['compute'] + self.kernel.stores = self.original_buffers['stores'] + self.kernel.applys = self.original_buffers['applys'] + self.kernel.dma_loads = self.original_buffers['dma_loads'] + self.kernel.dma_stores = self.original_buffers['dma_stores'] + self.kernel.spad_buffer = self.original_buffers['spad_buffer'] + self.kernel.cse = self.original_buffers['cse'] + self.kernel.apply_cse = self.original_buffers['apply_cse'] @contextlib.contextmanager def as_local(self): - yield self + self.original_buffers = { + 'loads': self.kernel.loads, + 'compute': self.kernel.compute, + 'stores': self.kernel.stores, + 'applys': self.kernel.applys, + 'dma_loads': self.kernel.dma_loads, + 'dma_stores': self.kernel.dma_stores, + 'spad_buffer': self.kernel.spad_buffer, + 'cse': self.kernel.cse, + 'apply_cse': self.kernel.apply_cse, + } + try: + self.set_buffers() + yield self + finally: + self.restore_buffers() class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): def __init__(self, @@ -65,8 +99,6 @@ def __init__(self, self.call_size = call_size self.named_nodes = {} self.loop_info = {} - self.load_desc = {} - self.store_desc = {} self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes @@ -75,29 +107,23 @@ def __init__(self, self.render_options = dict() self.tile_size = [] self.loop_size = None - self.is_template_kernel = True - self.map_cse = CSE("#", self.suffix, name_prefix="template_map") - self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_const") - self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_alloc") - self.prologue_buffer_group = IndentedBufferGroup(self) - self.epilogue_buffer_group = IndentedBufferGroup(self) + self.map_cse = CSE("#", self.suffix, name_prefix="t_map") + self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_const") + self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_alloc") + self.prologue_buffer_group = IndentedBufferGroup(self, prefix="prologue_") + self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() + self.exception_nodes = {} # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False self.reduction_body_loop = None - self.reduction_idx = None self.reduction_buffer_idx = 0 self.reduction_info = {} self.reduction_epilogue_result = {} self.reduction_mean = [] - self.reuse_buffer_names = {} - - # Overwrite ops - self.load = self.load_epilogue - self.store = self.store_epilogue - self.store_reduction = self.store_reduction_epilogue - self.reduction = self.reduction_epilogue + # Dim info + self.dim_aliasing = {} def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): @@ -368,8 +394,6 @@ def meta_kernel(self): wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') # Dump loop and load/store information wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"load_tile_info = {self.load_desc}") - wrapper.add_import_once(f"store_tile_info = {self.store_desc}") wrapper.add_import_once(f"arg_attributes = {arg_attributes}") def call_kernel(self, kernel_name): @@ -381,78 +405,62 @@ def call_kernel(self, kernel_name): call_args, cuda=False) def codegen_prologue_body(self): - with self.prologue_buffer_group.as_local() as buf: - buf.body.splice(buf.spad_buffer) - buf.body.splice(buf.applys) - buf.body.splice(buf.dma_loads) - - if (buf.loads.getvalue() != '' or buf.compute.getvalue() != '' or buf.stores.getvalue() != ''): - buf.body.writelines(self.prologue_compute_body_loop.lines()) + body = IndentedBuffer() + with self.prologue_buffer_group.as_local(): + body.splice(self.spad_buffer) + body.splice(self.applys) + body.splice(self.dma_loads) + + if (self.loads.getvalue() != '' or self.compute.getvalue() != '' or self.stores.getvalue() != ''): + body.writelines(self.prologue_compute_body_loop.lines()) compute_body = mlir_common.ParallelLoopBuffer() with contextlib.ExitStack() as stack: stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) - compute_body.splice(buf.loads) - compute_body.splice(buf.compute) - compute_body.splice(buf.stores) - buf.body.splice(compute_body) - - # Clear buffers - self.loads.clear() - self.compute.clear() - self.stores.clear() + compute_body.splice(self.loads) + compute_body.splice(self.compute) + compute_body.splice(self.stores) + body.splice(compute_body) + return body def codegen_epilogue_body(self): def template_store(): - zero_cse = self.get_const_cse(0) - sram_var = self.epilogue_info["sram_var"] dram_var = self.epilogue_info["dram_var"] - index_var = self.epilogue_info["index_var"] - tag_var = self.epilogue_info["tag_var"] - mlir_dtype = self.epilogue_info["mlir_dtype"] - dram_shape = self.epilogue_info["dram_shape"] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis - vlane_stride = self.kernel_group.tile_desc.get_vlane_stride() - tile_stride = self.epilogue_info["tile_stride"] - tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - sram_index_var = ",".join([f"%{zero_cse}"] * self.kernel_group.tile_desc.get_nr_dim()) - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - tag_var, dram_shape, tile_shape, tile_stride) + index_list = self.epilogue_info["dram_idx"] + tile_desc = self.epilogue_info["dram_tile_desc"] + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) self.cse.generate(self.dma_stores, code, assignment = False) - # Do dma store first to overlap epilogue nodes - if self.reduction_fusion: - if len(self.stores._lines) == 0: - template_store() - self.epilogue_buffer_group.body.splice(self.dma_stores) - self.dma_stores.clear() - self.epilogue_buffer_group.body.splice(self.spad_buffer) - self.epilogue_buffer_group.body.splice(self.applys) - self.epilogue_buffer_group.body.splice(self.dma_loads) - self.epilogue_buffer_group.body.writelines(self.compute_body_loop.lines()) - compute_body = mlir_common.ParallelLoopBuffer() - with contextlib.ExitStack() as stack: - stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + + body = IndentedBuffer() + with self.epilogue_buffer_group.as_local(): + # Do dma store first to overlap epilogue nodes if self.reduction_fusion: - #if len(self.stores._lines) == 0: - # template_store() - compute_body.writelines(self.reduction_body_loop.lines()) - stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) - compute_body.splice(self.loads) - compute_body.splice(self.compute) - else: - compute_body.splice(self.loads) - compute_body.splice(self.compute) if len(self.stores._lines) == 0: template_store() - compute_body.splice(self.epilogue_buffer_group.stores) - if (compute_body.getvalue()): - self.epilogue_buffer_group.body.splice(compute_body) - self.epilogue_buffer_group.body.splice(self.dma_stores) - self.epilogue_buffer_group.body.splice(self.reduction_epilogue_suffix) - - # Clear buffers - self.loads.clear() - self.compute.clear() - self.stores.clear() + body.splice(self.dma_stores) + self.dma_stores.clear() + body.splice(self.spad_buffer) + body.splice(self.applys) + body.splice(self.dma_loads) + body.writelines(self.compute_body_loop.lines()) + compute_body = mlir_common.ParallelLoopBuffer() + with contextlib.ExitStack() as stack: + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + if self.reduction_fusion: + compute_body.writelines(self.reduction_body_loop.lines()) + stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) + compute_body.splice(self.loads) + compute_body.splice(self.compute) + else: + compute_body.splice(self.loads) + compute_body.splice(self.compute) + if len(self.stores._lines) == 0: + template_store() + compute_body.splice(self.stores) + if (compute_body.getvalue()): + body.splice(compute_body) + body.splice(self.dma_stores) + body.splice(self.reduction_epilogue_suffix) + return body def def_kernel( self, @@ -562,66 +570,42 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def prepare_input(self, indent_size: int = 0): - def emit_dma_start(buffer_name, index_var, tag_var, size, tile_size, subtile_size=None, async_flag=True, label="X"): - base = f"memref.dma_start %{label}[%{index_var}], %{buffer_name}[%c0, %c0], %c_mvin" - if label == "W": - base = base.replace("mvin", "mvin2") - - suffix = f"%{tag_var}[%c0], %axis, %vstride" - memref_shape = f"memref<{size}xf32>" - tile_shape = "x".join([str(x) for x in tile_size]) - tile_memref = f"memref<{tile_shape}xf32, 1>" - tag_memref = f"memref<1xi32>" - attrs = f"sram_stride=[1, {tile_size[0]}]" - async_flag = "false" - if subtile_size: - subtile_shape = ", ".join([str(x) for x in subtile_size]) - attrs = f"subtile_size=[{subtile_shape}], async={async_flag}, {attrs}" - else: - subtile_shape = ", ".join([str(x) for x in tile_size]) - attrs = f"subtile_size=[{subtile_shape}], async={async_flag}, {attrs}" - attr_memref = f"{{ {attrs} }}" - return f"{base}, {suffix}: {memref_shape}, {tile_memref}, {tag_memref} {attr_memref}" - + def load_input(self, indent_size: int = 0): def hook(): code = IndentedBuffer() - self.codegen_prologue_body() - prologue_code = self.prologue_buffer_group.body + prologue_code = self.codegen_prologue_body() if prologue_code.getvalue(): - code.writeline(emit_dma_start(self.prologue_info["input_sram_var"], self.prologue_info["input_index_var"], self.prologue_info["input_tag_var"], - self.prologue_info["input_numel"], self.prologue_info["input_tile_size"], subtile_size=self.prologue_info["input_subtile_size"], label="X")) - code.splice(prologue_code) - code.writeline(emit_dma_start(self.prologue_info["weight_sram_var"], self.prologue_info["weight_index_var"], self.prologue_info["weight_tag_var"], - self.prologue_info["weight_numel"], self.prologue_info["weight_tile_size"], subtile_size=self.prologue_info["weight_subtile_size"], label="W")) + input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + if (self.prologue_info["is_input_fused"]): + code.splice(input_dma_code) + code.splice(prologue_code) + code.splice(weight_dma_code) + else: + code.splice(weight_dma_code) + code.splice(prologue_code) + code.splice(input_dma_code) else: - code.writeline(emit_dma_start(self.prologue_info["input_sram_var"], self.prologue_info["input_index_var"], self.prologue_info["input_tag_var"], - self.prologue_info["input_numel"], self.prologue_info["input_tile_size"], self.prologue_info["input_subtile_size"], async_flag=True, label="X")) - code.writeline(emit_dma_start(self.prologue_info["weight_sram_var"], self.prologue_info["weight_index_var"], self.prologue_info["weight_tag_var"], - self.prologue_info["weight_numel"], self.prologue_info["weight_tile_size"], self.prologue_info["weight_subtile_size"], async_flag=True, label="W")) + dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + code.splice(dma_code) + dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + code.splice(dma_code) code = textwrap.indent(code.getvalue(), " "*indent_size).strip() return code assert "" not in self.render_hooks self.render_hooks[""] = hook + self.render_hooks.move_to_end("", last=False) # Force order to be triggered first return "" - def output_name(self): - # Cannot know the output name from the template, so we need to hook it - def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs() - output = arg_defs[3] #FIXME: Constant index used - pattern = r"%(\w+):" - output = re.search(pattern, output).group(1) - return output - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - def store_output(self, indent_size: int = 0): def hook(): - self.codegen_epilogue_body() - return textwrap.indent(self.epilogue_buffer_group.body.getvalue(), " "*indent_size).strip() + epilogue_code = self.codegen_epilogue_body() + return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -636,29 +620,6 @@ def hook(): self.render_hooks[""] = hook return "" - def reduction_iter_arg(self): - def hook(): - if len(self.reduction_vars): - args = ', '.join([f"%{iter.name} = %{init.name}" for (_, iter, init, _) in self.reduction_vars.values()]) - dtype = ', '.join([f"{dtype}" for (_, _, _, dtype) in self.reduction_vars.values()]) - return f"iter_args({args}) -> ({dtype})" - return "" - - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - - def reduction_acc(self): - def hook(): - if len(self.reduction_vars): - acc = ', '.join([f"%{acc.name}" for acc in self.reduction_vars.keys()]) - return f"{acc} =" - return "" - - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - def def_function(self): _, call_args, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: @@ -679,26 +640,76 @@ def hook(): self.render_hooks[key] = hook return key - def def_local_vars(self): + def def_local_vars(self, indent_size=0): key = "" def hook(): code = IndentedBuffer() - code.tabwidth = 2 - code.splice("\n") - with code.indent(): - code.splice(self.const_buffer) - code.splice(self.alloc_buffer) - return code.getvalue() + code.tabwidth = 1 + code.splice(self.const_buffer) + code.splice(self.alloc_buffer) + return textwrap.indent(code.getvalue(), " "*indent_size).strip() assert key not in self.render_hooks self.render_hooks[key] = hook return key + def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, + subtile_size:list=[], async_type=None, indent_size=0): + # Prepare code block + local_code = IndentedBuffer() + with V.set_kernel_handler(self): + tag = f"mvint_{self.dma_read_counter}" if dma_type == "MVIN" else f"mvoutt_{self.dma_write_counter}" + index_var = self.parse_index_list(index_list, local_code) + node_layout = self.named_nodes[dram_var].get_layout() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + dram_stride = [] + for idx in index_list: + if idx.is_Mul: + dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + dram_stride.append(0) + elif not idx.is_Number: + dram_stride.append(1) + else: + dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vlane_split_axis + vlane_stride = tile_desc.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + tag, dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): + # Prepare code block + with V.set_kernel_handler(self): + dtype = self.named_nodes[dram_name].get_layout().dtype + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) + buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) + code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" + return textwrap.indent(code, " "*indent_size).strip() + def render(self, template, kwargs, define_function=None): - # self.render_hooks = {} code = template.render(**kwargs) if define_function is not None: define_function(self) + return PartialRender( code, self.render_hooks, @@ -708,71 +719,14 @@ def get_spad_size_per_lane(self, tile_m, tile_n): size = tile_m * ((tile_n + self.vector_lane - 1) // self.vector_lane) return max(size, 2) # vector load/store - def load_prologue(self, name: str, index: sympy.Expr): - load_dim = [] - if not isinstance(V.graph, NullHandler) and name in V.graph.graph_inputs: - load_dim = V.graph.graph_inputs[name].layout.size - if self.ranges == self.buffer_types[name][2]: - index_var = self.prologue_info['input_index_var'] if len(load_dim) != 1 else 'tile_n' - vlane_split_axis = self.kernel_group.prologue_tile_desc.vlane_split_axis if len(load_dim) != 1 else 0 # FIXME: Fixed split axis for 1d load dim - else: - # Broadcast pattern - zero_index = self.const_cse.generate(self.const_buffer, "arith.constant 0 : index") - if self.prologue_info['is_bmm']: # FIXME: hardcoded - idx = f"%b, %t_k, %t_n" - map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1, d2) -> (d0 * 512 + d2)>") - vlane_split_axis = 2 # 3D GEMM prologue should be loaded by axis 2 - else: - idx = f"%t_m, %{zero_index}" - map_var = self.map_cse.generate(self.global_vars, f"affine_map<(d0, d1) -> (d0)>") - vlane_split_axis = 1 # 2D GEMM prologue should be loaded by axis 1 - index_var = self.apply_cse.generate(self.dma_loads, f"affine.apply #{map_var}({idx})") - index = self.rename_indexing(index) - dram_var = self.kernel_group.args.input(name) - dtype = V.graph.get_dtype(name) - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - vlane_stride = self.kernel_group.prologue_tile_desc.vlane_stride if len(load_dim) != 1 else 1 # FIXME: Fixed stride for 1d load dim - tile_numel_per_lane = self.kernel_group.prologue_tile_desc.get_numel_per_lane() - tile_shape = self.kernel_group.prologue_tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.prologue_info['input_sram_stride'] - - # Compute vector unit size - vshape = self.kernel_group.prologue_tile_desc.get_mlir_vshape(mlir_dtype) - compute_vec_size = self.kernel_group.prologue_tile_desc.get_compute_vec_size() - - if name not in self.buffer_names: - # Allocate sram buffer - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index, self.alloc_buffer) - self.buffer_names[name] = sram_var - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) - self.cse.generate(self.dma_loads, code, assignment = False) - - # Load vector from sram - sram_var = self.buffer_names[name] - zero_var = self.get_const_cse(0) - compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.prologue_tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) - - if compute_vec_size > 1: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" - - out = self.cse.generate(self.loads, line) - self.register_var_info(out, [compute_vec_size, mlir_dtype]) - return out - def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - tile_shape = self.kernel_group.prologue_tile_desc.get_mlir_shape(mlir_dtype) + tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) # Compute vector unit size - vshape = self.kernel_group.prologue_tile_desc.get_mlir_vshape(mlir_dtype) - compute_vec_size = self.kernel_group.prologue_tile_desc.get_compute_vec_size() + vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() sram_var = self.buffer_names[name] zero_var = self.get_const_cse(0) @@ -780,7 +734,7 @@ def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): _, operand_type = self.var_info[value] if mlir_dtype != operand_type: value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) - compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.prologue_tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) + compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) # Generate vector load instruction if compute_vec_size > 1: operation = "affine.vector_store" @@ -791,25 +745,19 @@ def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.stores.writeline(line) def load_epilogue(self, name: str, index: sympy.Expr): - is_1d_source = len(index.free_symbols) == 1 - is_transpose = False # FIXME: Only works for 2d input - if len(index.args) == 2: - for expr in index.args: - if len(expr.args): - if expr.args[1].name == "index0" and expr.args[0] > 1: - is_transpose = True - break - key = 't_index_var' if is_transpose else 'index_var' - index_var = self.epilogue_info[key] if not is_1d_source else 'tile_n' index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis if not is_1d_source else 0 # FIXME: Fixed split axis for 1d load dim - vlane_stride = self.kernel_group.tile_desc.vlane_stride if not is_1d_source else 1 # FIXME: Fixed stride for 1d load dim - tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() + + # Want to use tile_desc from epilogue_info + index_var = self.parse_indices(index) + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + vlane_stride = self.kernel_group.tile_desc.vlane_stride tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.epilogue_info['tile_stride'] + tile_stride = self.kernel_group.tile_desc.get_tile_stride() # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) @@ -818,16 +766,12 @@ def load_epilogue(self, name: str, index: sympy.Expr): if name not in self.buffer_names: # Allocate sram buffer dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) - self.buffer_names[name] = sram_var + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) - self.cse.generate(self.dma_loads, code, assignment = False) - elif name in self.reuse_buffer_names: - sram_var = self.reuse_buffer_names[name] - code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + f"{name}_tag", dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) + self.buffer_names[name] = sram_var else: sram_var = self.buffer_names[name] @@ -861,24 +805,25 @@ def load_epilogue(self, name: str, index: sympy.Expr): return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - index_var = self.epilogue_info['index_var'] + index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + + index_var = self.parse_indices(index) + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vlane_stride - tile_numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() - - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = self.epilogue_info['tile_stride'] + tile_stride = self.kernel_group.tile_desc.get_tile_stride() # Compute vector unit size vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() if name not in self.buffer_names: - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index_var, index) + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) self.buffer_names[name] = sram_var else: zero_cse = self.get_const_cse(0) @@ -901,8 +846,9 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.stores.writeline(DeferredLine(name, line)) # Generate DMA instruction + attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + f"{name}_tag", dram_shape, tile_shape, attribute) self.dma_stores.writeline(DeferredLine(name, code)) def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): @@ -920,28 +866,35 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): sqr_sum = self.reduction_epilogue(dtype, src_dtype, "sum", ops.mul(value, value)) self.welford_reduce_out = (sum, sqr_sum, None) return sum, sqr_sum, None + # Check duplicated reductions reduction_key = src_dtype, reduction_type, value if reduction_key in self.reduction_epilogue_result: return self.reduction_epilogue_result[reduction_key] # Reduction fusion codegen part - type_name = mlir_common.DTYPE_TO_MLIR[dtype] vec_size = self.compute_body_loop.step - vshape = f"vector<{vec_size}x{type_name}>" + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + new_tile_size = self.kernel_group.tile_desc.get_tile_size()[:-1] + [vec_size] + new_vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis + new_vlane_stride = self.kernel_group.tile_desc.vlane_stride + local_tile_desc = mlir_common.MLIRMultiDimTile(new_tile_size, self.vector_lane, new_vlane_split_axis, new_vlane_stride, vec_size) + + tile_shape = local_tile_desc.get_mlir_shape(type_name) + vshape = local_tile_desc.get_mlir_vshape(type_name) - tile_shape = f"memref<{self.reduction_body_loop.size * self.vector_lane}x{vec_size}x{type_name}, 1>" name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" self.reduction_buffer_idx += 1 index = "dummy_index" # Not used - tile_numel_per_lane = self.compute_body_loop.step * self.reduction_body_loop.size - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, None, index, self.const_buffer) + tile_numel_per_lane = self.compute_body_loop.step * self.reduction_body_loop.size # ??? + sram_var, _ = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index, self.const_buffer) self.reduction_epilogue_result[reduction_key] = sram_var # Load partial result - zero_var = self.get_const_cse(0) + zero_var_list = [f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim() + zero_var_list[-2] = f"%{self.reduction_loop_idx}" + compute_index_var = ", ".join(zero_var_list) operation = "affine.vector_load" - compute_index_var = ",".join([f"%{self.reduction_loop_idx}"] + [f"%{zero_var}"]) line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" out = self.cse.generate(self.loads, line) self.register_var_info(out, [self.compute_body_loop.step, type_name]) @@ -953,78 +906,85 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): operation = "affine.vector_store" line = f"{operation} %{result}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" self.compute.writeline(line) # Need to be placed after partial reduction - self.reduction_info[sram_var] = reduction_type + self.reduction_info[sram_var] = [reduction_type, local_tile_desc] return sram_var def store_reduction_epilogue(self, name, index, value): + index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) + dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - index = self.rename_indexing(index) - - # Tile is always reuduced in inner loop - numel_per_lane = self.kernel_group.tile_desc.get_numel_per_lane() - reduction_axis_size = self.kernel_group.tile_desc.get_tile_size()[-2] - nr_outer_loop = numel_per_lane // reduction_axis_size + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis - 1 + index_var = self.parse_indices(index, self.reductions_suffix, comments="// Store reduction") + dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis + vlane_split_axis = self.kernel_group.tile_desc.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vlane_stride - tile_numel_per_lane = vlane_stride * nr_outer_loop * 2 - dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) - tile_shape = f"memref<{self.kernel_group.tile_desc.get_tile_size()[1]}x{type_name}, 1>" - tile_stride = [1] - sram_var, index_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, tile_numel_per_lane, tile_shape, index, - index, buffer=self.const_buffer) + # Create final buffer descriptor + nr_outer_loop = self.reduction_nr_outer_loop + tile_size = self.kernel_group.tile_desc.get_tile_size()[:-1] + final_tile_desc = mlir_common.MLIRMultiDimTile(tile_size, self.vector_lane, vlane_split_axis, vlane_stride*nr_outer_loop*2) + final_tile_shape = final_tile_desc.get_mlir_shape(mlir_dtype) + final_tile_stride = final_tile_desc.get_tile_stride() + sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, final_tile_desc, index, buffer=self.const_buffer) + + # Set partial buffer descriptor + partial_tile_desc = self.reduction_info[value][1] + partial_vec_size = partial_tile_desc.get_compute_vec_size() + partial_vshape = partial_tile_desc.get_mlir_vshape(mlir_dtype) + partial_tile_shape = partial_tile_desc.get_mlir_shape(mlir_dtype) + + # Prepare constant + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value][0], dtype)} : {mlir_dtype}") + partial_zero_var_list = [f"%{self.get_const_cse(0)}"] * partial_tile_desc.get_nr_dim() + final_zero_var_list = [f"%{self.get_const_cse(0)}"] * final_tile_desc.get_nr_dim() for i in range(self.reduction_body_loop.size): - vec_size = self.compute_body_loop.step - vshape = f"vector<{vec_size}x{type_name}>" - - partial_tile_shape = f"memref<{self.reduction_body_loop.size * self.vector_lane}x{vec_size}x{type_name}, 1>" # Load partial result - init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value], dtype)} : {type_name}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") - zero_var = self.const_cse.generate(self.const_buffer, f"arith.constant {0} : index") - index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") - compute_index_var = ",".join([f"%{index_var}"] + [f"%{zero_var}"]) + body_index_var = self.const_cse.generate(self.const_buffer, f"arith.constant {i} : index") + partial_zero_var_list[-2] = f"%{body_index_var}" + compute_index_var = ",".join(partial_zero_var_list) operation = "affine.vector_load" - line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {vshape}" + line = f"{operation} %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" out = self.cse.generate(self.reductions_suffix, line) operation = "affine.vector_store" - line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {vshape}" + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {partial_vshape}") + line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" self.reductions_suffix.writeline(line) # 2 step reduction new_vec_size = 2 - new_vshape = f"vector<{vec_size//new_vec_size}x{new_vec_size}x{type_name}>" - new_reduced_shape = f"vector<{new_vec_size}x{type_name}>" - out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {vshape} to {new_vshape}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {new_reduced_shape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) + new_vshape = f"vector<{partial_vec_size//new_vec_size}x{new_vec_size}x{mlir_dtype}>" + new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" + out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {partial_vshape} to {new_vshape}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {new_reduced_shape}") + out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value][0], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - self.register_var_info(out, [new_vec_size, type_name]) - self.register_var_info(out2, [new_vec_size, type_name]) - out = reduction_partial_combine_vec(self.reduction_info[value], out, out2) + self.register_var_info(out, [new_vec_size, mlir_dtype]) + self.register_var_info(out2, [new_vec_size, mlir_dtype]) + out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - # Final reduction - #final_reduced_shape = type_name - #init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(self.reduction_info[value], dtype)} : {type_name}") - #out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value], out, init, axis=0, shape=vshape, reduced_shape=final_reduced_shape)) - if self.welford_reduce_out is not None: - # mean - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(768)} : f32") + # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size)} : f32") + if self.reduction_axis_size - 1 > 0: + divider2 = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size-1)} : f32") + else: + divider2 = divider + if self.buffer_types[name][1] > 1: divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") + divider_vec2 = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider2} : f32 to {new_reduced_shape}") else: divider_vec = divider + divider_vec2 = divider2 if self.current_node.node.origin_node: # FIXME: This is a temporary solution - # mean = E(X) / N + # mean = SUM(X) / N self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) out = self.reduction_mean[i] else: @@ -1032,43 +992,37 @@ def store_reduction_epilogue(self, name, index, value): sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") + m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec2} : {new_reduced_shape}") out = m2 + final_zero_var_list[-1] = f"%{body_index_var}" + final_compute_index_var = ",".join(final_zero_var_list) operation = "affine.vector_store" - line = f"{operation} %{out}, %{sram_var}[%{index_var}] : {tile_shape}, {new_reduced_shape}" + line = f"{operation} %{out}, %{sram_var}[{final_compute_index_var}] : {final_tile_shape}, {new_reduced_shape}" self.reductions_suffix.writeline(DeferredLine(name, line)) # MVOUT Encoding # Generate DMA instruction - index_var = self.reduction_idx - code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, type_name, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, tile_stride) + attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}" + code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + f"{name}_tag", dram_shape, final_tile_shape, attribute) self.reductions_suffix.writeline(DeferredLine(name, code)) - def get_scratchpad_buffer(self, dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, buffer=None): - return super().get_scratchpad_buffer(dtype, name, tile_size_per_lane, dram_tile_shape, index_var, raw_index, True, buffer=buffer) - - def set_tile_size(self, template_epilogue_info, prologue=False): - tile_desc = mlir_common.MLIRMultiDimTile(template_epilogue_info['tile_size'], - self.vector_lane, - vlane_split_axis=template_epilogue_info['vlane_split_axis'], - vlane_stride=template_epilogue_info['vlane_stride']) + def set_tile_size(self, template_fusion_info, prologue=False): + tile_desc = template_fusion_info["dram_tile_desc"] + if "dim_aliasing" in template_fusion_info: + self.dim_aliasing = template_fusion_info["dim_aliasing"] - if "reuse_buffer_names" in template_epilogue_info: - self.reuse_buffer_names.update(template_epilogue_info["reuse_buffer_names"]) - - if 'nr_rdim' in template_epilogue_info and template_epilogue_info['nr_rdim']==1: + if 'nr_rdim' in template_fusion_info and template_fusion_info['nr_rdim']==1: tile_desc.nr_rdim = 1 numel_per_lane = tile_desc.get_numel_per_lane() - reduction_axis_size = tile_desc.get_tile_size()[-2] + reduction_axis_size = tile_desc.get_tile_size()[-1] nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size tile_desc.vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True - self.reduction_axis_size = tile_desc.get_tile_size()[-2] - self.reduction_nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size - self.reduction_idx = template_epilogue_info["reduction_idx"] + self.reduction_axis_size = tile_desc.get_tile_size()[-1] + self.reduction_nr_outer_loop = nr_outer_loop self.reduction_loop_idx = "reduce_loop_idx" self.compute_body_loop.size = reduction_axis_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop @@ -1083,6 +1037,14 @@ def set_tile_size(self, template_epilogue_info, prologue=False): self.compute_body_loop.step = tile_desc.get_compute_vec_size() return tile_desc + def rename_indexing(self, index) -> sympy.Expr: + for dim_name, dim_aliased_name in self.dim_aliasing.items(): + index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name)) + # To avoid this case ({"index0":"index1", "index1":"index0"}) + for dim_aliased_name in self.dim_aliasing.values(): + index = index.subs(sympy.Symbol("tmp_"+dim_aliased_name), sympy.Symbol(dim_aliased_name)) + return index + class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" From 4fd7b6949f302ad995c539369c5f1b96d2cea9ad Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 04:42:20 +0000 Subject: [PATCH 27/62] [CI+Test] Add fusion test + update test case --- .github/workflows/pull-request.yml | 71 +++++++++++++++++------ .github/workflows/pull-request_mobile.yml | 71 +++++++++++++++++------ tests/Fusion/test_matmul_reduction.py | 35 +++++------ tests/Fusion/test_prologue_fusion.py | 28 +++++++-- tests/test_conv2d.py | 10 +++- tests/test_matmul.py | 23 +++++++- tests/test_reduce.py | 11 ++-- 7 files changed, 178 insertions(+), 71 deletions(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 3dbb3e36..9d440df6 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -493,12 +493,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_activation.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -508,12 +503,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_scalar.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -523,12 +513,57 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + + - name: Run test_matmul_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py + + - name: Run test_matmul_layernorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_layernorm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_layernorm.py + + - name: Run test_bmm_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py + + - name: Run test_prologue_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_prologue_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py + + - name: Run test_transformer_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py + - name: Run test_conv_fusion.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} diff --git a/.github/workflows/pull-request_mobile.yml b/.github/workflows/pull-request_mobile.yml index 945bac3b..45d73fa8 100644 --- a/.github/workflows/pull-request_mobile.yml +++ b/.github/workflows/pull-request_mobile.yml @@ -493,12 +493,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_addmm_residual.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_activation.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -508,12 +503,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_activation.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_matmul_scalar.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -523,12 +513,7 @@ jobs: -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_scalar.py - - name: Log in to GitHub Container Registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GIT_ACCESS_TOKEN }} + - name: Run test_conv_fusion.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} @@ -539,6 +524,56 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump -e TORCHSIM_VECTOR_LANE=8 -e TORCHSIM_SPAD_SIZE=32 \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_conv_fusion.py + - name: Run test_matmul_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py + + - name: Run test_matmul_layernorm.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_matmul_layernorm.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_layernorm.py + + - name: Run test_bmm_reduction.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_bmm_reduction.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_bmm_reduction.py + + - name: Run test_prologue_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_prologue_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_prologue_fusion.py + + - name: Run test_transformer_fusion.py + env: + GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} + run: | + echo "Running test_transformer_fusion.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_transformer_fusion.py + test_moe: name: Run test_moe runs-on: self-hosted diff --git a/tests/Fusion/test_matmul_reduction.py b/tests/Fusion/test_matmul_reduction.py index 07dd914d..31ea1b0d 100644 --- a/tests/Fusion/test_matmul_reduction.py +++ b/tests/Fusion/test_matmul_reduction.py @@ -17,24 +17,22 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) exit(1) -def test_matmul_reduce(device, size=512): - def matmul_fused(a, b, c): +def test_matmul_reduce(device, M=512, N=512, K=512): + def matmul_fused(a, b): result = torch.matmul(a, b) return result, result.max(dim=-2).values torch.manual_seed(0) - N = size - input = torch.randn(N, N) - weight = torch.randn(N, N) - #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) - #weight = torch.eye(N, dtype=torch.float32) + input = torch.randn(M, K) + weight = torch.randn(K, N) + #input = torch.arange(1, M * K + 1, dtype=torch.float32).reshape(M, K).to(dtype=torch.float32) + #weight = torch.eye(K, dtype=torch.float32) x1 = input.to(device=device) w1 = weight.to(device=device) x2 = input.to("cpu") w2 = weight.to("cpu") - c = 7 opt_fn = torch.compile(dynamic=False)(matmul_fused) - res = opt_fn(x1, w1, c) - y = matmul_fused(x2, w2, c) + res = opt_fn(x1, w1) + y = matmul_fused(x2, w2) test_result("Matmul Reduction Fusion activation", res[0], y[0]) test_result("Matmul Reduction Fusion reduction", res[1], y[1]) @@ -45,7 +43,7 @@ def matmul_fused(a, b, c): return result, var, mean torch.manual_seed(0) N = size - input = torch.randn(3072, 768) + input = torch.randn(1024, 768) weight = torch.randn(512, 768) #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) #weight = torch.eye(N, dtype=torch.float32) @@ -61,17 +59,16 @@ def matmul_fused(a, b, c): test_result("Matmul var_mean Fusion reduction", res[1], y[1]) test_result("Matmul var_mean Fusion reduction", res[2], y[2]) -def test_matmul_add_var_mean(device, size=512): +def test_matmul_add_var_mean(device, M=768, N=512, K=3072): def matmul_fused(a, b, c, d): result = torch.matmul(a, b.T) + c.T var, mean = torch.var_mean(result + d, dim=-2) return result, var, mean torch.manual_seed(0) - N = size - input = torch.randn(768, 3072) - weight = torch.randn(512, 3072) - bias = torch.randn(768, 512) - residual = torch.randn(768,512) + input = torch.randn(M, K) + weight = torch.randn(N, K) + bias = torch.zeros(N, M) + residual = torch.randn(M,N) x1 = input.to(device=device) w1 = weight.to(device=device) b1 = bias.to(device=device) @@ -95,6 +92,6 @@ def matmul_fused(a, b, c, d): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - #test_matmul_reduce(device) + test_matmul_reduce(device, 3072, 512, 768) test_matmul_var_mean(device) - #test_matmul_add_var_mean(device) + test_matmul_add_var_mean(device) diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py index 926782be..d5d1cdb1 100644 --- a/tests/Fusion/test_prologue_fusion.py +++ b/tests/Fusion/test_prologue_fusion.py @@ -53,13 +53,28 @@ def matmul_fused(a, b, c): y = matmul_fused(x2, w2, c2) test_result("Matmul Element-wise Fusion Forward", res, y) -def test_elem_bmm_fusion(device, batch_size=1, m=512, n=512, k=64): +def test_elem_bmm_weight_fusion(device, batch_size=1, m=512, n=512, k=64): def bmm(a, b, c, d): - return torch.bmm(a , (d - b)/c) + return torch.bmm(a , (d+b)*c) torch.manual_seed(0) a = torch.randn(batch_size, m, k).to(device=device) b = torch.randn(batch_size, 1, n).to(device=device) - c = torch.randn(batch_size, 1, n) * 1000 + c = torch.randn(batch_size, 1, n) + c = c.to(device=device) + d = torch.randn(batch_size, k, n).to(device=device) + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(a, b, c, d) + out = bmm(a.cpu(), b.cpu(), c.cpu(), d.cpu()) + print(torch.max(torch.abs(res.cpu() - out))) + test_result("BMM Element-wise Fusion Forward", res, out) + +def test_elem_bmm_input_fusion(device, batch_size=1, m=512, n=512, k=64): + def bmm(a, b, c, d): + return torch.bmm((a+b)*c , d) + torch.manual_seed(0) + a = torch.randn(batch_size, m, k).to(device=device) + b = torch.randn(batch_size, 1, k).to(device=device) + c = torch.randn(batch_size, 1, k) c = c.to(device=device) d = torch.randn(batch_size, k, n).to(device=device) opt_fn = torch.compile(dynamic=False)(bmm) @@ -76,6 +91,7 @@ def bmm(a, b, c, d): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_elem_broadcast_fusion(device) - test_elem_fusion(device) - test_elem_bmm_fusion(device, batch_size=12, m=64, n=512, k=512) \ No newline at end of file + #test_elem_broadcast_fusion(device) + #test_elem_fusion(device) + #test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) + test_elem_bmm_weight_fusion(device, batch_size=12, m=512, n=512, k=64) \ No newline at end of file diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 9d8b855a..8667792a 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -43,6 +43,10 @@ def custom_conv2d(a, b, bias): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_conv2d(device, batch_size=1, in_channels=128, out_channels=128, input_size=28, kernel_size=3, stride=1, padding=1) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64, kernel_size=7, stride=2, padding=3) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=32, input_size=32, kernel_size=3, stride=1, padding=1) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index bd219051..6f41468b 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -50,6 +50,24 @@ def custom_matmul(bias, a, b): y = custom_matmul(b2, x2, w2) test_result("Addmm Forward", res, y) +def test_addmm2(device, input_size=128, hidden_size=128, output_size=128): + def custom_matmul(bias, a, b): + return torch.matmul(a, b) #+ bias + torch.manual_seed(0) + input = torch.randn(input_size, hidden_size) + weight = torch.randn(hidden_size, output_size) + bias = torch.randn(input_size, 1, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + b1 = bias.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + b2 = bias.to("cpu") + opt_fn = torch.compile(dynamic=False)(custom_matmul) + res = opt_fn(b1, x1, w1) + y = custom_matmul(b2, x2, w2) + test_result("Addmm2 Forward", res, y) + def test_linear(device, input_size=128, hidden_size=128, output_size=128): def custom_linear(a, b, bias): linear = torch.nn.Linear(hidden_size, output_size) @@ -83,7 +101,10 @@ def custom_linear(a, b, bias): test_matmul(device, 128, 128, 128) test_matmul(device, 256, 256, 256) test_matmul(device, 128, 256, 256) - test_matmul(device, 129, 61, 56) + test_matmul(device, 128, 63, 56) test_addmm(device, 128, 256, 512) test_addmm(device, 128, 256, 512, bias_rank=2) test_addmm(device, 129, 61, 56) + test_addmm2(device, 129, 61, 56) + test_addmm(device, 129*4, 61*4, 56*4) + test_addmm2(device, 129*4, 61*4, 56*4) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index c1556787..e1a84b7f 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -50,9 +50,8 @@ def reduce_sum(a, dim, keepdim): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - #test_reduce_sum(device, (29, 47), 1, keepdim=True) - #test_reduce_sum(device, (17, 68), 0, keepdim=True) - #test_reduce_sum(device, (327, 447), 1, keepdim=True) - #test_reduce_sum(device, (327, 447), 0, keepdim=True) - test_reduce_sum2(device, shape) - + test_reduce_sum(device, (29, 47), 1, keepdim=True) + test_reduce_sum(device, (17, 68), 0, keepdim=True) + test_reduce_sum(device, (327, 447), 1, keepdim=True) + test_reduce_sum(device, (327, 447), 0, keepdim=True) + test_reduce_sum2(device, shape) \ No newline at end of file From 9ae7a0849526f1585408c1b1124b4e716a5493f4 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 05:32:09 +0000 Subject: [PATCH 28/62] [Fix] Fix var_mean codegen + cheatsheet folder issue --- PyTorchSimFrontend/mlir/mlir_template.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index ccb9b0d1..b5aa2593 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -2,6 +2,7 @@ import itertools import textwrap import re +import os import contextlib import math import sympy @@ -218,7 +219,9 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: - file_path = f"{CONFIG_TORCHSIM_DIR}/validation/gemm_candidates/gemm_{M}_{K}_{N}.txt" + dir_path = f"{CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" + os.makedirs(dir_path, exist_ok=True) + file_path = f"{dir_path}/gemm_{M}_{K}_{N}.txt" line_to_write = f"{tile_M} {tile_K} {tile_N}\n" try: with open(file_path, "r") as f: @@ -978,10 +981,8 @@ def store_reduction_epilogue(self, name, index, value): if self.buffer_types[name][1] > 1: divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") - divider_vec2 = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider2} : f32 to {new_reduced_shape}") else: divider_vec = divider - divider_vec2 = divider2 if self.current_node.node.origin_node: # FIXME: This is a temporary solution # mean = SUM(X) / N @@ -992,7 +993,7 @@ def store_reduction_epilogue(self, name, index, value): sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec2} : {new_reduced_shape}") + m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") out = m2 final_zero_var_list[-1] = f"%{body_index_var}" From 9e0e2d46bf8130caad01f5fe5ba9a20cd082d7e3 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 06:12:38 +0000 Subject: [PATCH 29/62] [Frontend] Add exception handling for reduction loop only kernel --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 84418ec7..615925d1 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -147,6 +147,15 @@ def codegen_nodes(self, nodes): _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group + + # There is no normal loop, then revert simplified group + if len(group) == 0: + for idx, node in enumerate(nodes): + self.revert_group(node) + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + ex_kernel = self.target_kernel(kernel_group=self.kernel_group) ex_kernel.kernel_group = self.kernel_group From 2109244aca1591ef1b5bde9df40d51d6daaa7d32 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 06:15:00 +0000 Subject: [PATCH 30/62] [Template] Fix a minor bug in GEMM template --- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 1 + 1 file changed, 1 insertion(+) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index ace6ea9d..f706c2e5 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -143,6 +143,7 @@ def render(self, if (M == 0) or (N == 0) or (K == 0): # exception for MoE template = EMPTY_TEMPLATE nr_rdim = 0 + epilogue_dim_aliasing = {} elif n_epilogue_node>=1 and epilogue_nodes[0].is_reduction(): template = GEMM_REDUCTION_TEMPLATE epilogue_dim_aliasing = {"index0":"index1", "index1":"index0"} From 3a8e0f8a0cb485f18fa6a6a14ca90c46b1221893 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 07:31:31 +0000 Subject: [PATCH 31/62] [Frontend] Fix dram stride calculate logic --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6dbe9047..68aa1b11 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1550,8 +1550,6 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_dims = total_dims # Brodatcast tile shape index_var = self.parse_indices(index, buffer=buffer) - input_argument = [f"index{str(i)}" for i in local_dims] - dram_stride = [index.coeff(sympy.Symbol(arg)) for arg in input_argument] if kg_tile_desc.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vlane_split_axis) @@ -1619,6 +1617,16 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Update local_tile_desc.set_tile_size(new_tile_size) local_tile_desc.vlane_split_axis = new_vlane_split_axis + + # Calculate dram stride + if index.is_Symbol: + dram_stride = [0] * local_tile_desc.get_nr_dim() + dim_idx = int(str(index)[5:]) + dram_stride[dim_idx] = 1 + elif index.is_Number: + dram_stride = [0] * local_tile_desc.get_nr_dim() + else: + dram_stride = [arg.as_coeff_mul()[0] for arg in index.as_ordered_terms()] return local_tile_desc, index_var, dram_stride def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, From ae09ef2066e7187f3473ae11320881baa1a51a40 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 08:31:43 +0000 Subject: [PATCH 32/62] [Frontend] Fix dram_stride --- .../mlir/mlir_codegen_backend.py | 20 ++++++++++++++++--- PyTorchSimFrontend/mlir/mlir_scheduling.py | 1 + 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 68aa1b11..6dd4f66a 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -4,6 +4,7 @@ import os import math import torch +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from torch._dynamo.utils import dynamo_timed from torch._inductor.codegen import cpp, wrapper, common, memory_planning @@ -1619,14 +1620,27 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.vlane_split_axis = new_vlane_split_axis # Calculate dram stride + dram_stride = [0] * local_tile_desc.get_nr_dim() if index.is_Symbol: - dram_stride = [0] * local_tile_desc.get_nr_dim() dim_idx = int(str(index)[5:]) dram_stride[dim_idx] = 1 elif index.is_Number: - dram_stride = [0] * local_tile_desc.get_nr_dim() + pass else: - dram_stride = [arg.as_coeff_mul()[0] for arg in index.as_ordered_terms()] + dram_dict = defaultdict(list) + # Assume that div will have high priority than mod + for arg in index.as_ordered_terms(): + coeff, dim = arg.as_coeff_mul() + real_dim = list(dim[0].free_symbols)[0] + dram_dict[str(real_dim)].append(coeff) + # Add missing dims if not added + max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 + for i in range(max_dim): + target_dim = f"index{i}" + if target_dim not in str(index): + dram_dict[target_dim] = [0] + sorted_keys = sorted(dram_dict.keys()) + dram_stride = sum((dram_dict[key] for key in sorted_keys), []) return local_tile_desc, index_var, dram_stride def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 615925d1..20528d80 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -260,6 +260,7 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e tile_desc = kernel.set_tile_size(kernel.epilogue_info) kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None if epilogue_nodes: with kernel.epilogue_buffer_group.as_local(): kernel.load = kernel.load_epilogue From 39405e7f843c467cc3f5c7e8db7f7a1ffe29309c Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 11:09:54 +0000 Subject: [PATCH 33/62] [Frontend] Fix 1 --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6dd4f66a..ab04d3e6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1631,6 +1631,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Assume that div will have high priority than mod for arg in index.as_ordered_terms(): coeff, dim = arg.as_coeff_mul() + if len(dim) == 0: + continue real_dim = list(dim[0].free_symbols)[0] dram_dict[str(real_dim)].append(coeff) # Add missing dims if not added From 5776d03eab30a9e70afa2ffeee85a38df1c3a0b9 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 11:35:02 +0000 Subject: [PATCH 34/62] [Frontend] Fix wip --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 20528d80..67a8b026 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -113,6 +113,8 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != act_node.group: self.revert_group(act_node) + if template_node.group != act_node.group: + return False return True # Check elementwise fusion @@ -214,6 +216,7 @@ def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): with kernel: + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() for node in [template_node, *prologue_nodes, *epilogue_nodes]: node.mark_run() partial_code = render() @@ -275,12 +278,12 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e for node in epilogue_nodes: node.codegen((vars, reduction_vars)) - with V.set_kernel_handler(kernel): - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) + with V.set_kernel_handler(kernel): + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) # For consistency, white space could make wrong write_path buffer = IndentedBuffer() buffer.splice(src_code) @@ -301,7 +304,6 @@ def codegen_template(self, template_node, epilogue_nodes): _, (numel, rnumel) = template_node.group template_buffer = template_node.node kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) - _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) wrapper = V.graph.wrapper_code From 7699887f87cf65298f851ac2bc55dc5037ea7b44 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 13:37:54 +0000 Subject: [PATCH 35/62] Revert final render position --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 67a8b026..e5d28779 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -278,12 +278,12 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e for node in epilogue_nodes: node.codegen((vars, reduction_vars)) - with V.set_kernel_handler(kernel): - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) + with V.set_kernel_handler(kernel): + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) # For consistency, white space could make wrong write_path buffer = IndentedBuffer() buffer.splice(src_code) @@ -304,6 +304,7 @@ def codegen_template(self, template_node, epilogue_nodes): _, (numel, rnumel) = template_node.group template_buffer = template_node.node kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) wrapper = V.graph.wrapper_code From b8868671d0f9513ac151ac8d82e85b92f925e49f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 14:47:41 +0000 Subject: [PATCH 36/62] [Frontend] Do not fuse for edge case --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 3 +++ PyTorchSimFrontend/mlir/mlir_template.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index e5d28779..a5d8bd3d 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -115,6 +115,9 @@ def can_fuse_horizontal(self, node1, node2): self.revert_group(act_node) if template_node.group != act_node.group: return False + # We don't fuse this case... + if template_node.group[1][0][0] == 1: + return False return True # Check elementwise fusion diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b5aa2593..9e7a104e 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -1030,6 +1030,9 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: tile_desc.vec_size=64 + if tile_desc.get_numel_per_lane() < tile_desc.vec_size: + tile_desc.vec_size = tile_desc.get_numel_per_lane() + if prologue: self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() self.prologue_compute_body_loop.step = tile_desc.get_compute_vec_size() From 0bded100531682bf61e3c7210a7ab1a308ac5f45 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 11 Jul 2025 15:24:44 +0000 Subject: [PATCH 37/62] [Frontend] Fusion condition change --- PyTorchSimFrontend/mlir/mlir_common.py | 5 +++++ PyTorchSimFrontend/mlir/mlir_scheduling.py | 11 +++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 00bf4169..9151ac0b 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -499,6 +499,11 @@ def dummy_tile_size(): tile_size[-1] = self.vector_lane tile_size[-2] = 4 * self.vector_lane tile_size[-3] = 2 + elif len(tile_size) == 4: + tile_size[-1] = self.vector_lane + tile_size[-2] = 4 * self.vector_lane + tile_size[-3] = 2 + tile_size[-4] = 1 else: raise NotImplementedError("dummy tile size fail!") return tile_size diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index a5d8bd3d..f1a2513e 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -94,6 +94,8 @@ def can_fuse_horizontal(self, node1, node2): if node1.is_template() or node2.is_template(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate if node1.is_template() and len(node1.get_nodes())==1 and isinstance(node1.node.template, MLIRMaxPoolTemplate) or \ node2.is_template() and len(node1.get_nodes())==1 and isinstance(node2.node.template, MLIRMaxPoolTemplate): return False @@ -112,12 +114,13 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != act_node.group: - self.revert_group(act_node) - if template_node.group != act_node.group: - return False # We don't fuse this case... - if template_node.group[1][0][0] == 1: + if (isinstance(template_node, MLIRBMMTemplate) or isinstance(template_node, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: return False + + if template_node.group[1][0] != act_node.get_nodes()[0].node.data.get_size(): + return False + self.revert_group(act_node) return True # Check elementwise fusion From b3c5d9ce87456bb984eb58abe5b185b1b4838b7f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 12 Jul 2025 06:54:47 +0000 Subject: [PATCH 38/62] [Frontend/Fusion] Add prologue fusion condition --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index f1a2513e..9b07d3c7 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -64,6 +64,8 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return False if len(node1.read_writes.writes) != 1: return False + if len(node1.users) != 1: + return False if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: node1 = self.revert_group(node1) return True From b2e7110946af3c0b7e64ff13257e11ae10eef517 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 12 Jul 2025 10:09:53 +0000 Subject: [PATCH 39/62] [Frontend] Fix dram_stride + tile_size for reduction only case --- .../mlir/mlir_codegen_backend.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index ab04d3e6..377eec7a 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1212,7 +1212,6 @@ def store_reduction(self, name, index, value): vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index) if self.welford_reduce_out is not None: - # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") @@ -1559,9 +1558,14 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Case 0. Tile is 0-D scalar if len(local_dims) == 0: - local_tile_desc.set_tile_size([kg_tile_desc.get_used_vlane() * kg_tile_desc.vlane_stride]) # Force it to use vector instruction. - local_tile_desc.vlane_split_axis = local_vlane_split_axis # last axis - local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + if not store_reduction: + local_tile_desc.set_tile_size([kg_tile_desc.get_used_vlane() * kg_tile_desc.vlane_stride]) # Force it to use vector instruction. + local_tile_desc.vlane_split_axis = local_vlane_split_axis # last axis + local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride + else: + local_tile_desc.set_tile_size([1]) + local_tile_desc.vlane_split_axis = 0 + local_tile_desc.vlane_stride = 1 dram_stride = [0] # Edge case # Case 1. Tile is 1-D vector type elif len(local_dims) == 1 and len(local_dims) <= self.reduction_depth: @@ -1571,7 +1575,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe # Case 2. Tile is 1-D vector type with reduction elif len(local_dims) == 1 and len(local_dims) == self.reduction_depth + 1: local_tile_desc.set_tile_size([1, kg_tile_desc.get_dim_size(local_dims[0])]) - local_tile_desc.vlane_split_axis = local_vlane_split_axis + local_tile_desc.vlane_split_axis = local_vlane_split_axis + 1 local_tile_desc.vlane_stride = kg_tile_desc.vlane_stride # Case 3. Tile is 2-D tile elif len(local_dims) == 2: @@ -1643,6 +1647,11 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe dram_dict[target_dim] = [0] sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) + + # FIXME. It will be nice to modify node instead of this exception handling... + if len(self.itervars) == 1 and self.reduction_depth == 0: + # In case of reduction loop only case, we will add dummy loop so shift it once + dram_stride = [0] + dram_stride[:-1] return local_tile_desc, index_var, dram_stride def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, From dfd7809a8c1b62687b773c6501b62024d39ff2f9 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 12 Jul 2025 12:00:29 +0000 Subject: [PATCH 40/62] [Frontend/Fusion] Add nop op fusion condition --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 9b07d3c7..3b354b44 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -35,6 +35,8 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] if node1.get_device() != node2.get_device(): return False + if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): + return False if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION: from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate From 237905f0c6597d60cde8fb2f8aca6dc217421e7f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 12 Jul 2025 12:14:28 +0000 Subject: [PATCH 41/62] [Frontend] Handle edge case of parse_index_list --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 377eec7a..4c20fced 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -957,7 +957,8 @@ def parse_indices(self, expr, buffer=None, comments="") -> common.CSEVariable: def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: if buffer is None: buffer = self.applys - expr_list = [arg for arg in expr_list if arg != sympy.Number(0)] + zero_var = self.get_const_cse(0) + expr_list = [arg if arg != sympy.Number(0) else sympy.Symbol(str(zero_var)) for arg in expr_list] if len(expr_list) == 1 and expr_list[0].is_number: # Constant case From 0b9c0c39fc921a6ccb865c048d7ba803952c2d32 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 12 Jul 2025 12:34:31 +0000 Subject: [PATCH 42/62] Fix 2 --- PyTorchSimFrontend/llvm/llvm_caller_codegen.py | 2 +- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/PyTorchSimFrontend/llvm/llvm_caller_codegen.py b/PyTorchSimFrontend/llvm/llvm_caller_codegen.py index 835d9b80..3690f533 100644 --- a/PyTorchSimFrontend/llvm/llvm_caller_codegen.py +++ b/PyTorchSimFrontend/llvm/llvm_caller_codegen.py @@ -231,6 +231,6 @@ def get_spad_size(self, binary_path): spad_end = int(parts[1], 16) if spad_start is None or spad_end is None: - raise ValueError("Could not find .spad addresses") + return 0 spad_size = spad_end - spad_start return spad_size \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 4c20fced..5c344d1d 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -959,6 +959,7 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: buffer = self.applys zero_var = self.get_const_cse(0) expr_list = [arg if arg != sympy.Number(0) else sympy.Symbol(str(zero_var)) for arg in expr_list] + dim_list = [f"d{i}" for i in range(len(expr_list))] if len(expr_list) == 1 and expr_list[0].is_number: # Constant case @@ -972,18 +973,18 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: for idx, arg in enumerate(expr_list): if arg.is_Mul and arg.args[0].is_number: new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) - new_expr_list[idx] = arg.subs(arg.args[1], new_arg) + new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) indices.append(str(new_arg)) elif not arg.is_number: new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) - new_expr_list[idx] = new_arg + new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: new_expr_list[idx] = arg # Extract index var expr_str = str(sum(new_expr_list)) - args = ", ".join(map(str, indices)) + args = ", ".join(map(str, dim_list)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]") From 5bcc9693d8ea6cae042cf6340569e2f05a444eba Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 14 Jul 2025 05:40:13 +0000 Subject: [PATCH 43/62] [Frontend] Fix apply gen code --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 5 ++++- PyTorchSimFrontend/mlir/mlir_conv_template.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 5c344d1d..99a48fb6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -958,7 +958,7 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: if buffer is None: buffer = self.applys zero_var = self.get_const_cse(0) - expr_list = [arg if arg != sympy.Number(0) else sympy.Symbol(str(zero_var)) for arg in expr_list] + expr_list = [arg for arg in expr_list] dim_list = [f"d{i}" for i in range(len(expr_list))] if len(expr_list) == 1 and expr_list[0].is_number: @@ -980,7 +980,10 @@ def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: + const_var = self.get_const_cse(int(arg)) + new_arg = sympy.Symbol(f"{const_var}") new_expr_list[idx] = arg + indices.append(str(new_arg)) # Extract index var expr_str = str(sum(new_expr_list)) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index cd4ddf82..4792c6ac 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -57,7 +57,7 @@ {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> %c0 = arith.constant 0 : index - {{- kernel.def_local_vars(indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} affine.for %tile_m = 0 to {{ BATCH }} step {{ TILE_M }} { affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { From 831fa9f8e2e4a8f1a297d837d6e4bd52e5ccb09b Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 14 Jul 2025 12:42:12 +0000 Subject: [PATCH 44/62] [Frontend] Indirect access fix --- .../mlir/mlir_codegen_backend.py | 18 +++++++++++------- PyTorchSimFrontend/mlir/mlir_scheduling.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 99a48fb6..51a79ebd 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -919,19 +919,22 @@ def convert_index(self, expr, buffer): index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})") return index - def parse_indices(self, expr, buffer=None, comments="") -> common.CSEVariable: + def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> common.CSEVariable: if buffer is None: buffer = self.applys # Constant case - if expr.is_number: + if expr.is_number and len(indirect_dims) == 0: return self.get_const_cse(int(expr)) # Identity case - if len(expr.args) == 0: + if len(expr.args) == 0 and len(indirect_dims) == 0: return expr - args = list(expr.args) + if len(expr.args) == 0: + args = [expr] + else: + args = list(expr.args) # Sort index variable.. ex) (%index1, %index0) args_dict = {term: list(term.free_symbols)[0] for term in args if term.free_symbols} sorted_args = sorted(args_dict.keys(), key=lambda term: str(args_dict[term])) @@ -947,11 +950,12 @@ def parse_indices(self, expr, buffer=None, comments="") -> common.CSEVariable: indices.append(str(new_arg)) # Extract index var + indirect_args = [f"%{i}" for i in indirect_dims] expr_str = str(expr) args = ", ".join(map(str, indices)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") + map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[] {comments}") + index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}") return index def parse_index_list(self, expr_list:list, buffer=None) -> common.CSEVariable: @@ -1554,7 +1558,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): local_dims = total_dims # Brodatcast tile shape - index_var = self.parse_indices(index, buffer=buffer) + index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims) if kg_tile_desc.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vlane_split_axis) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 3b354b44..c8ed9efc 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -122,7 +122,7 @@ def can_fuse_horizontal(self, node1, node2): if (isinstance(template_node, MLIRBMMTemplate) or isinstance(template_node, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: return False - if template_node.group[1][0] != act_node.get_nodes()[0].node.data.get_size(): + if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): return False self.revert_group(act_node) return True From 7abca4d11127462e3e0d1fd15ff3f1a6b166cfb7 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 14 Jul 2025 15:20:38 +0000 Subject: [PATCH 45/62] [Frontend/Fusion] Add something OMG --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 50 ++++++++++++---------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index c8ed9efc..e63df4fb 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -100,8 +100,11 @@ def can_fuse_horizontal(self, node1, node2): from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - if node1.is_template() and len(node1.get_nodes())==1 and isinstance(node1.node.template, MLIRMaxPoolTemplate) or \ - node2.is_template() and len(node1.get_nodes())==1 and isinstance(node2.node.template, MLIRMaxPoolTemplate): + template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) + template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) + + if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ + template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): return False # Pointwise check @@ -111,7 +114,7 @@ def can_fuse_horizontal(self, node1, node2): return False # Pattern check - template_node, act_node = (node1, node2) if node1.is_template() else (node2, node1) + template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) if not has_depedency: return False @@ -119,7 +122,7 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != act_node.group: # We don't fuse this case... - if (isinstance(template_node, MLIRBMMTemplate) or isinstance(template_node, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: return False if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): @@ -132,25 +135,26 @@ def can_fuse_horizontal(self, node1, node2): return True return False - def revert_group(self, act_node): - args, var_ranges = dependencies.index_vars_no_squeeze( - act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" - ) - body = LoopBody( - act_node.node.get_store_function(), - (args if act_node.node.get_reduction_type() else args[:1]), - var_ranges, - ) - index_size = [] - reduce_size = [] - for v, s in var_ranges.items(): - if v in args[0]: - index_size.append(s) - else: - reduce_size.append(s) - node_device = act_node.get_device() - ranges = (index_size, reduce_size) - act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) + def revert_group(self, act_nodes): + for act_node in act_nodes.get_nodes(): + args, var_ranges = dependencies.index_vars_no_squeeze( + act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" + ) + body = LoopBody( + act_node.node.get_store_function(), + (args if act_node.node.get_reduction_type() else args[:1]), + var_ranges, + ) + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + index_size.append(s) + else: + reduce_size.append(s) + node_device = act_node.get_device() + ranges = (index_size, reduce_size) + act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) From ed3130742b21c5b3e83fa999329891c228e3b54b Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 15 Jul 2025 01:58:20 +0000 Subject: [PATCH 46/62] [Frontend] Fix dima_alising for conv_template --- PyTorchSimFrontend/mlir/mlir_conv_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_template.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 4792c6ac..73cf710f 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -275,7 +275,7 @@ def render(self, dram_var = "Y", dram_idx = Y_idx, dram_tile_desc = Y_tile_desc, - dim_aliasing = {"index0":"c0", "index1":"tile_n", "index2":"o_h", "index3":"tile_m"} + dim_aliasing = {"index0":"tile_m", "index1":"tile_n", "index2":"o_h", "index3":"o_w"} ) kernel.exception_nodes["X"] = {"numel" : (I_W+2*PADDING_W)*(I_H+2*PADDING_H)*I_C*BATCH} code = self._template_from_string(conv_template).render(**kernel.render_options) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 9e7a104e..f802f8e8 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -1030,8 +1030,6 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: tile_desc.vec_size=64 - if tile_desc.get_numel_per_lane() < tile_desc.vec_size: - tile_desc.vec_size = tile_desc.get_numel_per_lane() if prologue: self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() From 80b7a85c233577f32d6b8ad80fe5279cf1dd3104 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 15 Jul 2025 05:19:05 +0000 Subject: [PATCH 47/62] [Frontend/Scheduling] Fix reduction fusion condition --- .github/workflows/pull-request.yml | 10 ---------- .github/workflows/pull-request_mobile.yml | 10 ---------- PyTorchSimFrontend/mlir/mlir_scheduling.py | 2 ++ 3 files changed, 2 insertions(+), 20 deletions(-) diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index 9d440df6..bc5c9dab 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -524,16 +524,6 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py - - name: Run test_matmul_layernorm.py - env: - GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} - run: | - echo "Running test_matmul_layernorm.py" - docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ - ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_layernorm.py - - name: Run test_bmm_reduction.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} diff --git a/.github/workflows/pull-request_mobile.yml b/.github/workflows/pull-request_mobile.yml index 45d73fa8..0043eaf4 100644 --- a/.github/workflows/pull-request_mobile.yml +++ b/.github/workflows/pull-request_mobile.yml @@ -534,16 +534,6 @@ jobs: -e TORCHSIM_DUMP_PATH=/dump \ ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_reduction.py - - name: Run test_matmul_layernorm.py - env: - GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} - run: | - echo "Running test_matmul_layernorm.py" - docker run --rm \ - -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ - -e TORCHSIM_DUMP_PATH=/dump \ - ghcr.io/psal-postech/torchsim-ci:${GITHUB_SHA} python3 PyTorchSim/tests/Fusion/test_matmul_layernorm.py - - name: Run test_bmm_reduction.py env: GIT_ACCESS_TOKEN: ${{ secrets.GIT_ACCESS_TOKEN }} diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index e63df4fb..ffc001da 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -87,6 +87,8 @@ def can_fuse_horizontal(self, node1, node2): _, (vars2, reduce2) = node2.group # Reduction is currently not supported + if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template(): + return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users if node1.is_reduction() or node2.is_reduction(): return False From 06741953f65e99cc24c524eadbb5693037697353 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 17 Jul 2025 01:43:44 +0000 Subject: [PATCH 48/62] [Frontend/template] Fix tile stride in convolution templates Also, update mlir version(refactored fine-grained dma pass) --- PyTorchSimFrontend/mlir/mlir_conv_mt_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_conv_sb_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py | 2 +- tests/Fusion/test_prologue_fusion.py | 6 +++--- tests/Fusion/test_transformer_fusion.py | 1 + tests/test_indirect_access.py | 2 +- 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index 7968f813..8cd57077 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -202,7 +202,7 @@ def render(self, X_idx = [X_dim[0]*(I_W+2*PADDING_W)*BATCH*I_C, X_dim[1]*I_C*STRIDE_W, X_dim[2]*I_C*(I_W+2*PADDING_W), X_dim[3]] W_tile_size = [TILE_K_H, 1, TILE_K, TILE_N] - W_tile_stride = [TILE_K_W * TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] + W_tile_stride = [TILE_K * TILE_N, TILE_K * TILE_N, 1, TILE_K] W_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, 3, vlane_stride) W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) W_tile_desc.set_name("weight_buffer") diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index f2df1e43..6c31776d 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -210,7 +210,7 @@ def render(self, W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] - Y_tile_stride = [TILE_O_W * TILE_M * TILE_N, TILE_M * TILE_N, TILE_M, 1] # N, C, H, W + Y_tile_stride = [TILE_O_H * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) Y_tile_desc.set_name("output_buffer") diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 3b60dcbc..a4ea0b20 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -211,7 +211,7 @@ def render(self, W_idx = [W_dim[0]*K_W*I_C*O_C , W_dim[1]*I_C*O_C, W_dim[2]*O_C, W_dim[3]] Y_tile_size = [1, TILE_N, TILE_O_H, TILE_M] - Y_tile_stride = [TILE_O_W * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W + Y_tile_stride = [TILE_O_H * TILE_M * TILE_N, TILE_M, TILE_M * TILE_N, 1] # N, C, H, W Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) Y_tile_desc.set_name("output_buffer") diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py index d5d1cdb1..797f9e76 100644 --- a/tests/Fusion/test_prologue_fusion.py +++ b/tests/Fusion/test_prologue_fusion.py @@ -91,7 +91,7 @@ def bmm(a, b, c, d): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - #test_elem_broadcast_fusion(device) - #test_elem_fusion(device) - #test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) + test_elem_broadcast_fusion(device) + test_elem_fusion(device) + test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) test_elem_bmm_weight_fusion(device, batch_size=12, m=512, n=512, k=64) \ No newline at end of file diff --git a/tests/Fusion/test_transformer_fusion.py b/tests/Fusion/test_transformer_fusion.py index 15bacb39..0f68948e 100644 --- a/tests/Fusion/test_transformer_fusion.py +++ b/tests/Fusion/test_transformer_fusion.py @@ -206,6 +206,7 @@ def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() + test_MHA(device) test_DecoderBlock(device) # test_DecoderBlock_validation() # test_Attention(device, head=16, seq=512, d_k=64) diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py index 6d16c9d0..b7b20074 100644 --- a/tests/test_indirect_access.py +++ b/tests/test_indirect_access.py @@ -27,7 +27,7 @@ def vectoradd(a, idx, b): opt_fn = torch.compile(dynamic=False)(vectoradd) res = opt_fn(x, idx, y) out = vectoradd(x.cpu(), idx.cpu(), y.cpu()) - test_result("VectorAdd", res, out) + test_result("Indirect VectorAdd", res, out) def test_embedding(device, vocab_size, dim): emb = torch.nn.Embedding(vocab_size, dim) From 6cd7b8be35922e9707a421dfac6a645ddae1e1b4 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 17 Jul 2025 07:49:59 +0000 Subject: [PATCH 49/62] [Frontend] Update fusion condition --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index ffc001da..8ea995df 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -14,7 +14,7 @@ from torch._inductor import dependencies from . import mlir_common -from . import mlir_lowering +from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering class MLIRScheduling(BaseScheduling): count = 0 @@ -41,15 +41,15 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION: from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction() and len(node2.get_nodes())==1: + if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): # For matmul/bmm+reduction case - size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.node.get_size(), 1) * reduce(operator.mul, node2.node.get_reduction_size(), 1) - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.node).split("\n") if "r0" in i][1] + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] target_symbol = symbols("r0") # We can't fuse dim=-1 layout_possible = int(sympify(stride).coeff(target_symbol)) != 1 # Directed linked? - dependency_check = node2 in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 + dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) return size_match and layout_possible and dependency_check and dependency_size @@ -66,7 +66,7 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return False if len(node1.read_writes.writes) != 1: return False - if len(node1.users) != 1: + if len([node for node in node1.users if node.get_name() != "OUTPUT"]) != 1: # FIXME. Any good way to check this? return False if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: node1 = self.revert_group(node1) From 4771bcbb5f0e40956ce04fb25a24e3e68380486a Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 17 Jul 2025 07:51:15 +0000 Subject: [PATCH 50/62] [Test] Add test_bmm_reduction fusion --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 2 +- tests/Fusion/test_bmm_reduction.py | 52 ++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/Fusion/test_bmm_reduction.py diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index b81b3862..9a9785e1 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -125,7 +125,7 @@ %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> {% endif %} %c0 = arith.constant 0 : index - {{ kernel.def_local_vars() }} + {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0=0 to {{ B }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { diff --git a/tests/Fusion/test_bmm_reduction.py b/tests/Fusion/test_bmm_reduction.py new file mode 100644 index 00000000..42e38095 --- /dev/null +++ b/tests/Fusion/test_bmm_reduction.py @@ -0,0 +1,52 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_bmm_reduce(device, batch=12, size=512): + def bmm(a, b): + result = torch.bmm(a, b.transpose(1,2)) + return result, result.max(dim=1).values + torch.manual_seed(0) + N = size + input = torch.randn(batch, N, 64) + weight = torch.randn(batch, N, 64) + #input = torch.arange(1, N * N + 1, dtype=torch.float32).reshape(N, N).to(dtype=torch.float32) + #weight = torch.eye(N, dtype=torch.float32) + x1 = input.to(device=device) + w1 = weight.to(device=device) + x2 = input.to("cpu") + w2 = weight.to("cpu") + opt_fn = torch.compile(dynamic=False)(bmm) + res = opt_fn(x1, w1) + y = bmm(x2, w2) + test_result("BMM Reduction Fusion activation", res[0], y[0]) + test_result("BMM Reduction Fusion reduction", res[1], y[1]) + +if __name__ == "__main__": + import os + import sys + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + from Scheduler.scheduler import ExecutionEngine + module = ExecutionEngine.setup_device() + device = module.custom_device() + #test_bmm_reduce(device) + test_bmm_reduce(device, 12, 512) + test_bmm_reduce(device, 4, 256) + test_bmm_reduce(device, 6, 768) + test_bmm_reduce(device, 2, 128) From 22e167d2d677b186f00d8ca1cd8a84607e5e8db9 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 17 Jul 2025 13:57:53 +0000 Subject: [PATCH 51/62] [Frontend/Fusion] Add prologue fusion condition --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 8ea995df..f81c7b05 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -55,10 +55,8 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # For prologue fusion case if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: - # Return false if node2 is Convolution template - # if node2.get_nodes()[0].node.origin_node.target._name == 'aten::mm' or \ - # node2.get_nodes()[0].node.origin_node.target._name == 'aten::addmm': - # return False + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate target_node = base_template_node2[0].node if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': return False @@ -66,8 +64,11 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return False if len(node1.read_writes.writes) != 1: return False - if len([node for node in node1.users if node.get_name() != "OUTPUT"]) != 1: # FIXME. Any good way to check this? + if len(node1.users) != 1: return False + # We don't fuse this case... + if (isinstance(target_node.template, MLIRBMMTemplate) or isinstance(target_node.template, MLIRGemmTemplate)) and base_template_node2[0].group[1][0][0] == 1: + return False if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: node1 = self.revert_group(node1) return True From 46614427d04cb74ad4e9be5be338aa43c05bf15c Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 18 Jul 2025 06:14:57 +0000 Subject: [PATCH 52/62] [Frontend] Fix reverting the group when ther is no loop --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 26 +++++++++++++++++----- tests/Mixtral_8x7B/test_attention.py | 24 ++++++++++++++++++-- tests/MoE/test_moe.py | 5 ----- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index f81c7b05..773414d5 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -1,5 +1,6 @@ import os import math +import sympy from functools import reduce import operator from sympy import symbols, sympify, Symbol @@ -138,11 +139,12 @@ def can_fuse_horizontal(self, node1, node2): return True return False - def revert_group(self, act_nodes): + def revert_group(self, act_nodes, args=None, var_ranges=None): for act_node in act_nodes.get_nodes(): - args, var_ranges = dependencies.index_vars_no_squeeze( - act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" - ) + if args is None or var_ranges is None: + args, var_ranges = dependencies.index_vars_no_squeeze( + act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" + ) body = LoopBody( act_node.node.get_store_function(), (args if act_node.node.get_reduction_type() else args[:1]), @@ -167,10 +169,22 @@ def codegen_nodes(self, nodes): nodes, key=lambda x: int(x.is_reduction()) ).group - # There is no normal loop, then revert simplified group + # Note: We assume that ther is at least one loop in the nodes + # But, inductor simplifies the group, there could be no loop + # In that case, we add dummy loop(size=1) to the group if len(group) == 0: for idx, node in enumerate(nodes): - self.revert_group(node) + if len(node.node.data.get_size()) == 0: + continue + if len(reduction_group) != 0: + sym0, sym1 = sympy.Symbol("q0"), sympy.Symbol("q1") + args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), [sym1]] + var_ranges = {sym0: sympy.Number(1), sym1: reduction_group[0]} + else: + sym0 = sympy.Symbol("q0") + args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), []] + var_ranges = {sym0: sympy.Number(1)} + self.revert_group(node, args, var_ranges) _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group diff --git a/tests/Mixtral_8x7B/test_attention.py b/tests/Mixtral_8x7B/test_attention.py index cc2adc96..aa1af651 100644 --- a/tests/Mixtral_8x7B/test_attention.py +++ b/tests/Mixtral_8x7B/test_attention.py @@ -2,7 +2,7 @@ import torch import torch._dynamo import torch.utils.cpp_extension -from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, precompute_freqs_cis, sample +from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, RMSNorm, precompute_freqs_cis, sample def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -139,6 +139,25 @@ def concat_tensors(a, b): test_result("ConcatTensors", res, out) +def test_rmsnorm(device, seq=32): + dim = 512 + eps = 1e-5 + T = seq + rmsnorm = RMSNorm(dim=dim, eps=eps) + rmsnorm = rmsnorm.to(device=device) + + x = torch.randn([1, T, dim], dtype=torch.float32) + cpu_x = copy.deepcopy(x) + x = x.to(device) + + cpu_model = copy.deepcopy(rmsnorm).to("cpu") + opt_fn = torch.compile(dynamic=False)(rmsnorm) + + res = opt_fn(x) + cpu_res = cpu_model(cpu_x) + + test_result("RMSNorm", res, cpu_res) + if __name__ == "__main__": import os import sys @@ -147,7 +166,8 @@ def concat_tensors(a, b): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() + test_rmsnorm(device, seq=1) + test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) test_decode(device, 32, 3) - #test_concat(device, size1=(1, 8, 32, 64), size2=(1,8,1,64), dim=2) #test_attention(device) #test_ffn(device) diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index cf2f37f4..c5ab8107 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -1,12 +1,7 @@ # Owner(s): ["module: inductor"] import os -import shutil import sys -import time -import contextlib -import unittest import copy -import numpy as np import matplotlib.pyplot as plt From 2bea699200cb0e17b84b452bf28a3aaf7bf1a2f2 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 18 Jul 2025 14:24:05 +0000 Subject: [PATCH 53/62] [Frontend] Add mask in the reduction if needed --- .../mlir/mlir_codegen_backend.py | 92 +++++-------------- PyTorchSimFrontend/mlir/mlir_template.py | 8 +- 2 files changed, 32 insertions(+), 68 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 51a79ebd..79d735a3 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -846,6 +846,7 @@ def __init__(self, kernel_group, reason=None): self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() self.applys = IndentedBuffer() + self.masks = IndentedBuffer() self.dma_loads = IndentedBuffer() self.dma_stores = IndentedBuffer() self.indexed_buffer = IndentedBuffer() @@ -859,6 +860,7 @@ def __init__(self, kernel_group, reason=None): self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") self.apply_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="apply") + self.mask_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="mask") self.iterator_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="iter") self.init_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init") self.init_vec_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init_vec") @@ -1030,25 +1032,9 @@ def load(self, name: str, index: sympy.Expr): self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector load instruction - needs_mask = self.compute_body_loop.size % self.compute_body_loop.step != 0 and len(index.free_symbols) == len(self.ranges) if compute_vec_size > 1: - if needs_mask: - index_shape = f"vector<{self.compute_body_loop.step}xindex>" - mask_shape = f"vector<{compute_vec_size}xi1>" - step_vec = self.cse.generate(self.loads, f"vector.step : {index_shape}") - upper_bound = self.get_const_cse(self.compute_body_loop.size, "index") - gap = self.cse.generate(self.loads, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") - gap_vec = self.cse.generate(self.loads, f"vector.broadcast %{gap} : index to {index_shape}") - mask_var = self.cse.generate(self.loads, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") - if padding: - pad_val = self.const_cse.generate(self.const_buffer, f"arith.constant 0x{mlir_common.MLIR_INF['-inf'][mlir_dtype]:x} : {mlir_dtype}") - else: - pad_val = self.get_const_cse(0, mlir_dtype) - pad_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{pad_val} : {mlir_dtype} to {vshape}") - line = f"vector.maskedload %{sram_var}[{compute_index_var}], %{mask_var}, %{pad_vec} : {tile_shape}, {mask_shape}, {vshape} into {vshape}" - else: - operation = "affine.vector_load" - line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" + operation = "affine.vector_load" + line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" else: operation = "affine.load" line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}" @@ -1149,6 +1135,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): else: # Adjust shape and inital value init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {reduced_shape}") + self.register_var_info(init_vec, [vec_len, type_name]) acc_var = init_vec # Reduction body prepare @@ -1167,6 +1154,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.init_cse.reduction_cache[reduction_key] = init_vec # Reduction body codegen + mask_shape, mask_var = self.get_mask() + if mask_var is not None: + value = ops.where(mask_var, value, init_vec) result = reduction_partial_combine_vec(reduction_type, value, body_iter_arg) self.compute_body_loop.reduction_vars[body_acc] = (reduction_type, body_iter_arg, iterator, reduced_shape) self.compute_body_loop.affine_yield[result] = reduced_shape @@ -1423,6 +1413,7 @@ def codegen_loops(self): code.writelines(self.compute_body_loop.lines()) with contextlib.ExitStack() as stack: stack.enter_context(code.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) + code.splice(self.masks) code.splice(self.loads) code.splice(self.compute) code.splice(self.stores) @@ -1701,55 +1692,6 @@ def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype return f"memref.dma_start {src_operand}, {dst_operand}, %{dma_type}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape} {attribute}" - def adjust_tile_size(self): - if self.read_writes is not None: - read_writes = list(self.read_writes.reads) + list(self.read_writes.writes) - cv_list = [] - for node in read_writes: - if len(node) > 1: - cv_list.append(self.get_constant_vector2(node[1])) - max_element = max(cv_list, key=len) - max_nr_dim = len(max_element) - - sorted_max_element = sorted(max_element, key=lambda x:x[0]) - # Force vector tile size when 3D node is originated from view - if max_nr_dim == 3 and max_nr_dim != len(self.itervars): - self.tile_desc.n_col = min(self.tile_desc.get_tile_size(), sorted_max_element[1][0]) - self.tile_desc.n_row = 1 - return - - # Case 1. vector kernel - if len(self.itervars) == 1: - tile_size = self.tile_desc.get_tile_size() if self.tile_desc.get_tile_size() < self.ranges[0] else self.ranges[0] - min_tile_size_unit = self.vector_lane * self.vlen // (8 * self.precision) # TODO: VCIX widening is not implemented - self.tile_desc.n_col = math.ceil(tile_size / min_tile_size_unit) * min_tile_size_unit # padding - self.tile_desc.n_row = 1 - elif len(self.itervars) == 0: - self.tile_desc.n_col = 1 - self.tile_desc.n_row = 1 - - # Case 2. 2-D tensor (e.g., softmax) - if len(self.itervars) == 2 and self.reduction_depth == len(self.itervars): - # Avoid too much padding - if (self.ranges[0] <= self.vector_lane and self.ranges[0] <= self.tile_desc.n_row): - self.tile_desc.n_row = self.ranges[0] - self.tile_desc.used_vector_lane = self.ranges[0] - - # Case 2. 2-D reduction (e.g., batchnorm) - if len(self.itervars) == 2 and self.reduction_depth == len(self.itervars) - 1: - if (((self.ranges[0] + 1) // 2) <= self.vector_lane and ((self.ranges[0] + 1) // 2) <= self.tile_desc.n_row): - self.tile_desc.n_row = ((self.ranges[0] + 1) // 2) * 2 - self.tile_desc.used_vector_lane = (self.ranges[0] + 1) // 2 - - # Case 2. 3-D tensor kernel without reduction. Access vector granule! - if len(self.itervars) == 3 and self.reduction_depth == len(self.itervars): - self.tile_desc.n_col = self.ranges[-1] - self.tile_desc.n_row = 1 - - # Case 3. N-D tensor kernel with reduction. Not implemented. Need this? - if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars): - raise NotImplementedError() - def allocate_sram_buffer(self, dtype, dram_name, tile_desc, raw_index, buffer=None, forced_name=None): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] @@ -1805,6 +1747,22 @@ def get_tag_cse(self, value, shape="memref<1xi32>"): self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape}") return self.tags[value] + def get_mask(self): + if self.compute_body_loop.size % self.compute_body_loop.step == 0: + return None, None + compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() + index_shape = f"vector<{self.compute_body_loop.step}xindex>" + mask_shape = f"vector<{compute_vec_size}xi1>" + + upper_bound = self.get_const_cse(self.compute_body_loop.size) + step_vec = self.const_cse.generate(self.const_buffer, f"vector.step : {index_shape}") + + gap = self.mask_cse.generate(self.masks, f"arith.subi %{upper_bound}, %{self.compute_idx} : index") + gap_vec = self.mask_cse.generate(self.masks, f"vector.broadcast %{gap} : index to {index_shape}") + mask_var = self.mask_cse.generate(self.masks, f"arith.cmpi ult, %{step_vec}, %{gap_vec} : {index_shape}") + self.register_var_info(mask_var, [compute_vec_size, "i1"]) + return mask_shape, mask_var + def convert_indirect_indexing(self, index :sympy.Expr): if "tmp" not in str(index): return index diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index f802f8e8..c6cd4a7e 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -450,6 +450,7 @@ def template_store(): stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) if self.reduction_fusion: compute_body.writelines(self.reduction_body_loop.lines()) + compute_body.splice(self.masks) stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) compute_body.splice(self.loads) compute_body.splice(self.compute) @@ -889,7 +890,6 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): name = f"{reduction_type}_buffer{self.reduction_buffer_idx}" self.reduction_buffer_idx += 1 index = "dummy_index" # Not used - tile_numel_per_lane = self.compute_body_loop.step * self.reduction_body_loop.size # ??? sram_var, _ = self.get_scratchpad_buffer(dtype, name, local_tile_desc, index, self.const_buffer) self.reduction_epilogue_result[reduction_key] = sram_var @@ -903,6 +903,12 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): self.register_var_info(out, [self.compute_body_loop.step, type_name]) # Reduction body codegen + init = self.const_cse.generate(self.const_buffer, f"arith.constant {reduction_init(reduction_type, dtype)} : {type_name}") + init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {type_name} to {vshape}") + self.register_var_info(init_vec, [local_tile_desc.get_compute_vec_size(), type_name]) + mask_shape, mask_var = self.get_mask() + if mask_var is not None: + value = ops.where(mask_var, value, init_vec) result = reduction_partial_combine_vec(reduction_type, value, out) # Store partial result From 9b235108760d765663d0ddcc33f2f9bc92b9aa2a Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 19 Jul 2025 07:24:01 +0000 Subject: [PATCH 54/62] [Rename] Use encoder instead of decoder --- experiments/BERT.py | 8 ++--- test_extension_backend.py | 4 +-- tests/Fusion/test_transformer_fusion.py | 46 ++++++++++++------------- tests/test_conv2d.py | 6 +++- tests/test_pool.py | 6 ++-- tests/test_resnet.py | 6 ++-- tests/test_scheduler.py | 2 +- tests/test_sparsity.py | 34 +++++++++--------- tests/test_spmm_scheduler.py | 2 +- tests/test_transformer.py | 18 +++++----- 10 files changed, 68 insertions(+), 64 deletions(-) diff --git a/experiments/BERT.py b/experiments/BERT.py index 7086ad9a..3534505d 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -7,8 +7,8 @@ def run_BERT(size, input_seq, config): from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - # from tests.test_transformer import DecoderBlock - from tests.Fusion.test_transformer_fusion import DecoderBlock + # from tests.test_transformer import EncoderBlock + from tests.Fusion.test_transformer_fusion import EncoderBlock scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, backend_config=config) device = scheduler.execution_engine.module.custom_device() @@ -16,10 +16,10 @@ def run_BERT(size, input_seq, config): embedding_size = {'base': 768, 'large': 1024, 'xlarge': 2048} heads = {'base': 12, 'large': 16, 'xlarge': 32} # hidden/64 https://arxiv.org/pdf/1909.11942 cpu_query = torch.randn(input_seq, hidden_dim[size]) - decoder_block = DecoderBlock(embedding_size[size], heads[size]).eval() + encoder_block = EncoderBlock(embedding_size[size], heads[size]).eval() query = cpu_query.clone().to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block.to(device=device)) + opt_fn = torch.compile(dynamic=False)(encoder_block.to(device=device)) SchedulerDNNModel.register_model(f"BERT-{size}", opt_fn) request = Request(f"BERT-{size}", [query], [], request_queue_idx=0) diff --git a/test_extension_backend.py b/test_extension_backend.py index 10bc9854..f0a9353a 100644 --- a/test_extension_backend.py +++ b/test_extension_backend.py @@ -12,7 +12,7 @@ from tests.test_matmul import test_matmul from tests.test_bmm import test_BMM from tests.test_cnn import test_CNN -from tests.test_transformer import test_DecoderBlock +from tests.test_transformer import test_EncoderBlock from tests.test_resnet import test_resnet from tests.test_mlp import test_mlp, test_mlp_inf from tests.MoE.test_moe import test_moe @@ -46,7 +46,7 @@ #test_matmul(device, 33, 45, 68) #test_BMM(device) #test_CNN(device) - #test_DecoderBlock(device) + #test_EncoderBlock(device) #test_resnet(device) #test_mlp(device) #test_mlp_inf(device, batch_size=64, input_size=256, hidden_size=512, output_size=256, sparsity=0.97) diff --git a/tests/Fusion/test_transformer_fusion.py b/tests/Fusion/test_transformer_fusion.py index 0f68948e..0e500b5b 100644 --- a/tests/Fusion/test_transformer_fusion.py +++ b/tests/Fusion/test_transformer_fusion.py @@ -53,9 +53,9 @@ def forward(self, query, key, value): del value return self.linears[-1](x) -class DecoderBlock_origin(torch.nn.Module): +class EncoderBlock_origin(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock_origin, self).__init__() + super(EncoderBlock_origin, self).__init__() self.multihead_attn = my_MultiheadAttention_origin(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -111,9 +111,9 @@ def forward(self, x, residual): out = torch.matmul(self.weight, x.transpose(-1, -2)) + self.bias[:, None] # (1, 768, 512) return self.layer_norm(out.transpose(-1, -2) + residual) -class DecoderBlock(torch.nn.Module): +class EncoderBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock, self).__init__() + super(EncoderBlock, self).__init__() self.multihead_attn = my_MultiheadAttention(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -130,18 +130,18 @@ def forward(self, x): act_result = self.act(ffn1_result) return self.matmulln2(act_result, result) -def test_DecoderBlock(device, head=12, embed_dim=768, input_seq=512): +def test_EncoderBlock(device, head=12, embed_dim=768, input_seq=512): cpu_query = torch.randn(input_seq, embed_dim) - decoder_block = DecoderBlock(embed_dim, head) - cpu_res = decoder_block(cpu_query) + encoder_block = EncoderBlock(embed_dim, head) + cpu_res = encoder_block(cpu_query) query = cpu_query.clone().to(device=device) - decoder_block.to(device=device) + encoder_block.to(device=device) with torch.no_grad(): - opt_fn = torch.compile(dynamic=False)(decoder_block) + opt_fn = torch.compile(dynamic=False)(encoder_block) res = opt_fn(query) - test_result("Decoder Block Forwrad", res, cpu_res) + test_result("Encoder Block Forwrad", res, cpu_res) def test_Attention(device, head=16, seq=512, d_k=64): def attention(query, key, value): @@ -165,18 +165,18 @@ def attention(query, key, value): def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): MHA = my_MultiheadAttention(num_heads, embed_dim) cpu_query = torch.randn(input_seq, embed_dim) - cpu_res = MHA(cpu_query, cpu_query, cpu_query) - - query = cpu_query.clone().to(device=device) - MHA.to(device=device) - opt_fn = torch.compile(dynamic=False)(MHA) - res = opt_fn(query, query, query) + with torch.no_grad(): + cpu_res = MHA(cpu_query, cpu_query, cpu_query) + query = cpu_query.clone().to(device=device) + MHA.to(device=device) + opt_fn = torch.compile(dynamic=False)(MHA) + res = opt_fn(query, query, query) test_result("MHA Forward", res, cpu_res) -def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): - bert_origin = DecoderBlock_origin(embed_dim, head) - bert = DecoderBlock(embed_dim, head) +def test_EncoderBlock_validation(head=12, embed_dim=768, input_seq=512): + bert_origin = EncoderBlock_origin(embed_dim, head) + bert = EncoderBlock(embed_dim, head) bert.multihead_attn.linears[0].weight = bert_origin.multihead_attn.linears[0].weight bert.multihead_attn.linears[0].bias = bert_origin.multihead_attn.linears[0].bias @@ -196,7 +196,7 @@ def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): origin_res = bert_origin(origin_query) res = bert(query) - test_result("Decoder Block Validation", res, origin_res) + test_result("Encoder Block Validation", res, origin_res) if __name__ == "__main__": import os @@ -206,8 +206,8 @@ def test_DecoderBlock_validation(head=12, embed_dim=768, input_seq=512): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_MHA(device) - test_DecoderBlock(device) - # test_DecoderBlock_validation() + #test_MHA(device) + test_EncoderBlock(device) + # test_EncoderBlock_validation() # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 8667792a..96ee05eb 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -43,10 +43,14 @@ def custom_conv2d(a, b, bias): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_conv2d(device, batch_size=1, in_channels=3, out_channels=32, input_size=32, kernel_size=3, stride=1, padding=1) + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) diff --git a/tests/test_pool.py b/tests/test_pool.py index e94df65b..304a5e7c 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -50,6 +50,6 @@ def avgpool(a): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_maxpool(device, b=1, c=8, h=16, w=16) - test_maxpool(device, b=1, c=8, h=112, w=112) - test_avgpool(device) + #test_maxpool(device, b=1, c=8, h=16, w=16) + #test_maxpool(device, b=1, c=8, h=112, w=112) + test_avgpool(device, b=1, c=512, h=7, w=7) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 5e96b922..f54ce9be 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -18,13 +18,13 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) exit(1) -def test_resnet(device): +def test_resnet(device, batch=1): from torchvision.models import resnet - # model = resnet._resnet(resnet.BasicBlock, [1, 1, 0, 0], weights=None, progress=False).eval() with torch.no_grad(): + #model = resnet._resnet(resnet.BasicBlock, [1, 1, 1, 1], weights=None, progress=False).eval() model = resnet18().eval() model.to(device, memory_format=torch.channels_last) - input = torch.randn(1, 3, 224, 224) + input = torch.randn(batch, 3, 224, 224) x1 = input.to(device=device, memory_format=torch.channels_last) x2 = input.cpu().to(memory_format=torch.channels_last) opt_fn = torch.compile(dynamic=False)(model) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e05fa392..c64093a0 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2,7 +2,7 @@ import sys import torch from torchvision.models import resnet18 as model1 -from test_transformer import DecoderBlock as model2 +from test_transformer import EncoderBlock as model2 base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') sys.path.append(base_path) diff --git a/tests/test_sparsity.py b/tests/test_sparsity.py index b3945520..3e079f83 100644 --- a/tests/test_sparsity.py +++ b/tests/test_sparsity.py @@ -8,7 +8,7 @@ import torch._dynamo import torch.utils.cpp_extension sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) -from test_transformer import DecoderBlock, test_result +from test_transformer import EncoderBlock, test_result from test_mlp import MLP def apply_random_zero(tensor, zero_prob, block_size=8): @@ -35,30 +35,30 @@ def count_zeros_in_tensor_list(tensor_list): def test_dec_inf(device, sparsity=0.0, block=8): torch.manual_seed(0) - decoder_block = DecoderBlock(768, 12) + encoder_block = EncoderBlock(768, 12) cpu_query = torch.randn(512, 768) query = cpu_query.clone().to(device=device) - cpu_y = decoder_block(cpu_query) + cpu_y = encoder_block(cpu_query) with torch.no_grad(): - decoder_block.multihead_attn.linears[0].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[0].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[1].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[1].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[2].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[2].weight, sparsity, block_size=block)) - decoder_block.multihead_attn.linears[3].weight.copy_(apply_random_zero(decoder_block.multihead_attn.linears[3].weight, sparsity, block_size=block)) - decoder_block.ffn1.weight.copy_(apply_random_zero(decoder_block.ffn1.weight, sparsity, block_size=block)) - decoder_block.ffn2.weight.copy_(apply_random_zero(decoder_block.ffn2.weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[0].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[0].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[1].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[1].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[2].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[2].weight, sparsity, block_size=block)) + encoder_block.multihead_attn.linears[3].weight.copy_(apply_random_zero(encoder_block.multihead_attn.linears[3].weight, sparsity, block_size=block)) + encoder_block.ffn1.weight.copy_(apply_random_zero(encoder_block.ffn1.weight, sparsity, block_size=block)) + encoder_block.ffn2.weight.copy_(apply_random_zero(encoder_block.ffn2.weight, sparsity, block_size=block)) count_zeros_in_tensor_list([ - decoder_block.multihead_attn.linears[0].weight, - decoder_block.multihead_attn.linears[1].weight, - decoder_block.multihead_attn.linears[2].weight, - decoder_block.multihead_attn.linears[3].weight, - decoder_block.ffn1.weight, - decoder_block.ffn2.weight + encoder_block.multihead_attn.linears[0].weight, + encoder_block.multihead_attn.linears[1].weight, + encoder_block.multihead_attn.linears[2].weight, + encoder_block.multihead_attn.linears[3].weight, + encoder_block.ffn1.weight, + encoder_block.ffn2.weight ]) - decoder_block.to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block) + encoder_block.to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block) y = opt_fn(query) test_result("MLP Forward", y, cpu_y) diff --git a/tests/test_spmm_scheduler.py b/tests/test_spmm_scheduler.py index 73bbdbae..1cf0d3b3 100644 --- a/tests/test_spmm_scheduler.py +++ b/tests/test_spmm_scheduler.py @@ -5,7 +5,7 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request from test_sparse_core import SparseMLP as model1 -from test_transformer import DecoderBlock as model2 +from test_transformer import EncoderBlock as model2 CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') if __name__ == "__main__": diff --git a/tests/test_transformer.py b/tests/test_transformer.py index cfa2a622..4d45707e 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -53,9 +53,9 @@ def forward(self, query, key, value): del value return self.linears[-1](x) -class DecoderBlock(torch.nn.Module): +class EncoderBlock(torch.nn.Module): def __init__(self, embed_dim, num_heads): - super(DecoderBlock, self).__init__() + super(EncoderBlock, self).__init__() self.multihead_attn = my_MultiheadAttention(num_heads, embed_dim) self.layer_norm = torch.nn.LayerNorm(embed_dim) self.ffn1 = torch.nn.Linear(embed_dim, embed_dim*4) @@ -71,17 +71,17 @@ def forward(self, x): ffn2_result = self.ffn2(act_result) return self.layer_norm(ffn2_result + result) -def test_DecoderBlock(device, head=12, embed_dim=768, input_seq=512): +def test_EncoderBlock(device, head=12, embed_dim=768, input_seq=512): cpu_query = torch.randn(1, input_seq, embed_dim) - decoder_block = DecoderBlock(embed_dim, head) - cpu_res = decoder_block(cpu_query) + encoder_block = EncoderBlock(embed_dim, head) + cpu_res = encoder_block(cpu_query) query = cpu_query.clone().to(device=device) - decoder_block.to(device=device) - opt_fn = torch.compile(dynamic=False)(decoder_block) + encoder_block.to(device=device) + opt_fn = torch.compile(dynamic=False)(encoder_block) res = opt_fn(query) - test_result("Decoder Block Forwrad", res, cpu_res) + test_result("Encoder Block Forwrad", res, cpu_res) def test_Attention(device, head=16, seq=512, d_k=64): def attention(query, key, value): @@ -122,6 +122,6 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() - test_DecoderBlock(device) + test_EncoderBlock(device) # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) From 4d1d0f50e4ebeb36c3bae70e975b275c9d027096 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 19 Jul 2025 08:13:08 +0000 Subject: [PATCH 55/62] [Frotend/Fusion] Relax the prologue fusion condition --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 20 +++++++++------- PyTorchSimFrontend/mlir/mlir_template.py | 28 ++-------------------- 2 files changed, 13 insertions(+), 35 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 773414d5..3eff0ddc 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -65,8 +65,6 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return False if len(node1.read_writes.writes) != 1: return False - if len(node1.users) != 1: - return False # We don't fuse this case... if (isinstance(target_node.template, MLIRBMMTemplate) or isinstance(target_node.template, MLIRGemmTemplate)) and base_template_node2[0].group[1][0][0] == 1: return False @@ -250,12 +248,19 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() for node in [template_node, *prologue_nodes, *epilogue_nodes]: node.mark_run() + # Partial codgen template nodes partial_code = render() + + # Swap load/store functions + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue + kernel.store_reduction = kernel.store_reduction_epilogue + kernel.reduction = kernel.reduction_epilogue + + # Codegen prologue nodes if prologue_nodes: # Flush created varaibles, since template fusion doen't share variable with kernel.prologue_buffer_group.as_local(): - kernel.load = kernel.load_epilogue - kernel.store = kernel.store_prologue _, (group, reduction_group) = max( [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) ).group @@ -292,16 +297,12 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e } node.codegen((vars, reduction_vars)) + # Codegen epilogue nodes tile_desc = kernel.set_tile_size(kernel.epilogue_info) kernel.kernel_group.set_tile_info(tile_desc) kernel.call_ranges = None if epilogue_nodes: with kernel.epilogue_buffer_group.as_local(): - kernel.load = kernel.load_epilogue - kernel.store = kernel.store_epilogue - kernel.store_reduction = kernel.store_reduction_epilogue - kernel.reduction = kernel.reduction_epilogue - _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) ).group @@ -315,6 +316,7 @@ def codegen_template_code(self, kernel, render, template_node, prologue_nodes, e if isinstance(partial_code, str) else partial_code.finalize() ) + # For consistency, white space could make wrong write_path buffer = IndentedBuffer() buffer.splice(src_code) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index c6cd4a7e..66d6b578 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -423,6 +423,7 @@ def codegen_prologue_body(self): compute_body.splice(self.compute) compute_body.splice(self.stores) body.splice(compute_body) + body.splice(self.dma_stores) return body def codegen_epilogue_body(self): @@ -723,31 +724,6 @@ def get_spad_size_per_lane(self, tile_m, tile_n): size = tile_m * ((tile_n + self.vector_lane - 1) // self.vector_lane) return max(size, 2) # vector load/store - def store_prologue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - dtype = V.graph.get_dtype(name) - mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - tile_shape = self.kernel_group.tile_desc.get_mlir_shape(mlir_dtype) - - # Compute vector unit size - vshape = self.kernel_group.tile_desc.get_mlir_vshape(mlir_dtype) - compute_vec_size = self.kernel_group.tile_desc.get_compute_vec_size() - - sram_var = self.buffer_names[name] - zero_var = self.get_const_cse(0) - - _, operand_type = self.var_info[value] - if mlir_dtype != operand_type: - value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) - compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) - # Generate vector load instruction - if compute_vec_size > 1: - operation = "affine.vector_store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" - else: - operation = "affine.store" - line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - self.stores.writeline(line) - def load_epilogue(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) @@ -847,7 +823,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): else: operation = "affine.store" line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" - self.stores.writeline(DeferredLine(name, line)) + self.stores.writeline(line) # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" From 8253c1dbac6d4e00e805d590d3934b110f11df4a Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 19 Jul 2025 12:46:49 +0000 Subject: [PATCH 56/62] [Frontend] Avoid tricky cases in the prologue fusion --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 2 ++ tests/test_conv2d.py | 1 + 2 files changed, 3 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 3eff0ddc..e037207f 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -65,6 +65,8 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return False if len(node1.read_writes.writes) != 1: return False + if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME + return False # We don't fuse this case... if (isinstance(target_node.template, MLIRBMMTemplate) or isinstance(target_node.template, MLIRGemmTemplate)) and base_template_node2[0].group[1][0][0] == 1: return False diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index 96ee05eb..c679b431 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -43,6 +43,7 @@ def custom_conv2d(a, b, bias): from Scheduler.scheduler import ExecutionEngine module = ExecutionEngine.setup_device() device = module.custom_device() + torch._dynamo.config.cache_size_limit = 64 test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) From edc3f57261c9a8ad8253b9071800018de4cf07f5 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Sat, 19 Jul 2025 13:52:43 +0000 Subject: [PATCH 57/62] [Frontend] Fix store epilogue --- PyTorchSimFrontend/mlir/mlir_template.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 66d6b578..1da2e755 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -805,10 +805,12 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): if name not in self.buffer_names: sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) self.buffer_names[name] = sram_var + store_force = False else: zero_cse = self.get_const_cse(0) sram_dims = len(tile_shape.split("x")) - 1 sram_index_var = ",".join([f"%{zero_cse}"] * sram_dims) + store_force = True sram_var = self.buffer_names[name] zero_var = self.get_const_cse(0) @@ -823,6 +825,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): else: operation = "affine.store" line = f"{operation} %{value}, %{sram_var}[{compute_index_var}] : {tile_shape}" + line = line if store_force else DeferredLine(name, line) self.stores.writeline(line) # Generate DMA instruction From 1d49f43e340343b727f883c04c0812a27965b022 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 21 Jul 2025 04:43:57 +0000 Subject: [PATCH 58/62] [Config] Remove deprecated config --- PyTorchSimFrontend/extension_config.py | 3 --- scripts/chiplet_prep.sh | 9 ++++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 1761e05c..d60826a1 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -53,9 +53,6 @@ # For block sparse CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0)) -CONFIG_FORCE_TILE_M = int(os.environ.get("TORCHSIM_FORCE_TIME_M", default=sys.maxsize)) -CONFIG_FORCE_TILE_N = int(os.environ.get("TORCHSIM_FORCE_TIME_N", default=sys.maxsize)) -CONFIG_FORCE_TILE_K = int(os.environ.get("TORCHSIM_FORCE_TIME_K", default=sys.maxsize)) # For GEMM tile size CONFIG_MANUAL_TILE_SIZE = int(os.environ.get('TORCHSIM_MANUAL_TILE_SIZE', default=False)) diff --git a/scripts/chiplet_prep.sh b/scripts/chiplet_prep.sh index 99fc9b30..cddf1a58 100755 --- a/scripts/chiplet_prep.sh +++ b/scripts/chiplet_prep.sh @@ -1,14 +1,13 @@ #!/bin/bash sizes=(256 512 1024 2048) -# 각 size에 대해 처리 for size in "${sizes[@]}"; do echo "Processing size: $size" - # 환경 변수 설정 - export TORCHSIM_FORCE_TIME_M=$((size / 2)) - export TORCHSIM_FORCE_TIME_K=$((size / 2)) - export TORCHSIM_FORCE_TIME_N=$((size / 2)) + # Set environment variables + export TORCHSIM_TILE_M=$((size / 2)) + export TORCHSIM_TILE_K=$((size / 2)) + export TORCHSIM_TILE_N=$((size / 2)) export TORCHSIM_DUMP_PATH=$(pwd)/chiplet_result/$size python3 chiplet_prep.py $size #python3 chiplet_run.py $(pwd)/chiplet_result From 903ff136b73d7ead01b15dd599ef90b3705dac08 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 21 Jul 2025 12:43:16 +0000 Subject: [PATCH 59/62] [TogSim] Update tile_stride logic --- AsmParser/onnx_utility.py | 3 +- AsmParser/tog_generator.py | 3 +- PyTorchSimBackend/include/Instruction.h | 16 +-- PyTorchSimBackend/include/TileGraphParser.h | 6 +- PyTorchSimBackend/src/Instruction.cc | 31 ++--- PyTorchSimBackend/src/TileGraphParser.cc | 146 ++++++++------------ 6 files changed, 83 insertions(+), 122 deletions(-) diff --git a/AsmParser/onnx_utility.py b/AsmParser/onnx_utility.py index d46e8347..4f76ef35 100644 --- a/AsmParser/onnx_utility.py +++ b/AsmParser/onnx_utility.py @@ -66,12 +66,13 @@ def __init__(self, tile_info, inst_list=list(), node_id=0): super().__init__(node_id) self.inst = inst_list self.torchsim_base_addr = tile_info["base_addr"] - self.torchsim_stride_list = tile_info["stride_list"] self.torchsim_tile_size = tile_info["tile_size"] + self.torchsim_tile_stride = tile_info["tile_stride"] self.torchsim_element_size = tile_info["element_size"] self.torchsim_tag_idx_list = tile_info["tag_idx_list"] self.torchsim_tag_stride_list = tile_info["tag_stride_list"] self.torchsim_loop_idx_list = tile_info["loop_idx_list"] + self.torchsim_loop_stride_list = tile_info["loop_stride_list"] self.torchsim_is_async = tile_info["is_async"] self.torchsim_indirect_mode = tile_info["indirect_mode"] diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 1dea2f8d..5f586d99 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -91,12 +91,13 @@ def _create_node(self, dump_data): elif node_type == self.DMANodeKind: tile_info = {} tile_info["base_addr"] = dump_data["base_address"] - tile_info["stride_list"] = dump_data["stride_list"] tile_info["tile_size"] = dump_data["tile_size"] + tile_info["tile_stride"] = dump_data["tile_stride"] tile_info["element_size"] = dump_data["element_size"] tile_info["tag_idx_list"] = dump_data["tag_idx_list"] tile_info["tag_stride_list"] = dump_data["tag_stride_list"] tile_info["loop_idx_list"] = dump_data["loop_idx_list"] + tile_info["loop_stride_list"] = dump_data["loop_stride_list"] tile_info["is_async"] = dump_data["is_async"] tile_info["indirect_mode"] = dump_data["indirect_mode"] is_write = dump_data["is_write"] diff --git a/PyTorchSimBackend/include/Instruction.h b/PyTorchSimBackend/include/Instruction.h index 84b17d7c..4c14dd81 100644 --- a/PyTorchSimBackend/include/Instruction.h +++ b/PyTorchSimBackend/include/Instruction.h @@ -22,9 +22,10 @@ std::string opcode_to_string(Opcode opcode); class Instruction : public std::enable_shared_from_this { public: Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, addr_type dram_addr, - std::vector tile_size, size_t precision, std::vector &idx_list, - std::vector &stride_list, std::vector tag_idx_list, std::vector tag_stride_list, - std::vector accum_tag_idx_list, std::vector loop_size_list); + std::vector tile_size, std::vector tile_stride, size_t precision, + std::vector tag_idx_list, std::vector tag_stride_list, + std::vector accum_tag_idx_list); + Instruction(Opcode opcode); void finish_instruction(); void add_child(std::shared_ptr child); bool check_ready() { return ready_counter == 0; } @@ -60,10 +61,6 @@ class Instruction : public std::enable_shared_from_this { bool load_indirect_index(const std::string& path, uint64_t*& indirect_index, const std::vector& tile_size); void set_trace_address(std::vector& trace_address) { _trace_address = trace_address; } size_t get_free_sram_size() { return _free_sram_size; } - void adjust_dram_address() { - int offset = std::inner_product(_idx_list.begin(), _idx_list.end(), _stride_list.begin(), 0); - dram_addr += offset * _precision; - } addr_type get_base_dram_address() { return dram_addr; } void set_free_sram_size(size_t sram_size) { _free_sram_size=sram_size; } void* get_owner() { return _owner; } @@ -73,7 +70,6 @@ class Instruction : public std::enable_shared_from_this { int get_compute_type() { return _compute_type; } void set_numa_id(int numa_id) { _numa_id = numa_id; } uint32_t get_numa_id() { return _numa_id; } - std::vector& get_idx_list() { return _idx_list; } std::vector& get_tag_idx_list() { return _tag_idx_list; } std::vector& get_tag_stride_list() { return _tag_stride_list; } std::vector& get_tag_id() { return _tag_key; } @@ -103,6 +99,7 @@ class Instruction : public std::enable_shared_from_this { size_t ready_counter; std::set> child_inst; std::vector tile_size; + std::vector tile_stride; size_t _tile_numel; size_t _nr_waiting_request=0; size_t _precision=0; @@ -110,13 +107,10 @@ class Instruction : public std::enable_shared_from_this { addr_type dram_addr; uint32_t _numa_id = 0; // For DMA instruction int _compute_type = 0; - std::vector _idx_list; - std::vector _stride_list; std::vector _tag_idx_list; std::vector _tag_stride_list; std::vector _tag_key; std::vector _accum_tag_idx_list; - std::vector _loop_size_list; std::vector _trace_address; std::string _addr_name; int _addr_id; diff --git a/PyTorchSimBackend/include/TileGraphParser.h b/PyTorchSimBackend/include/TileGraphParser.h index b5322b76..5b561127 100644 --- a/PyTorchSimBackend/include/TileGraphParser.h +++ b/PyTorchSimBackend/include/TileGraphParser.h @@ -175,17 +175,18 @@ class TileMemoryNode : public TileNode { std::string get_base_addr_name() { return _base_addr_name; } size_t get_precision() { return _element_size; } std::vector get_tile_size() { return _tile_size; } - std::vector& get_stride_list () { return _stride_list; } + std::vector& get_tile_stride() { return _tile_stride; } std::vector& get_tag_idx_list() { return _tag_idx_list; } std::vector& get_tag_stride_list() { return _tag_stride_list; } std::vector& get_loop_idx_list() { return _loop_idx_list; } + std::vector& get_loop_stride_list () { return _loop_stride_list; } bool is_async_node() { return _is_async; } bool is_indirect() { return _is_indirect; } void print_node() override; private: std::vector _tile_size; - std::vector _stride_list; + std::vector _tile_stride; size_t _element_size; bool _is_async; bool _is_indirect; @@ -193,6 +194,7 @@ class TileMemoryNode : public TileNode { std::vector _tag_idx_list; std::vector _tag_stride_list; std::vector _loop_idx_list; + std::vector _loop_stride_list; }; class TileMemoryWaitNode : public TileNode { diff --git a/PyTorchSimBackend/src/Instruction.cc b/PyTorchSimBackend/src/Instruction.cc index b706ca8f..aef9079c 100644 --- a/PyTorchSimBackend/src/Instruction.cc +++ b/PyTorchSimBackend/src/Instruction.cc @@ -11,23 +11,22 @@ std::string opcode_to_string(Opcode opcode) { } Instruction::Instruction(Opcode opcode, cycle_type compute_cycle, size_t num_parents, - addr_type dram_addr, std::vector tile_size, size_t precision, - std::vector& idx_list, std::vector& stride_list, + addr_type dram_addr, std::vector tile_size, std::vector tile_stride, size_t precision, std::vector tag_idx_list, std::vector tag_stride_list, - std::vector accum_tag_idx_list, std::vector loop_size_list) + std::vector accum_tag_idx_list) : opcode(opcode), compute_cycle(compute_cycle), ready_counter(num_parents), dram_addr(dram_addr), - tile_size(tile_size), _precision(precision), _idx_list(idx_list), - _stride_list(stride_list), _tag_idx_list(tag_idx_list), _tag_stride_list(tag_stride_list), - _accum_tag_idx_list(accum_tag_idx_list), _loop_size_list(loop_size_list) { + tile_size(tile_size), tile_stride(tile_stride), _precision(precision), + _tag_idx_list(tag_idx_list), _tag_stride_list(tag_stride_list), + _accum_tag_idx_list(accum_tag_idx_list) { assert(_tag_idx_list.size()==_tag_stride_list.size()); _tile_numel = 1; for (auto dim : tile_size) _tile_numel *= dim; +} - /* Supporting vector */ - if (_stride_list.size() == 1) { - _stride_list.push_back(1); - } +Instruction::Instruction(Opcode opcode) + : opcode(opcode) { + _tile_numel = 1; } void Instruction::finish_instruction() { @@ -73,8 +72,8 @@ std::shared_ptr> Instruction::get_dram_address(addr_type dra while (tile_size.size() < 4) tile_size.insert(tile_size.begin(), 1); - while (_stride_list.size() < 4) - _stride_list.insert(_stride_list.begin(), 0); + while (tile_stride.size() < 4) + tile_stride.insert(tile_stride.begin(), 0); if (_is_indirect_mode) { spdlog::trace("[Indirect Access] Indirect mode, dump_path: {}", _indirect_index_path); load_indirect_index(_indirect_index_path, indirect_index, tile_size); @@ -85,10 +84,10 @@ std::shared_ptr> Instruction::get_dram_address(addr_type dra for (int dim1=0; dim1> TileLoopNode::get_tiles_from_iter(TileGraphPa for (auto& tile_node: _body_node) { if (tile_node->get_type() == TileType::LOAD_NODE) { std::shared_ptr mem_node = std::static_pointer_cast(tile_node); - auto base_addr_name = mem_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); - std::vector& tag_idx_list = mem_node->get_tag_idx_list(); - std::vector& tag_stride_list = mem_node->get_tag_stride_list(); - std::vector skip_idx_list; - std::vector values; - - /* Lookup given name's address */ - addr_type base_addr = tog_parser->lookup(base_addr_name); std::vector iter_list; - std::vector tag_list; - std::vector accum_tag_list; - std::vector loop_size_list; - std::vector outer_loop_idx; - std::vector outer_loop_size; int nr_inner_loop = 0; auto& loop_idx_list = mem_node->get_loop_idx_list(); for (auto loop_idx: loop_idx_list) { - auto iter_value = getLoopIndexValue(iter, loop_idx); + int iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); - loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; } + + /* Base address setting */ + std::string base_addr_name = mem_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); + addr_type base_addr = tog_parser->lookup(base_addr_name); + addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); + + std::vector tag_list; + std::vector accum_tag_list; + std::vector outer_loop_idx; + std::vector outer_loop_size; /* Add accumulation loop info to accum_tag list */ for (auto loop_idx = loop_idx_list.begin(); loop_idx != loop_idx_list.end() - nr_inner_loop; ++loop_idx) { @@ -387,7 +387,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } uint32_t systolic_size = std::stoi(tog_parser->getMetaByName("systolic_size")); - for (auto loop_idx: tag_idx_list) { + for (auto loop_idx: mem_node->get_tag_idx_list()) { if (iter.find(loop_idx) == iter.end()) tag_list.push_back(0); else { @@ -406,25 +406,32 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa int stride_idx = calculateAddress(outer_loop_size, tog_parser->lookupNumaInfo(base_addr_name)); numa_id = total_idx / stride_idx; } + /* Check need to make this memory node */ + std::vector& tag_stride_list = mem_node->get_tag_stride_list(); std::vector key = tog_parser->calc_tag(accum_tag_list, tag_list, tag_stride_list); if (tog_parser->check_memory_tag(base_addr_name, key)) continue; tog_parser->register_memory_tag(base_addr_name, key); printIndexMap("[TOGParser] Load Node " + mem_node->get_base_addr_name() + " Numa_id: " + std::to_string(numa_id), iter); + spdlog::trace("[TOGParser] Load Node {} key = [{}], accum = [{}], tag = [{}], stride = [{}]", mem_node->get_base_addr_name(), + fmt::join(key, ", "), + fmt::join(accum_tag_list, ", "), + fmt::join(tag_list, ", "), + fmt::join(tag_stride_list, ", ")); std::shared_ptr inst = std::make_shared( Opcode::MOVIN, 0, - 0, base_addr, - mem_node->get_tile_size(), mem_node->get_precision(), iter_list, - mem_node->get_stride_list(), tag_list, tag_stride_list, accum_tag_list, loop_size_list + 0, base_addr+offset, + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + tag_list, tag_stride_list, accum_tag_list ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); inst->set_nr_inner_loop(nr_inner_loop); - inst->adjust_dram_address(); inst->set_is_async(mem_node->is_async_node()); inst->set_numa_id(numa_id); + if (mem_node->is_indirect()) { inst->set_indirect_index_path(tog_parser->get_indirect_path()); tog_parser->inc_indirect_counter(); @@ -439,14 +446,7 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa tile_vec.back()->append_instuction(inst); } else if (tile_node->get_type() == TileType::STORE_NODE) { std::shared_ptr mem_node = std::static_pointer_cast(tile_node); - auto base_addr_name = mem_node->get_base_addr_name(); - int base_addr_id = tog_parser->register_addr_name(base_addr_name); - /* Lookup given name's address */ - addr_type base_addr = tog_parser->lookup(base_addr_name); - std::vector& tag_stride_list = mem_node->get_tag_stride_list(); - std::vector accum_tag_list; std::vector iter_list; - std::vector loop_size_list; std::vector outer_loop_idx; std::vector outer_loop_size; int nr_inner_loop = 0; @@ -454,7 +454,6 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa for (auto loop_idx: loop_idx_list) { auto iter_value = getLoopIndexValue(iter, loop_idx); iter_list.push_back(iter_value); - loop_size_list.push_back(tog_parser->get_loop_size(loop_idx)); if (tog_parser->get_loop_type(loop_idx)==LoopType::INNER_LOOP) nr_inner_loop++; if (tog_parser->get_loop_type(loop_idx)==LoopType::PARALLEL_LOOP) { @@ -465,6 +464,12 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } } + /* Lookup given name's address */ + std::string base_addr_name = mem_node->get_base_addr_name(); + int base_addr_id = tog_parser->register_addr_name(base_addr_name); + addr_type base_addr = tog_parser->lookup(base_addr_name); + addr_type offset = std::inner_product(iter_list.begin(), iter_list.end(), mem_node->get_loop_stride_list().begin(), 0); + /* Calc numa id */ int numa_id = 0; auto numa_stride_size = tog_parser->lookupNumaInfo(base_addr_name).size(); @@ -477,14 +482,13 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa printIndexMap("[TOGParser] Store Node " + mem_node->get_base_addr_name() + " Numa_id: " + std::to_string(numa_id), iter); std::shared_ptr inst = std::make_shared( Opcode::MOVOUT, 0, - 0, base_addr, - mem_node->get_tile_size(), mem_node->get_precision(), iter_list, - mem_node->get_stride_list(), std::vector(1), tag_stride_list, accum_tag_list, loop_size_list + 0, base_addr+offset, + mem_node->get_tile_size(), mem_node->get_tile_stride(), mem_node->get_precision(), + std::vector(1), mem_node->get_tag_stride_list(), std::vector() ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); inst->set_nr_inner_loop(nr_inner_loop); - inst->adjust_dram_address(); inst->set_is_async(mem_node->is_async_node()); inst->set_numa_id(numa_id); if (mem_node->is_indirect()) { @@ -530,11 +534,16 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa new_tag_stride_list.push_back(i); } + spdlog::trace("[TOGParser] Wait Node {}, accum = [{}], tag = [{}], stride = [{}]", wait_node->get_base_addr_name(), + fmt::join(accum_tag_list, ", "), + fmt::join(tag_list, ", "), + fmt::join(new_tag_stride_list, ", ")); + std::shared_ptr inst = std::make_shared( Opcode::BAR, 0, 0, base_addr, - std::vector(), 0, iter_list, - iter_list, tag_list, new_tag_stride_list, accum_tag_list, std::vector() + std::vector(), std::vector(), 0, + tag_list, new_tag_stride_list, accum_tag_list ); inst->set_addr_name(base_addr_name, base_addr_id); inst->prepare_tag_key(); @@ -543,15 +552,14 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } else if (tile_node->get_type() == TileType::COMPUTE_NODE) { printIndexMap("[TOGParser] Compute Node ", iter); std::shared_ptr compute_node = std::static_pointer_cast(tile_node); - std::vector iter_list; std::vector tag_list = {0}; std::vector tag_stride_list = {1}; std::vector accum_tag_list; std::shared_ptr inst = std::make_shared( Opcode::COMP, compute_node->get_cycle(), 0, 0, - std::vector(), 0, iter_list, iter_list, - tag_list, tag_stride_list, accum_tag_list, std::vector() + std::vector(), std::vector(), 0, + tag_list, tag_stride_list, accum_tag_list ); inst->set_overlapping_cycle(compute_node->get_overlapping_cycle()); inst->set_compute_type(compute_node->get_compute_type()); @@ -620,72 +628,28 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa } else if (tile_node->get_type() == TileType::STONNE_NODE) { printIndexMap("[TOGParser] Stonne Node ", iter); std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - /* Put dummy computation instruction */ - std::shared_ptr inst = std::make_shared( - Opcode::COMP, 0, - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::COMP); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); tile_vec.back()->set_custom_data(stonne_node->getDesc()); tile_vec.back()->set_stonne_tile(true); } else if (tile_node->get_type() == TileType::STONNE_TRACE_COMPUTE_NODE) { std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - std::shared_ptr inst = std::make_shared( - Opcode::COMP, stonne_node->get_cycle(), - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::COMP); + inst->set_compute_cycle(stonne_node->get_cycle()); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); tile_vec.back()->set_stonne_tile(true); } else if (tile_node->get_type() == TileType::STONNE_TRACE_LOAD_NODE) { std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - std::shared_ptr inst = std::make_shared( - Opcode::MOVIN, 0, - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::MOVIN); inst->set_trace_address(stonne_node->get_address()); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); tile_vec.back()->set_stonne_tile(true); } else if (tile_node->get_type() == TileType::STONNE_TRACE_STORE_NODE) { std::shared_ptr stonne_node = std::static_pointer_cast(tile_node); - /* Lookup given name's address */ - std::vector iter_list; - std::vector tag_list; - std::vector tag_stride_list; - std::vector accum_tag_list; - - std::shared_ptr inst = std::make_shared( - Opcode::MOVOUT, 0, - 0, 0, - std::vector(), 0, iter_list, - iter_list, tag_list, tag_stride_list, accum_tag_list, std::vector() - ); + std::shared_ptr inst = std::make_shared(Opcode::MOVOUT); inst->set_trace_address(stonne_node->get_address()); link_map[tile_node] = inst; tile_vec.back()->append_instuction(inst); From 94b13e1a71893ba93fcd79b31deeac7d18cb3b1d Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 21 Jul 2025 12:45:46 +0000 Subject: [PATCH 60/62] [Frontend] Make dma tag unique --- PyTorchSimFrontend/extension_config.py | 2 +- .../mlir/mlir_codegen_backend.py | 19 +++++++++++-------- PyTorchSimFrontend/mlir/mlir_template.py | 9 ++++----- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index d60826a1..8994cffe 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -37,7 +37,7 @@ # Backendsim config CONFIG_TORCHSIM_BACKEND_CONFIG = os.environ.get('TORCHSIM_CONFIG', default=f'{CONFIG_TORCHSIM_DIR}/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json') -CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", True)) +CONFIG_BACKENDSIM_SPIKE_ONLY = int(os.environ.get("BACKENDSIM_SPIKE_ONLY", False)) CONFIG_BACKENDSIM_EAGER_MODE = int(os.environ.get("BACKENDSIM_EAGER_MODE", default=False)) CONFIG_BACKENDSIM_DRYRUN = int(os.environ.get('BACKENDSIM_DRYRUN', default=False)) CONFIG_BACKENDSIM_DEBUG_LEVEL = os.environ.get("BACKENDSIM_DEBUG_LEVEL", "") diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 79d735a3..725fec5d 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -877,6 +877,7 @@ def __init__(self, kernel_group, reason=None): self.spadbuf_counter = 0 self.dma_read_counter = 1 self.dma_write_counter = 1 + self.dma_tag_id = 0 self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} @@ -1028,7 +1029,7 @@ def load(self, name: str, index: sympy.Expr): # MVIN Encoding attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding={padding}}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, attribute) + dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) # FIXME: assignment = False does not support caching compute_index_var = ",".join(sram_index_var.split(",")[:-1] + [f"%{self.compute_idx}"]) # Generate vector load instruction @@ -1090,7 +1091,7 @@ def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, attribute) + dram_shape, tile_shape, attribute) self.dma_stores.writeline(common.DeferredLine(name, code)) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -1243,7 +1244,7 @@ def store_reduction(self, name, index, value): # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, attribute) + dram_shape, tile_shape, attribute) self.reductions_suffix.writeline(common.DeferredLine(name, code)) # Restore origin cse @@ -1655,7 +1656,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe return local_tile_desc, index_var, dram_stride def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, dram_index_var, sram_var, sram_index_var, - tag_name, dram_shape, tile_shape, attribute): + dram_shape, tile_shape, attribute): dma_key = (vlane_split_axis, vlane_stride, mlir_dtype) if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: dma_type, vlane_split_axis, vlane_stride = self.dma_read_cache[dma_key] @@ -1670,9 +1671,8 @@ def get_dma_code(self, dma_type_name, vlane_split_axis, vlane_stride, mlir_dtype self.dma_read_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] else: dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) - # self.dma_write_counter += 1 Is it okay? self.dma_write_cache[dma_key] = [dma_type, vlane_split_axis, vlane_stride] - tag = self.get_tag_cse(tag_name) + tag = self.get_tag_cse() zero_cse = self.get_const_cse(0) # Prepare opearnds and attributes @@ -1742,9 +1742,12 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable: self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") return self.consts[str(value)+dtype] - def get_tag_cse(self, value, shape="memref<1xi32>"): + def get_tag_cse(self, value=None, shape="memref<1xi32>"): + if value is None: + value = self.dma_tag_id + self.dma_tag_id += 1 if value not in self.tags: - self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape}") + self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape} // {value}") return self.tags[value] def get_mask(self): diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 1da2e755..0455cbf1 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -663,7 +663,6 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com # Prepare code block local_code = IndentedBuffer() with V.set_kernel_handler(self): - tag = f"mvint_{self.dma_read_counter}" if dma_type == "MVIN" else f"mvoutt_{self.dma_write_counter}" index_var = self.parse_index_list(index_list, local_code) node_layout = self.named_nodes[dram_var].get_layout() numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() @@ -696,7 +695,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") attribute = " {" + ", ".join(attribute_parts) + "}" code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - tag, dram_shape, tile_shape, "") + dram_shape, tile_shape, "") local_code.writeline(code) local_code.writeline(attribute) return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() @@ -749,7 +748,7 @@ def load_epilogue(self, name: str, index: sympy.Expr): sram_var, sram_index_var = self.get_scratchpad_buffer(dtype, name, self.kernel_group.tile_desc, index) attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVIN", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, attribute) + dram_shape, tile_shape, attribute) self.cse.generate(self.dma_loads, code, assignment = False) self.buffer_names[name] = sram_var else: @@ -831,7 +830,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, tile_shape, attribute) + dram_shape, tile_shape, attribute) self.dma_stores.writeline(DeferredLine(name, code)) def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): @@ -991,7 +990,7 @@ def store_reduction_epilogue(self, name, index, value): # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}" code = self.get_dma_code("MVOUT", vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - f"{name}_tag", dram_shape, final_tile_shape, attribute) + dram_shape, final_tile_shape, attribute) self.reductions_suffix.writeline(DeferredLine(name, code)) def set_tile_size(self, template_fusion_info, prologue=False): From 737ed02eac858481e40201655c55b34b17645193 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 21 Jul 2025 15:14:59 +0000 Subject: [PATCH 61/62] [TOGSim] Handle edge case tag matching --- PyTorchSimBackend/src/TileGraphParser.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/PyTorchSimBackend/src/TileGraphParser.cc b/PyTorchSimBackend/src/TileGraphParser.cc index 0f3e2ce9..12056f94 100644 --- a/PyTorchSimBackend/src/TileGraphParser.cc +++ b/PyTorchSimBackend/src/TileGraphParser.cc @@ -375,6 +375,10 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa accum_tag_list.push_back(iter_value); } } + /* Default accum tag */ + if (accum_tag_list.empty()) { + accum_tag_list.push_back(0); + } for (auto loop_idx = loop_idx_list.begin(); loop_idx != loop_idx_list.end(); ++loop_idx) { @@ -527,6 +531,10 @@ std::vector> TileLoopNode::get_tiles_from_iter(TileGraphPa tag_list.push_back(iter_value); } } + /* Default accum tag */ + if (accum_tag_list.empty()) { + accum_tag_list.push_back(0); + } /* Skip accum stride */ for (auto i : tag_stride_list) { From b18bcc0f78bad09e642f208811a4c53e6054e98b Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 21 Jul 2025 15:15:37 +0000 Subject: [PATCH 62/62] [Frontend/Fusion] Do not allow prologue fusion for CONV --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 25 ++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index e037207f..f1c72c44 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -24,6 +24,7 @@ def __init__(self, scheduler): self.scheduler = scheduler self.scheduler.can_fuse_origin = self.scheduler.can_fuse self.scheduler.can_fuse = self.can_fuse_with_exceptions + #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False self.outer_function = set() @@ -67,9 +68,14 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule return False if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME return False - # We don't fuse this case... - if (isinstance(target_node.template, MLIRBMMTemplate) or isinstance(target_node.template, MLIRGemmTemplate)) and base_template_node2[0].group[1][0][0] == 1: - return False + + # Currently only BMM, MM support prologue fusion + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + return False + # We don't fuse this edge case... + if base_template_node2[0].group[1][0][0] == 1: + return False + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: node1 = self.revert_group(node1) return True @@ -368,4 +374,15 @@ def codegen_template(self, template_node, epilogue_nodes): V.graph.wrapper_code.writeline( f"yield ({target_kernel_name}, ({args}))" ) - self._set_flush_status(True) \ No newline at end of file + self._set_flush_status(True) + + def enter_context_fixed(self, node): + def get_order(n): + if n not in self.scheduler.origin_to_index: + self.scheduler.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.scheduler.origin_to_index[n] + + origins = [(get_order(e), idx, e) for n in node.get_nodes() for idx, e in enumerate(n.node.origins)] + if origins: + _, _, last = max(origins) + V.graph.wrapper_code.enter_context(last)