From f10038e4c49c93c921abe135cc1b4144b94ac16e Mon Sep 17 00:00:00 2001 From: Yunseon Shin Date: Fri, 3 Oct 2025 07:56:15 +0000 Subject: [PATCH 1/5] [Frontend] Template autotune --- PyTorchSimFrontend/extension_config.py | 2 + .../mlir/mlir_codegen_backend.py | 18 +- .../mlir/mlir_conv_mt_template.py | 2 +- .../mlir/mlir_conv_sb_template.py | 2 +- .../mlir/mlir_conv_sbs_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_conv_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 2 +- PyTorchSimFrontend/mlir/mlir_scheduling.py | 97 +-------- PyTorchSimFrontend/mlir/mlir_template.py | 184 ++++++++++++++++-- 9 files changed, 184 insertions(+), 127 deletions(-) diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 80675682..7eddfcb9 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -46,7 +46,9 @@ # AUTOTUNE config CONFIG_AUTOTUNE = int(os.environ.get('AUTOTUNE', default=True)) +CONFIG_AUTOTUNE_TEMPLATE = int(os.environ.get('AUTOTUNE_TEMPLATE', default=True)) CONFIG_MAX_AUTOTUNE_TRY = int(os.environ.get('MAX_AUTOTUNE_TRY', default=10)) +CONFIG_AUTOTUNE_TOPK = int(os.environ.get('AUTOTUNE_TOPK', default=3)) # For block sparse CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0)) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 09ee129b..d54963c2 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -17,8 +17,7 @@ sympy_product ) from torch.utils._sympy.functions import ModularIndexing, FloorDiv -import PyTorchSimFrontend.extension_codecache as extension_codecache - +from PyTorchSimFrontend import extension_codecache from PyTorchSimFrontend import extension_config from . import mlir_common from .mlir_common import LoopLevel, LoopNest @@ -1608,9 +1607,9 @@ def make_choices(self, nodes, kernel_name): self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold return choices - def autotune(self, nodes, kernel_name): + def autotune(self, *args): def get_cycle(choice): - bench_runner, src_code, kernel_group = choice + bench_runner = choice[0] for n_try in range(extension_config.CONFIG_MAX_AUTOTUNE_TRY): # TODO: make simple try: # bench_runner = self.run_bench(nodes, kernel_name, src_code) @@ -1619,7 +1618,7 @@ def get_cycle(choice): except (extension_codecache.SpadOverflowError, RuntimeError) as e: return float("inf") return float("inf") # Exceeded maximum number of autotuning attempts - choices = self.make_choices(nodes, kernel_name) + choices = self.make_choices(*args) if len(choices) == 0: # can't autotune return None @@ -1635,14 +1634,11 @@ def get_cycle(choice): def codegen_nodes(self, nodes, kernel_name): src_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) - if not extension_config.CONFIG_AUTOTUNE or extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: - return src_code - else: + if extension_config.CONFIG_AUTOTUNE and extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: optimal_src_code = self.autotune(nodes, kernel_name) - if optimal_src_code: + if optimal_src_code is not None: return optimal_src_code - else: - return src_code + return src_code def _prepare_simulator_headers(self, src_code): write_path = extension_codecache.get_write_path(src_code) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index 6dd17576..ddbdf793 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -342,7 +342,7 @@ def compute_stride(shape): def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) if not os.path.exists(write_path): - os.makedirs(write_path) + os.makedirs(write_path, exist_ok=True) spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 8b1bf7c5..46cdb4d0 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -338,7 +338,7 @@ def compute_stride(shape): def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) if not os.path.exists(write_path): - os.makedirs(write_path) + os.makedirs(write_path, exist_ok=True) spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 2284c86c..006d5112 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -339,7 +339,7 @@ def compute_stride(shape): def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) if not os.path.exists(write_path): - os.makedirs(write_path) + os.makedirs(write_path, exist_ok=True) spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 890b76b7..c744258c 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -346,7 +346,7 @@ def compute_stride(shape): def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) if not os.path.exists(write_path): - os.makedirs(write_path) + os.makedirs(write_path, exist_ok=True) spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index ae793c06..119debd9 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -334,7 +334,7 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no def codegen_header(self, code, extra_headers): write_path = extension_codecache.get_write_path(code) if not os.path.exists(write_path): - os.makedirs(write_path) + os.makedirs(write_path, exist_ok=True) spike_write_path = os.path.join(write_path, "global_var.h") gem5_write_path = os.path.join(write_path, "gem5_global_var.h") if not os.path.exists(spike_write_path): diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 2bbdb41d..7b7b179b 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -4,8 +4,11 @@ from functools import reduce import operator from sympy import symbols, sympify, Symbol +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel +from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from torch._inductor import config from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode @@ -259,85 +262,6 @@ def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) return kernel_name - def codegen_template_code(self, kernel, render, template_node, prologue_nodes, epilogue_nodes): - with kernel: - _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - for node in [template_node, *prologue_nodes, *epilogue_nodes]: - node.mark_run() - # Partial codgen template nodes - partial_code = render() - - # Swap load/store functions - kernel.load = kernel.load_epilogue - kernel.store = kernel.store_epilogue - kernel.store_reduction = kernel.store_reduction_epilogue - kernel.reduction = kernel.reduction_epilogue - - # Codegen prologue nodes - if prologue_nodes: - # Flush created varaibles, since template fusion doen't share variable - with kernel.prologue_buffer_group.as_local(): - _, (group, reduction_group) = max( - [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) - ).group - prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) - kernel.kernel_group.set_tile_info(prologue_tile_desc) - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - for node in prologue_nodes: - # Reuse created spad - read_list = sorted([i.name for i in node.read_writes.reads]) - candidate_found = False - # Why? There is a case that memdep.get_size() != data.get_size() - buf_dict = {} - buf_dict.update({val.name : val for val in V.graph.buffers}) - buf_dict.update(V.graph.graph_inputs) - for candidate_read in read_list: - if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel(): - prologue_input_arg = candidate_read - candidate_found = True - break - assert(candidate_found) - assert(len(node.read_writes.writes)==1) - prologue_output_arg = list(node.read_writes.writes)[0].name - template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] - target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? - - # To skip the dma code gen - kernel.buffer_names[prologue_input_arg] = target_buf - kernel.buffer_names[prologue_output_arg] = target_buf - - # Edge delete - kernel.kernel_group.args.input_buffers = { - (arg if buf != template_buf else prologue_input_arg): buf - for arg, buf in kernel.kernel_group.args.input_buffers.items() - } - node.codegen((vars, reduction_vars)) - - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None - if epilogue_nodes: - with kernel.epilogue_buffer_group.as_local(): - _, (group, reduction_group) = max( - epilogue_nodes, key=lambda x: int(x.is_reduction()) - ).group - vars, reduction_vars = kernel.set_ranges(group, reduction_group) - for node in epilogue_nodes: - node.codegen((vars, reduction_vars)) - - with V.set_kernel_handler(kernel): - src_code = ( - partial_code - if isinstance(partial_code, str) - else partial_code.finalize() - ) - - # For consistency, white space could make wrong write_path - buffer = IndentedBuffer() - buffer.splice(src_code) - return buffer.getvalue() - def codegen_template(self, template_node, epilogue_nodes): # Handle prologue pattern prologue_nodes = [] @@ -350,24 +274,13 @@ def codegen_template(self, template_node, epilogue_nodes): epilogue_nodes = epilogue_nodes[i+1:] break - _, (numel, rnumel) = template_node.group + # Generate template code template_buffer = template_node.node kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - - src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) - wrapper = V.graph.wrapper_code - - if src_code in wrapper.src_to_kernel: # [CONV] check inner function is already defined - kernel_name = wrapper.src_to_kernel[src_code] - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_name=kernel_name) # update kernel name - src_code = self.codegen_template_code(kernel, render, template_node, prologue_nodes, epilogue_nodes) + src_code = kernel.codegen_nodes(render, codegen_header, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(kernel): - spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" - spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({kernel.spad_info['spad_size']*kernel.vector_lane})));" - codegen_header(src_code, (kernel.header.getvalue()+spad_end_symbol+spad_section_end_symbol, kernel.gem5_header.getvalue())) - kernel.meta_kernel() kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) self.define_function(kernel) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 4f75dd84..762d2a93 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -6,6 +6,8 @@ import contextlib import math import sympy +from functools import reduce +import operator from collections import OrderedDict from typing import List, Optional @@ -25,7 +27,7 @@ from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode from torch._inductor.codegen import common -from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR +from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR, CONFIG_AUTOTUNE_TEMPLATE, CONFIG_AUTOTUNE, CONFIG_BACKENDSIM_SPIKE_ONLY from . import mlir_common class IndentedBufferGroup: @@ -93,7 +95,8 @@ def __init__(self, kernel_group = None, outer_func_name=None, outer_func_render=None, - kernel_arg_attributes=None) -> None: + kernel_arg_attributes=None, + reason=None) -> None: super().__init__(kernel_group if kernel_group is not None else mlir_common.MLIRWrapperKenrelGroup()) self.kernel_name = kernel_name self.input_nodes = input_nodes @@ -125,6 +128,16 @@ def __init__(self, self.reduction_mean = [] # Dim info self.dim_aliasing = {} + self.autotune_idx = 0 + self.reason = reason + + def reset(self, reason): + self.__init__( + self.kernel_name, self.input_nodes, + self.call_size, self.kernel_group, + self.outer_func_name, self.outer_func_render, + self.kernel_arg_attributes, reason + ) def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): @@ -185,7 +198,8 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer @@ -249,6 +263,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N mapping = (tile_M, tile_N, tile_K) + if check_spad_size: + tile_candidates.append((used_spad_size, (tile_M, tile_N, tile_K))) + if CONFIG_AUTOTUNE_TEMPLATE and not is_conv: + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping return mapping def search_mapping_space(self, mapping, idx, increment, stride, dilation, n_extra_node=0): @@ -288,13 +307,14 @@ def pseudo_auto_tune(self, mapping, stride, dilation, O_H, O_W, n_extra_node=0): return mapping def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 # double buffer max_spad_per_lane = spad_size_per_lane // 2 # double buffer max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False) + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True) max_k_h_w = 1 # maximize kernel size max_o_h_w = 1 # maximize output size K = min(K, self.vector_lane) @@ -312,27 +332,34 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h * k_w and max_o_h_w <= o_h * o_w: - max_used_spad_size = used_spad_size - max_k_h_w = k_h * k_w - max_o_h_w = o_h * o_w - mapping = (k_h, k_w, o_h, o_w, M, N, K) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + tile_candidates.append((used_spad_size, (k_h, k_w, o_h, o_w, M, N, K))) + if max_used_spad_size < used_spad_size and max_k_h_w <= k_h * k_w and max_o_h_w <= o_h * o_w: + max_used_spad_size = used_spad_size + max_k_h_w = k_h * k_w + max_o_h_w = o_h * o_w + mapping = (k_h, k_w, o_h, o_w, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") # FIXME: this should be implemented with auto-tuning mapping = self.pseudo_auto_tune(mapping, stride, dilation, O_H, O_W, n_extra_node=n_extra_node) + if CONFIG_AUTOTUNE_TEMPLATE: + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping return mapping def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False) + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True) max_k_h_w = K_W for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): @@ -347,22 +374,29 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h: - max_used_spad_size = used_spad_size - max_k_h_w = k_h - mapping = (k_h, K_W, o_h, o_w, M, N, K) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + tile_candidates.append((used_spad_size, (k_h, K_W, o_h, o_w, M, N, K))) + if max_used_spad_size < used_spad_size and max_k_h_w <= k_h: + max_used_spad_size = used_spad_size + max_k_h_w = k_h + mapping = (k_h, K_W, o_h, o_w, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") + if CONFIG_AUTOTUNE_TEMPLATE: + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping return mapping def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane max_spad_size = spad_size // 2 max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False) + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True) max_k_h_w = 1 for o_h in sympy.divisors(O_H): for k_h in sympy.divisors(K_H): @@ -377,12 +411,18 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * k_w, K) output_size_per_lane = self.get_spad_size_per_lane(M * o_h * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and max_used_spad_size < used_spad_size and used_spad_size_per_lane < max_spad_per_lane and max_k_h_w <= k_h * k_w: - max_used_spad_size = used_spad_size - max_k_h_w = k_h * k_w - mapping = (k_h, k_w, o_h, M, M, N, K) + check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + if check_spad_size: + tile_candidates.append((used_spad_size, (k_h, k_w, o_h, M, M, N, K))) + if max_used_spad_size < used_spad_size and max_k_h_w <= k_h * k_w: + max_used_spad_size = used_spad_size + max_k_h_w = k_h * k_w + mapping = (k_h, k_w, o_h, M, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") + if CONFIG_AUTOTUNE_TEMPLATE: + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping return mapping def meta_kernel(self): @@ -407,6 +447,112 @@ def call_kernel(self, kernel_name): kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) + def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes): + with self as kernel: + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() + for node in [template_node, *prologue_nodes, *epilogue_nodes]: + node.mark_run() + + # Partial codgen template nodes + partial_code = render() + + # Swap load/store functions + kernel.load = kernel.load_epilogue + kernel.store = kernel.store_epilogue + kernel.store_reduction = kernel.store_reduction_epilogue + kernel.reduction = kernel.reduction_epilogue + + # Codegen prologue nodes + if prologue_nodes: + # Flush created varaibles, since template fusion doen't share variable + with kernel.prologue_buffer_group.as_local(): + _, (group, reduction_group) = max( + [prologue_nodes[-1]], key=lambda x: int(x.is_reduction()) + ).group + prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) + kernel.kernel_group.set_tile_info(prologue_tile_desc) + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in prologue_nodes: + # Reuse created spad + read_list = sorted([i.name for i in node.read_writes.reads]) + candidate_found = False + # Why? There is a case that memdep.get_size() != data.get_size() + buf_dict = {} + buf_dict.update({val.name : val for val in V.graph.buffers}) + buf_dict.update(V.graph.graph_inputs) + for candidate_read in read_list: + if candidate_read in buf_dict and reduce(operator.mul, buf_dict[candidate_read].get_size(), 1) == node.node.get_numel(): + prologue_input_arg = candidate_read + candidate_found = True + break + assert(candidate_found) + assert(len(node.read_writes.writes)==1) + prologue_output_arg = list(node.read_writes.writes)[0].name + template_buf = self.kernel_group.args.input_buffers[prologue_output_arg] + target_buf = f"{template_buf}_buffer" # FIXME. How to pass spad buffer name? + + # To skip the dma code gen + kernel.buffer_names[prologue_input_arg] = target_buf + kernel.buffer_names[prologue_output_arg] = target_buf + + # Edge delete + kernel.kernel_group.args.input_buffers = { + (arg if buf != template_buf else prologue_input_arg): buf + for arg, buf in kernel.kernel_group.args.input_buffers.items() + } + node.codegen((vars, reduction_vars)) + + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None + if epilogue_nodes: + with kernel.epilogue_buffer_group.as_local(): + _, (group, reduction_group) = max( + epilogue_nodes, key=lambda x: int(x.is_reduction()) + ).group + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + for node in epilogue_nodes: + node.codegen((vars, reduction_vars)) + + with V.set_kernel_handler(kernel): + src_code = ( + partial_code + if isinstance(partial_code, str) + else partial_code.finalize() + ) + + # For consistency, white space could make wrong write_path + buffer = IndentedBuffer() + buffer.splice(src_code) + return buffer.getvalue() + + def make_choices(self, render, template_node, prologue_nodes, epilogue_nodes): + choices = [] + for i in range(3): + self.autotune_idx = i + self.reset(reason=None) + src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes) + bench_runner = self.run_bench([template_node], self.kernel_name, src_code) + choices.append((bench_runner, src_code, self.kernel_group)) + return choices + + def codegen_nodes(self, render, codegen_header, template_node, prologue_nodes, epilogue_nodes): + src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes) + + if False:# CONFIG_AUTOTUNE_TEMPLATE and not CONFIG_BACKENDSIM_SPIKE_ONLY: + src_code = self.autotune(render, template_node, prologue_nodes, epilogue_nodes) + + with V.set_kernel_handler(self): + self._prepare_simulator_headers(src_code, codegen_header) + self.meta_kernel() + return src_code + + def _prepare_simulator_headers(self, src_code, codegen_header): + spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" + spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({self.spad_info['spad_size']*self.vector_lane})));" + codegen_header(src_code, (self.header.getvalue()+spad_end_symbol+spad_section_end_symbol, self.gem5_header.getvalue())) + def codegen_prologue_body(self): body = IndentedBuffer() with self.prologue_buffer_group.as_local(): From fe22e9b133b3da2e6686bc9053f38c5815c25511 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 5 Nov 2025 03:20:54 +0000 Subject: [PATCH 2/5] [Cleanup] Remove codegen_headers --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 25 ++--- .../mlir/mlir_conv_mt_template.py | 21 +--- .../mlir/mlir_conv_sb_template.py | 21 +--- .../mlir/mlir_conv_sbs_template.py | 21 +--- PyTorchSimFrontend/mlir/mlir_conv_template.py | 22 +---- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 95 ++++++++++--------- .../mlir/mlir_maxpool_template.py | 13 --- PyTorchSimFrontend/mlir/mlir_scheduling.py | 4 +- PyTorchSimFrontend/mlir/mlir_template.py | 52 ++++++---- 9 files changed, 105 insertions(+), 169 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 79e03bd5..0c6583a7 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -6,8 +6,6 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common BMM_TEMPLATE = r""" @@ -184,14 +182,10 @@ def render(self, # Select tile size n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, len(prologue_nodes)) TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] - TILE_K = TILE_K // 2 if prologue_nodes else TILE_K # Select template code nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] @@ -329,13 +323,10 @@ def render(self, kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + TILE_K = TILE_K // 2 if n_prologue_node else TILE_K + return TILE_M,TILE_N,TILE_K,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index ddbdf793..26657712 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -7,10 +7,7 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" @@ -185,8 +182,9 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -294,8 +292,7 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] @@ -338,15 +335,3 @@ def compute_stride(shape): arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path, exist_ok=True) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 46cdb4d0..856d4c09 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -7,10 +7,7 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" @@ -186,8 +183,9 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] # Prepare tile descriptors @@ -290,8 +288,7 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] @@ -334,15 +331,3 @@ def compute_stride(shape): arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path, exist_ok=True) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 006d5112..14b7d432 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -7,10 +7,7 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" @@ -186,8 +183,9 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -291,8 +289,7 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane SUB_TILE_K = TILE_K - TOG_latency = O_W if TILE_M > O_W else TILE_M - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] @@ -335,15 +332,3 @@ def compute_stride(shape): arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path, exist_ok=True) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index c744258c..ff426ceb 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -7,8 +7,6 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common from torch._inductor.codecache import get_hash from PyTorchSimFrontend import extension_config @@ -190,8 +188,8 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K, TOG_latency = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) - SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TOG_latency = BATCH if TILE_M > BATCH else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -297,9 +295,7 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N - TOG_latency = BATCH if TILE_M > BATCH else TILE_M - TOG_latency = 8 if TOG_latency < 8 else TOG_latency - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K,TOG_latency + return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] @@ -342,15 +338,3 @@ def compute_stride(shape): arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) return arg_attributes - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path, exist_ok=True) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) - self.hash_value = get_hash(code.strip()) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 119debd9..9d3d3acf 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -8,8 +8,6 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir import mlir_common @@ -114,30 +112,13 @@ def render(self, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): - if template_buffer_node is not None: - self.output_node = template_buffer_node - - # Extract input arguments info - X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node - X_tensor = empty_strided(X.layout.size, X.layout.stride) - W_tensor = empty_strided(W.layout.size, W.layout.stride) - if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: - raise NotImplementedError("Please report this case to us...") - - # Extract fusion info - n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 - n_extra_read = set() - if epilogue_nodes is not None: - for enode in epilogue_nodes: - n_extra_read.update(enode.node.get_read_names()) - if self.output_node.name in n_extra_read: - n_extra_read.remove(self.output_node.name) - - # Select tile size - M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + if tile_info is None: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + else: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info # Select template code if (M == 0) or (N == 0) or (K == 0): # exception for MoE @@ -281,6 +262,41 @@ def render(self, kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code + def get_tile_candidates(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + **kwargs): + X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + return [[TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K]] + + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + # Extract input arguments info + X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + X_tensor = empty_strided(X.layout.size, X.layout.stride) + W_tensor = empty_strided(W.layout.size, W.layout.stride) + if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: + raise NotImplementedError("Please report this case to us...") + + # Extract fusion info + n_epilogue_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + n_extra_read = set() + if epilogue_nodes is not None: + for enode in epilogue_nodes: + n_extra_read.update(enode.node.get_read_names()) + if self.output_node.name in n_extra_read: + n_extra_read.remove(self.output_node.name) + + # Select tile size + M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] + return X,W,Y,M,N,K,n_epilogue_node,n_prologue_node,len(n_extra_read) + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): # Check cheat sheet cheatsheet_path = extension_config.CONFIG_GEMM_CHEATSHEET_PATH @@ -292,19 +308,21 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no data = json.load(f) gemm_shape = f"{M}_{K}_{N}" - if gemm_shape in data: + if extension_config.CONFIG_MANUAL_TILE_SIZE: + # case 1: use manual tile size + TILE_M = extension_config.CONFIG_TILE_M + TILE_N = extension_config.CONFIG_TILE_N + TILE_K = extension_config.CONFIG_TILE_K + elif gemm_shape in data: + # case 2: cached tile size tile_info = data[gemm_shape] TILE_M = tile_info["TILE_M"] TILE_N = tile_info["TILE_N"] TILE_K = tile_info["TILE_K"] - else: # case 2: use gemm_combination_mapping + else: + # case 3: use gemm_combination_mapping min_tile = (n_extra_node + n_prologue_node) == 0 - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(len(n_extra_read)-2, 0), n_prologue_node, min_tile=True) - # case 3: use manual tile size - if extension_config.CONFIG_MANUAL_TILE_SIZE: - TILE_M = extension_config.CONFIG_TILE_M - TILE_N = extension_config.CONFIG_TILE_N - TILE_K = extension_config.CONFIG_TILE_K + TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) # Edge case if (M == 0) or (N == 0) or (K == 0): @@ -330,14 +348,3 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no SUB_TILE_N = TILE_N SUB_TILE_K = TILE_K return TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path, exist_ok=True) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 6f605d56..a779e598 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -6,8 +6,6 @@ from torch._inductor.ir import Buffer from torch._inductor.ir import IRNode from torch._inductor.ir import ReinterpretView -from torch._inductor.codecache import write_atomic -import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir import mlir_common import sympy @@ -99,14 +97,3 @@ def render(self, code = self._template_from_string(TEMPLATE).render(**kernel.render_options) kernel.add_loop_info([X.get_numel()], [kernel.vector_lane, kernel.vector_lane]) return code - - def codegen_header(self, code, extra_headers): - write_path = extension_codecache.get_write_path(code) - if not os.path.exists(write_path): - os.makedirs(write_path) - spike_write_path = os.path.join(write_path, "global_var.h") - gem5_write_path = os.path.join(write_path, "gem5_global_var.h") - if not os.path.exists(spike_write_path): - write_atomic(spike_write_path, extra_headers[0]) - if not os.path.exists(gem5_write_path): - write_atomic(gem5_write_path, extra_headers[1]) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 7b7b179b..26b90401 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -276,9 +276,9 @@ def codegen_template(self, template_node, epilogue_nodes): # Generate template code template_buffer = template_node.node - kernel, render, codegen_header = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) + kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = kernel.codegen_nodes(render, codegen_header, template_node, prologue_nodes, epilogue_nodes) + src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(kernel): kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 762d2a93..07ebec51 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -20,7 +20,9 @@ from torch._inductor.autotune_process import TensorMeta from torch._inductor.virtualized import V, NullHandler, _ops as ops from torch._inductor.utils import IndentedBuffer +from torch._inductor.codecache import write_atomic +import PyTorchSimFrontend.extension_codecache as extension_codecache from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction @@ -447,14 +449,14 @@ def call_kernel(self, kernel_name): kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) - def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes): + def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() for node in [template_node, *prologue_nodes, *epilogue_nodes]: node.mark_run() # Partial codgen template nodes - partial_code = render() + partial_code = render(kwargs={**render.keywords['kwargs'], 'tile_info': tile_info}) # Swap load/store functions kernel.load = kernel.load_epilogue @@ -522,36 +524,42 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ else partial_code.finalize() ) - # For consistency, white space could make wrong write_path - buffer = IndentedBuffer() - buffer.splice(src_code) - return buffer.getvalue() + # For consistency, white space could make wrong write_path + buffer = IndentedBuffer() + buffer.splice(src_code) + src_code = buffer.getvalue() + self._prepare_simulator_headers(src_code) + return src_code - def make_choices(self, render, template_node, prologue_nodes, epilogue_nodes): + def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): choices = [] - for i in range(3): - self.autotune_idx = i - self.reset(reason=None) - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes) + for tile_info in tile_candidates: + src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) bench_runner = self.run_bench([template_node], self.kernel_name, src_code) choices.append((bench_runner, src_code, self.kernel_group)) + self.reset(reason=None) return choices - def codegen_nodes(self, render, codegen_header, template_node, prologue_nodes, epilogue_nodes): - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes) - - if False:# CONFIG_AUTOTUNE_TEMPLATE and not CONFIG_BACKENDSIM_SPIKE_ONLY: - src_code = self.autotune(render, template_node, prologue_nodes, epilogue_nodes) + def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): + src_code = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) with V.set_kernel_handler(self): - self._prepare_simulator_headers(src_code, codegen_header) self.meta_kernel() return src_code - def _prepare_simulator_headers(self, src_code, codegen_header): + def _prepare_simulator_headers(self, src_code): spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" spad_section_end_symbol = f"int spad_section_end[0] __attribute__ ((section(\".spad\"), aligned({self.spad_info['spad_size']*self.vector_lane})));" - codegen_header(src_code, (self.header.getvalue()+spad_end_symbol+spad_section_end_symbol, self.gem5_header.getvalue())) + + write_path = extension_codecache.get_write_path(src_code) + if not os.path.exists(write_path): + os.makedirs(write_path, exist_ok=True) + spike_write_path = os.path.join(write_path, "global_var.h") + gem5_write_path = os.path.join(write_path, "gem5_global_var.h") + if not os.path.exists(spike_write_path): + write_atomic(spike_write_path, self.header.getvalue()+spad_end_symbol+spad_section_end_symbol) + if not os.path.exists(gem5_write_path): + write_atomic(gem5_write_path, self.gem5_header.getvalue()) def codegen_prologue_body(self): body = IndentedBuffer() @@ -1256,7 +1264,8 @@ def make_kernel_render( template=self, kwargs=kwargs ) - return kernel, render, self.codegen_header + tile_candidates = self.get_tile_candidates(**kwargs) + return kernel, tile_candidates, render return MLIRTemplateCaller( kernel_hash_name, @@ -1268,5 +1277,8 @@ def make_kernel_render( self, ) + def get_tile_candidates(self, **kwargs): + return [] + def render(self, **kwargs) -> str: raise NotImplementedError \ No newline at end of file From f400d6800e4d7d0aacdf2e86f9359950e7c53215 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 5 Nov 2025 04:35:49 +0000 Subject: [PATCH 3/5] [Cleanup] Refactor conv template + autotune primitive --- PyTorchSimFrontend/mlir/mlir_conv_common.py | 120 ++++++++++++++++++ .../mlir/mlir_conv_mt_template.py | 102 ++------------- .../mlir/mlir_conv_sb_template.py | 103 ++------------- .../mlir/mlir_conv_sbs_template.py | 103 ++------------- PyTorchSimFrontend/mlir/mlir_conv_template.py | 104 ++------------- .../mlir/mlir_maxpool_template.py | 1 + 6 files changed, 161 insertions(+), 372 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_conv_common.py diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py new file mode 100644 index 00000000..e6379597 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -0,0 +1,120 @@ +import os +import math +from typing import List, Optional + +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +from torch._inductor.ir import IRNode +from PyTorchSimFrontend import extension_config + +class MLIRConvCommonTemplate(MLIRTemplate): + WRAPPER_TEMPLATE = None + def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.stride = kwargs["stride"] + self.padding = kwargs["padding"] + self.dilation = kwargs["dilation"] + self.weight_shape = [str(i) for i in input_nodes[1].layout.size] + self.input_shape = [str(i) for i in input_nodes[0].layout.size] + self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ + + "_".join([str(i) for i in self.stride]) \ + + "_" + "_".join([str(i) for i in self.padding]) \ + + "_" + "_".join([str(i) for i in self.dilation]) + self.kernel_args = ['X', 'W', 'Bias', 'Y'] + + def get_padded_input_size(self, X): + input_padded = list(X.layout.size) + input_padded[2] += 2 * self.padding[0] + input_padded[3] += 2 * self.padding[1] + return math.prod(input_padded) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + raise NotImplementedError() + + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + raise NotImplementedError() + + def extract_info(self, kernel, template_buffer_node, epilogue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + self.kernel = kernel + self.epilogue_nodes = epilogue_nodes + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + if epilogue_nodes is not None: + extra_node_rw = { + item.name for epilogue_node in epilogue_nodes + for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes + if item.name != Y.name + } + n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 + + BATCH, I_C, I_H, I_W = X.layout.size + O_C, _, K_H, K_W = W.layout.size + O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] + O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] + PADDING_H=self.padding[0] + PADDING_W=self.padding[1] + STRIDE_H=self.stride[0] + STRIDE_W=self.stride[1] + return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W + + def get_tile_candidates(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs): + # Extract input arguments info + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + return [self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)] + + def outer_func_render(self, kernel_name, input_args): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) + options = dict( + kernel=self.kernel, + KERNEL_NAME=kernel_name, + FUNC_NAME=self.function_name + f"_{len(input_args)}", + INPUT=X, + WEIGHT=W, + BIAS=Bias, + OUTPUT=Y, + PADDING_H=self.padding[0], + PADDING_W=self.padding[1], + VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + BACKENDSIM_EAGER_MODE=eager_mode, + input_reorder=self.input_reorder + ) + code = self._template_from_string(self.WRAPPER_TEMPLATE).render(**options) + return code, self.function_name + f"_{len(input_args)}" + + def get_arg_attributes(self): + arg_attributes = [] + + X = self.input_nodes[0] + X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] + X_shape[0] += 2 * self.padding[0] + X_shape[1] += 2 * self.padding[1] + + def compute_stride(shape): + stride = [1] * len(shape) + for i in range(len(shape)-2, -1, -1): + stride[i] = stride[i+1] * shape[i+1] + return stride + + X_stride = compute_stride(X_shape) + arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) + + return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index 26657712..3facedd5 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -1,10 +1,7 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode from PyTorchSimFrontend.mlir import mlir_common @@ -101,7 +98,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvMultiTileTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -127,62 +125,24 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvMultiTileTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency @@ -293,45 +253,3 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_K = TILE_K return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 856d4c09..6f3492c6 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -1,14 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode from PyTorchSimFrontend.mlir import mlir_common -from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" // Single Batch Conv2D kernel @@ -102,7 +98,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvSingleBatchTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -128,62 +125,24 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvSingleBatchTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency @@ -289,45 +248,3 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane SUB_TILE_K = TILE_K return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 14b7d432..53292858 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -1,14 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode from PyTorchSimFrontend.mlir import mlir_common -from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" // Single Batch Conv2D (Stride != 1) kernel @@ -102,7 +98,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvSingleBatchStridedTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -128,62 +125,24 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvSingleBatchStridedTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N TOG_latency = O_W if TILE_M > O_W else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency @@ -290,45 +249,3 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane SUB_TILE_K = TILE_K return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index ff426ceb..6fa3be53 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -1,15 +1,10 @@ -import os -import math from sympy import Symbol, Number from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs -from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_conv_common import MLIRConvCommonTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode from PyTorchSimFrontend.mlir import mlir_common -from torch._inductor.codecache import get_hash -from PyTorchSimFrontend import extension_config CONV_TEMPLATE = r""" // Conv2D kernel @@ -107,7 +102,8 @@ } """ -WRAPPER_TEMPLATE = r""" +class MLIRConvTemplate(MLIRConvCommonTemplate): + WRAPPER_TEMPLATE = r""" def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: # Padding input padded_shape = list(X.shape) @@ -133,62 +129,24 @@ def {{ FUNC_NAME }}{{kernel.def_wrapper()}}: yield ({{KERNEL_NAME}}, ) {%- endif %} """ - -class MLIRConvTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.stride = kwargs["stride"] - self.padding = kwargs["padding"] - self.dilation = kwargs["dilation"] - self.weight_shape = [str(i) for i in input_nodes[1].layout.size] - self.input_shape = [str(i) for i in input_nodes[0].layout.size] - self.function_name = "Conv2D_" + "_".join(self.input_shape) + "_".join(self.weight_shape)+ "_" \ - + "_".join([str(i) for i in self.stride]) \ - + "_" + "_".join([str(i) for i in self.padding]) \ - + "_" + "_".join([str(i) for i in self.dilation]) - self.kernel_args = ['X', 'W', 'Bias', 'Y'] - - def get_padded_input_size(self, X): - input_padded = list(X.layout.size) - input_padded[2] += 2 * self.padding[0] - input_padded[3] += 2 * self.padding[1] - return math.prod(input_padded) + super().__init__(input_nodes, layout, input_reorder, **kwargs) def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): # Extract input arguments info - if template_buffer_node is not None: - self.output_node = template_buffer_node - self.kernel = kernel - self.epilogue_nodes = epilogue_nodes - - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - if epilogue_nodes is not None: - extra_node_rw = { - item.name for epilogue_node in epilogue_nodes - for item in epilogue_node.read_writes.reads | epilogue_node.read_writes.writes - if item.name != Y.name - } - n_extra_node = len(extra_node_rw) if epilogue_nodes is not None else 0 - - BATCH, I_C, I_H, I_W = X.layout.size - O_C, _, K_H, K_W = W.layout.size - O_H = Y.layout.size[2] if template_buffer_node is None else template_buffer_node.layout.size[2] - O_W = Y.layout.size[3] if template_buffer_node is None else template_buffer_node.layout.size[3] - PADDING_H=self.padding[0] - PADDING_W=self.padding[1] - STRIDE_H=self.stride[0] - STRIDE_W=self.stride[1] + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + if tile_info is None: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + else: + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = BATCH if TILE_M > BATCH else TILE_M TOG_latency = 8 if TOG_latency < 8 else TOG_latency kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -296,45 +254,3 @@ def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K - - def outer_func_render(self, kernel_name, input_args): - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - eager_mode = int(os.environ.get('BACKENDSIM_EAGER_MODE', default=False)) - options = dict( - kernel=self.kernel, - KERNEL_NAME=kernel_name, - FUNC_NAME=self.function_name + f"_{len(input_args)}", - INPUT=X, - WEIGHT=W, - BIAS=Bias, - OUTPUT=Y, - PADDING_H=self.padding[0], - PADDING_W=self.padding[1], - VALIDATION_MODE=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, - BACKENDSIM_EAGER_MODE=eager_mode, - input_reorder=self.input_reorder - ) - code = self._template_from_string(WRAPPER_TEMPLATE).render(**options) - return code, self.function_name + f"_{len(input_args)}" - - def get_arg_attributes(self): - arg_attributes = [] - - X = self.input_nodes[0] - X_shape = [X.get_size()[i] for i in (2, 3, 0, 1)] - X_shape[0] += 2 * self.padding[0] - X_shape[1] += 2 * self.padding[1] - - def compute_stride(shape): - stride = [1] * len(shape) - for i in range(len(shape)-2, -1, -1): - stride[i] = stride[i+1] * shape[i+1] - return stride - - X_stride = compute_stride(X_shape) - arg_attributes.append([X.data.data.name, [MLIRKernelArgs.MLIR_ARGS_IN, X.layout.dtype, math.prod(X_shape), X_shape, X_stride]]) - - return arg_attributes diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index a779e598..2cca36b6 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -40,6 +40,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): if template_buffer_node is not None: self.output_node = template_buffer_node From 8c96a5a57ec6888040905aaf0ac5521192d0da83 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 5 Nov 2025 05:29:22 +0000 Subject: [PATCH 4/5] [Autotune] Connect autotune template --- PyTorchSimFrontend/extension_codecache.py | 2 +- PyTorchSimFrontend/extension_config.py | 2 +- PyTorchSimFrontend/mlir/mlir_autotune.py | 3 +- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 94 +++++++++++-------- .../mlir/mlir_codegen_backend.py | 29 ++++-- PyTorchSimFrontend/mlir/mlir_common.py | 5 +- PyTorchSimFrontend/mlir/mlir_conv_common.py | 2 +- .../mlir/mlir_conv_mt_template.py | 23 +++-- .../mlir/mlir_conv_sb_template.py | 20 ++-- .../mlir/mlir_conv_sbs_template.py | 20 ++-- PyTorchSimFrontend/mlir/mlir_conv_template.py | 22 +++-- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 51 +++++----- PyTorchSimFrontend/mlir/mlir_template.py | 87 +++++------------ 13 files changed, 186 insertions(+), 174 deletions(-) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 1e756f96..ca669361 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -299,7 +299,7 @@ def dummy_simulator(*args, **kwargs): # Dump arguments and meta data dump_metadata(args, arg_attributes, result_path) runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path) - if extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate: + if not autotune and (extension_config.CONFIG_TORCHSIM_VALIDATION_MODE or validate): funcsim = FunctionalSimulator(result_path, key) funcsim.run_spike(args, arg_attributes, runtime_path, self.validation_binary_name, diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 7eddfcb9..fa5d22b5 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -48,7 +48,7 @@ CONFIG_AUTOTUNE = int(os.environ.get('AUTOTUNE', default=True)) CONFIG_AUTOTUNE_TEMPLATE = int(os.environ.get('AUTOTUNE_TEMPLATE', default=True)) CONFIG_MAX_AUTOTUNE_TRY = int(os.environ.get('MAX_AUTOTUNE_TRY', default=10)) -CONFIG_AUTOTUNE_TOPK = int(os.environ.get('AUTOTUNE_TOPK', default=3)) +CONFIG_AUTOTUNE_TEMPLATE_TOPK = int(os.environ.get('AUTOTUNE_TEMPLATE_TOPK', default=4)) # For block sparse CONFIG_BLOCK_SPARSE = int(os.environ.get('BLOCK_SPARSE', default=0)) diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 54aed9c0..537809de 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -74,7 +74,8 @@ def cached_run_fn(*args, **kwargs): self.source_code, vectorlane_size=self.extra_args["vector_lane"], loop_size=None, spad_info=self.extra_args["spad_info"], vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"], - origins="Unknown", silent_mode=True) + origins="Unknown", silent_mode=True, + validate=self.extra_args['validate'], autotune=self.extra_args['autotune']) args = [ tensor diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 0c6583a7..9a12076a 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -160,29 +160,13 @@ def render(self, template_buffer_node = None, epilogue_nodes: Optional[List[IRNode]] = None, prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, **kwargs): - if template_buffer_node is not None: - self.output_node = template_buffer_node - - # Extract input arguments info - X, W = self.input_nodes[0], self.input_nodes[1] - Y = self.output_node - Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] - - W_tensor = empty_strided(W.layout.size, W.layout.stride) - X_tensor = empty_strided(X.layout.size, X.layout.stride) - if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2: - W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) - if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2: - X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) - B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] - - W_stride = W_tensor.stride() - X_stride = X_tensor.stride() - - # Select tile size - n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, len(prologue_nodes)) + X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + if tile_info is None: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)[0] + else: + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = M if TILE_M > M else TILE_M kernel.loop_size = [TOG_latency, TILE_N, TILE_K] @@ -190,17 +174,17 @@ def render(self, # Select template code nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] if nr_reduction_nodes: - template = BMM_REDUCTION_TEMPLATE - epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} - nr_rdim = 1 + template = BMM_REDUCTION_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index2", "index2": "index1"} + nr_rdim = 1 elif prologue_nodes: - template = BMM_PROLOGUE_TEMPLATE - epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} - nr_rdim = 0 + template = BMM_PROLOGUE_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 else: - template = BMM_TEMPLATE - epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} - nr_rdim = 0 + template = BMM_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2"} + nr_rdim = 0 # Prepare tile descriptors vlane_stride = 1 @@ -323,10 +307,46 @@ def render(self, kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"], kernel.render_options["K"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + # Extract input arguments info + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + + W_tensor = empty_strided(W.layout.size, W.layout.stride) + X_tensor = empty_strided(X.layout.size, X.layout.stride) + if len(W_tensor.size()) > 3 or len(W_tensor.size()) == 2: + W_tensor = W_tensor.view([-1, W_tensor.shape[-2], W_tensor.shape[-1]]) + if len(X_tensor.size()) > 3 or len(X_tensor.size()) == 2: + X_tensor = X_tensor.view([-1, X_tensor.shape[-2], X_tensor.shape[-1]]) + B, M, N, K = X_tensor.size()[0], X_tensor.size()[1], W_tensor.size()[2], X_tensor.size()[2] + + W_stride = W_tensor.stride() + X_stride = X_tensor.stride() + + # Select tile size + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + return X,W,Y,Bias,W_tensor,X_tensor,B,M,N,K,n_extra_node, n_prologue_node + + def get_tile_candidates(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + **kwargs): + X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node) + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane - SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane - TILE_K = TILE_K // 2 if n_prologue_node else TILE_K - return TILE_M,TILE_N,TILE_K,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + SUB_TILE_K = TILE_K # if (TILE_K < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + TILE_K = TILE_K // 2 if n_prologue_node else TILE_K + tile_candidates[idx] = TILE_M,TILE_N,TILE_K,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index d54963c2..b3352ea6 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1564,10 +1564,10 @@ def make_choices(self, nodes, kernel_name): current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) - print(f"[Auto-tune] Trying tile size: {current_tile_sz}, vlane_stride: {vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group)) + choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) while prevent_infinite_loop < 10 and candidate_axes: for axis in list(candidate_axes): @@ -1592,6 +1592,13 @@ def make_choices(self, nodes, kernel_name): src_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) + # FIXME. How to intergrate this constraint to tile system? + pad = self.kernel_group.tile_desc.vmap.get_used_vlane(current_tile_sz) * self.kernel_group.tile_desc.vmap.vlane_stride + vlane_size = current_tile_sz[self.kernel_group.tile_desc.vmap.vlane_split_axis] + if vlane_size > pad and vlane_size % pad: + prevent_infinite_loop += 1 + continue + # If tile size is converged for this axis, remove from candidate axes if current_tile_sz in search_space: candidate_axes.remove(axis) @@ -1599,10 +1606,10 @@ def make_choices(self, nodes, kernel_name): # Add this choice search_space.add(current_tile_sz) - print(f"[Auto-tune] Trying tile size: {current_tile_sz}, vlane_stride: {vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group)) + choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) prevent_infinite_loop += 1 self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold return choices @@ -1612,8 +1619,7 @@ def get_cycle(choice): bench_runner = choice[0] for n_try in range(extension_config.CONFIG_MAX_AUTOTUNE_TRY): # TODO: make simple try: - # bench_runner = self.run_bench(nodes, kernel_name, src_code) - out = bench_runner(validate=extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, autotune=True) + out = bench_runner() return out[-1] except (extension_codecache.SpadOverflowError, RuntimeError) as e: return float("inf") @@ -1627,14 +1633,21 @@ def get_cycle(choice): max_idx = results.index(min(results)) if min(results) == float("inf"): raise RuntimeError("Failed to find optimal tile size...") - print(f"[Auto-tune] Optimal tile size: {choices[max_idx][2].tile_desc.get_tile_size()}, vlane_stride: {choices[max_idx][2].tile_desc.vmap.vlane_stride}, cycles: {results[max_idx]}") + self._log_autotune_result(choices[max_idx], results[max_idx]) optimal_src_code = choices[max_idx][1] return optimal_src_code + def _log_autotune_result(self, best_choice, best_cycle): + print( + f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, " + f"vlane_stride: {best_choice[3]}, " + f"cycles: {best_cycle}" + ) + def codegen_nodes(self, nodes, kernel_name): src_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) - if extension_config.CONFIG_AUTOTUNE and extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: + if extension_config.CONFIG_AUTOTUNE and not extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY: optimal_src_code = self.autotune(nodes, kernel_name) if optimal_src_code is not None: return optimal_src_code diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 67d5380f..2644f125 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -408,6 +408,7 @@ def select_vlane_axis(self): self.vmap.vlane_split_axis = best_vlane_split_axis def pad_vlane_tile(self): + # FIXME. this doesn't follow tile constraints... vlane_split_axis, vlane_stride, vector_lane = self.vmap.vlane_split_axis, self.vmap.vlane_stride, self.vmap.vector_lane used_vlane = min(math.ceil(self._tile_size[vlane_split_axis] / vlane_stride), vector_lane) padded_size = used_vlane * vlane_stride @@ -790,7 +791,9 @@ def run_bench(self, nodes, kernel_name, src_code): "vector_lane" : self.vector_lane, "spad_info": self.spad_info, "vlen" : self.vlen, - "arg_attributes" : arg_attributes + "arg_attributes" : arg_attributes, + "validate" : extension_config.CONFIG_TORCHSIM_VALIDATION_MODE, + "autotune" : True, }, source_code=src_code, ) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index e6379597..52979d73 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -75,7 +75,7 @@ def get_tile_candidates(self, **kwargs): # Extract input arguments info X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) - return [self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)] + return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index 3facedd5..26018a94 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -140,7 +140,7 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -242,14 +242,13 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_K = TILE_K - - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index 6f3492c6..a2959b4d 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -140,7 +140,7 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -240,11 +240,13 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_I_W if TILE_I_W < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 53292858..afbe9289 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -140,7 +140,7 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -241,11 +241,13 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 6fa3be53..777d0a7b 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -144,7 +144,7 @@ def render(self, # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = BATCH if TILE_M > BATCH else TILE_M @@ -245,12 +245,14 @@ def render(self, return code def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) - SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - SUB_TILE_K = TILE_K - TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] - TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] - SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 - SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N - return TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] + TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] + SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W = 1, 1, 1, 1 + SUB_TILE_M = TILE_M if TILE_M < kernel.vector_lane else kernel.vector_lane + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + SUB_TILE_K = TILE_K + SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N + tile_candidates[idx] = TILE_K_H,TILE_K_W,TILE_O_H,TILE_O_W,TILE_M,TILE_N,TILE_K,TILE_I_H,TILE_I_W,SUB_TILE_I_H,SUB_TILE_I_W,SUB_TILE_K_H,SUB_TILE_K_W,SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 9d3d3acf..0830b4e6 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -116,7 +116,7 @@ def render(self, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -269,8 +269,7 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) - return [[TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K]] + return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): if template_buffer_node is not None: @@ -313,38 +312,44 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no TILE_M = extension_config.CONFIG_TILE_M TILE_N = extension_config.CONFIG_TILE_N TILE_K = extension_config.CONFIG_TILE_K + tile_candidates = [[TILE_M, TILE_N, TILE_K]] elif gemm_shape in data: # case 2: cached tile size tile_info = data[gemm_shape] TILE_M = tile_info["TILE_M"] TILE_N = tile_info["TILE_N"] TILE_K = tile_info["TILE_K"] + tile_candidates = [[TILE_M, TILE_N, TILE_K]] else: # case 3: use gemm_combination_mapping min_tile = (n_extra_node + n_prologue_node) == 0 - TILE_M, TILE_N, TILE_K = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) + tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) # Edge case if (M == 0) or (N == 0) or (K == 0): TILE_M, TILE_N, TILE_K = 1, 1, 1 + tile_candidates = [[TILE_M, TILE_N, TILE_K]] - # Calculate Sub Tile Size for fine-grained DMA - if extension_config.CONFIG_SUBTILE: - # Case 1: adjust selective fine-grained DMA (SFG-DMA) - SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane - if (TILE_M == M and TILE_N == N and TILE_N <= 512): - SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane - else: # Avoid Row Conflict of weights + full_tile_candidates = [] + for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): + # Calculate Sub Tile Size for fine-grained DMA + if extension_config.CONFIG_SUBTILE: + # Case 1: adjust selective fine-grained DMA (SFG-DMA) + SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane or n_prologue_node) else kernel.vector_lane + if (TILE_M == M and TILE_N == N and TILE_N <= 512): + SUB_TILE_N = TILE_N if TILE_N < kernel.vector_lane else kernel.vector_lane + else: # Avoid Row Conflict of weights + SUB_TILE_N = TILE_N + SUB_TILE_K = TILE_K + # Case 2: use manual sub tile size (FG-DMA) + if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: + SUB_TILE_M = extension_config.CONFIG_SUBTILE_M + SUB_TILE_N = extension_config.CONFIG_SUBTILE_N + SUB_TILE_K = extension_config.CONFIG_SUBTILE_K + # Case 3: None Subtile + else: + SUB_TILE_M = TILE_M SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K - # Case 2: use manual sub tile size (FG-DMA) - if extension_config.CONFIG_MANUAL_SUBTILE_SIZE: - SUB_TILE_M = extension_config.CONFIG_SUBTILE_M - SUB_TILE_N = extension_config.CONFIG_SUBTILE_N - SUB_TILE_K = extension_config.CONFIG_SUBTILE_K - # Case 3: None Subtile - else: - SUB_TILE_M = TILE_M - SUB_TILE_N = TILE_N - SUB_TILE_K = TILE_K - return TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K + SUB_TILE_K = TILE_K + full_tile_candidates.append([TILE_M,TILE_N,TILE_K, SUB_TILE_M,SUB_TILE_N,SUB_TILE_K]) + return full_tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 07ebec51..50fa6204 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -29,7 +29,7 @@ from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode from torch._inductor.codegen import common -from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR, CONFIG_AUTOTUNE_TEMPLATE, CONFIG_AUTOTUNE, CONFIG_BACKENDSIM_SPIKE_ONLY +from PyTorchSimFrontend.extension_config import CONFIG_TORCHSIM_DIR, CONFIG_AUTOTUNE_TEMPLATE_TOPK from . import mlir_common class IndentedBufferGroup: @@ -130,7 +130,6 @@ def __init__(self, self.reduction_mean = [] # Dim info self.dim_aliasing = {} - self.autotune_idx = 0 self.reason = reason def reset(self, reason): @@ -267,46 +266,10 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p mapping = (tile_M, tile_N, tile_K) if check_spad_size: tile_candidates.append((used_spad_size, (tile_M, tile_N, tile_K))) - if CONFIG_AUTOTUNE_TEMPLATE and not is_conv: - tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) - mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping - return mapping - - def search_mapping_space(self, mapping, idx, increment, stride, dilation, n_extra_node=0): - if idx == 0 or idx == 1 or idx == 4 or idx == 5 or idx == 6: - raise NotImplementedError("Only O_H and O_W are supported for search_mapping_space") - spad_size_per_lane = self.spad_info["spad_size"] - spad_size = spad_size_per_lane * self.vector_lane - max_spad_size = spad_size // 2 # double buffer - max_spad_per_lane = spad_size_per_lane // 2 # double buffer - - mapping = list(mapping) - mapping[idx] += increment - k_h, k_w, o_h, o_w, M, N, K = mapping - i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] - i_w = 1 + (o_w - 1) * stride[1] + (k_w - 1) * dilation[1] - weight_size = k_w * k_h * K * N - input_size = i_w * i_h * M * K - output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision - weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) - input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) - output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision - if used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane: - mapping = (k_h, k_w, o_h, o_w, M, N, K) - else: - mapping[idx] -= increment - - return mapping - def pseudo_auto_tune(self, mapping, stride, dilation, O_H, O_W, n_extra_node=0): - # pseudo auto-tune - if mapping[2] == 1 and not (O_H == 1): - mapping = self.search_mapping_space(mapping, 2, 1, stride, dilation, n_extra_node=n_extra_node) - if mapping[3] == 1 and not (O_W == 1): - mapping = self.search_mapping_space(mapping, 3, 1, stride, dilation, n_extra_node=n_extra_node) - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): tile_candidates = [] @@ -316,7 +279,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation max_spad_per_lane = spad_size_per_lane // 2 # double buffer max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True) + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] max_k_h_w = 1 # maximize kernel size max_o_h_w = 1 # maximize output size K = min(K, self.vector_lane) @@ -345,13 +308,9 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") - # FIXME: this should be implemented with auto-tuning - mapping = self.pseudo_auto_tune(mapping, stride, dilation, O_H, O_W, n_extra_node=n_extra_node) - if CONFIG_AUTOTUNE_TEMPLATE: - tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) - mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping - - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): tile_candidates = [] @@ -361,7 +320,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True) + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] max_k_h_w = K_W for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): @@ -385,10 +344,9 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, mapping = (k_h, K_W, o_h, o_w, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") - if CONFIG_AUTOTUNE_TEMPLATE: - tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) - mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): tile_candidates = [] @@ -398,7 +356,7 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True) + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] max_k_h_w = 1 for o_h in sympy.divisors(O_H): for k_h in sympy.divisors(K_H): @@ -422,10 +380,9 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio mapping = (k_h, k_w, o_h, M, M, N, K) if max_used_spad_size == 0: raise RuntimeError("Cannot find a valid mapping") - if CONFIG_AUTOTUNE_TEMPLATE: - tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) - mapping = tile_candidates[self.autotune_idx][1] if self.autotune_idx < len(tile_candidates) else mapping - return mapping + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + return tile_candidates def meta_kernel(self): wrapper = V.graph.wrapper_code @@ -534,12 +491,20 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): choices = [] for tile_info in tile_candidates: + print(f"[Auto-tune] Trying tile size: {list(tile_info)}") src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) bench_runner = self.run_bench([template_node], self.kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group)) + choices.append((bench_runner, src_code, tile_info)) self.reset(reason=None) return choices + def _log_autotune_result(self, best_choice, best_cycle): + tile_size = best_choice[2] + print( + f"[Auto-tune] Optimal tile size: {list(tile_size)}, " + f"cycles: {best_cycle}" + ) + def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): src_code = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) @@ -1264,7 +1229,7 @@ def make_kernel_render( template=self, kwargs=kwargs ) - tile_candidates = self.get_tile_candidates(**kwargs) + tile_candidates = self.get_tile_candidates(**kwargs)[:CONFIG_AUTOTUNE_TEMPLATE_TOPK] return kernel, tile_candidates, render return MLIRTemplateCaller( From 3692365be9631792dd7bdd9deba3159e0c46fa66 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 6 Nov 2025 04:17:40 +0000 Subject: [PATCH 5/5] [Fix] Fix wrong divder in reduction fusion --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 1 + PyTorchSimFrontend/mlir/mlir_gemm_template.py | 1 + PyTorchSimFrontend/mlir/mlir_template.py | 18 +++++++----------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 9a12076a..178ea987 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -301,6 +301,7 @@ def render(self, dram_idx = Y_idx, dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, + r_dim_size = M, dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 0830b4e6..c2120e7b 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -256,6 +256,7 @@ def render(self, dram_idx = Y_idx, dram_tile_desc = Y_tile_desc, nr_rdim = nr_rdim, + r_dim_size = M, dim_aliasing = epilogue_dim_aliasing ) code = self._template_from_string(template).render(**kernel.render_options) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 50fa6204..e6e9dd0c 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -892,7 +892,7 @@ def load_epilogue(self, name: str, index: sympy.Expr): vshape = f"vector<{vsize}x{mlir_dtype}>" if compute_vec_size > 1: - offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.reduction_axis_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) operation = "affine.vector_load" line = f"{operation} %{sram_var}[{compute_index_var}] : {tile_shape}, {vshape}" @@ -1077,12 +1077,7 @@ def store_reduction_epilogue(self, name, index, value): if self.welford_reduce_out is not None: # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size)} : f32") - if self.reduction_axis_size - 1 > 0: - divider2 = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.reduction_axis_size-1)} : f32") - else: - divider2 = divider - + divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.r_dim_size)} : f32") if self.buffer_types[name][1] > 1: divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") else: @@ -1121,15 +1116,16 @@ def set_tile_size(self, template_fusion_info, prologue=False): if 'nr_rdim' in template_fusion_info and template_fusion_info['nr_rdim']==1: tile_desc.nr_rdim = 1 numel_per_lane = tile_desc.get_numel_per_lane() - reduction_axis_size = tile_desc.get_tile_size()[-1] - nr_outer_loop = (numel_per_lane + reduction_axis_size-1) // reduction_axis_size + r_tile_size = tile_desc.get_tile_size()[-1] + nr_outer_loop = (numel_per_lane + r_tile_size-1) // r_tile_size tile_desc.vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True - self.reduction_axis_size = tile_desc.get_tile_size()[-1] + self.r_tile_size = tile_desc.get_tile_size()[-1] + self.r_dim_size = template_fusion_info['r_dim_size'] self.reduction_nr_outer_loop = nr_outer_loop self.reduction_loop_idx = "reduce_loop_idx" - self.compute_body_loop.size = reduction_axis_size + self.compute_body_loop.size = r_tile_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: