From 498da27a9f4e1da9697eb6673e03ff237a040dfc Mon Sep 17 00:00:00 2001 From: OkkyunWoo Date: Fri, 10 Jan 2025 11:19:13 +0000 Subject: [PATCH] [Fronted] fusion basic case debug --- .../mlir/mlir_codegen_backend.py | 22 ++++++------- PyTorchSimFrontend/mlir/mlir_scheduling.py | 23 ++++++++++--- PyTorchSimFrontend/mlir/mlir_template.py | 32 ++++++++++--------- test_extension_backend.py | 1 + 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 99c39322..827e0be9 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -12,6 +12,7 @@ from collections import OrderedDict import torch from torch._inductor.codegen import cpp, wrapper, common +from torch._inductor.scheduler import BaseScheduling from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic, write from torch._inductor.utils import ( @@ -893,7 +894,7 @@ def template_store(options): f": memref<{options['TILE_M']}x{options['TILE_N']}xf32, 1>,"\ f"memref<{options['M'] * options['N']}xf32>, memref<1xi32>" #FIXME: Using constant index self.cse.generate(self.stores, line, assignment = False) - self.body.splice(self.codegen_init()) + self.body.splice(self.codegen_init('e_')) self.body.splice(self.loads) self.body.splice(self.compute) if len(self.stores._lines) == 0: @@ -906,14 +907,14 @@ def template_store(options): def codegen_global_init(self): return self.global_vars - def codegen_init(self): + def codegen_init(self, prefix=""): code = IndentedBuffer() tags = sorted(self.tags) consts = sorted(self.consts) for tag in tags: - code.writeline(f"%{tag} = memref.alloc() : memref<1xi32>") + code.writeline(f"%{prefix}{tag} = memref.alloc() : memref<1xi32>") for const in consts: - code.writeline(f"%c{const} = arith.constant {const} : index") + code.writeline(f"%{prefix}c{const} = arith.constant {const} : index") return code def codegen_loops(self): @@ -1122,17 +1123,14 @@ def adjust_tile_size(self): if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars): raise NotImplementedError() - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): + def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index, is_template=False): c_type = mlir_common.DTYPE_TO_C[dtype] mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] # Make sure each lane's buffer has at least two element tile_size = max(self.roundup_vectorlane(tile_row * tile_col), self.vector_lane * 2) - if dtype == torch.bool and not self.is_template_kernel: #FIXME: epilogue ReLU does not need this - if self.is_template_kernel: - mapping = f"template_{indices} " - self.map_cse.generate(self.global_vars, f"#{mapping} = affine_map<({indices}) -> ({indices} floordiv 8)>", assignment=False) - else: - mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>") + + if dtype == torch.bool and not is_template: + mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>") indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads? if name not in self.global_vars_dict: @@ -1192,4 +1190,4 @@ def mark_parallel(self, par_depth): loops[0].parallel = par_depth for i in range(1, par_depth): loops[i].collapsed = True - loops[0].simd = loops[par_depth - 1].simd + loops[0].simd = loops[par_depth - 1].simd \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 752fa8b4..7425728c 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -3,7 +3,7 @@ from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel from torch._inductor import config -from torch._inductor.scheduler import BaseScheduling +from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V @@ -24,15 +24,30 @@ def _set_flush_status(self, status: bool): self._ready_to_flush = status def can_fuse_vertical(self, node1, node2): - return False return self.can_fuse_horizontal(node1, node2) and not node1.is_reduction() def can_fuse_horizontal(self, node1, node2): - return False _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group + + # Reduction is currently not supported + if node1.is_reduction() or node2.is_reduction(): + return False + + if not isinstance(node1, FusedSchedulerNode) and not isinstance(node2, FusedSchedulerNode): + # Different layout is not supported + if node1.node.layout.dtype != node2.node.layout.dtype: + return False + + # Different size is not supported for non-template node + if not node1.is_template() and (node1._sizes[0] != node2._sizes[0]): + return False + if vars1 == vars2 and reduce1 == reduce2: return True + if reduce1 == () and vars1 == vars2 + reduce2: + return True + #TODO: Temporary solution determining the fusion condition similar to CPP/OpenMP v1_total = math.prod(vars1) if len(vars1) else 0 v2_total = math.prod(vars2) if len(vars2) else 0 @@ -40,8 +55,8 @@ def can_fuse_horizontal(self, node1, node2): r2_total = math.prod(reduce2) if len(reduce2) else 0 if reduce1 == () \ and v1_total == (v2_total + r2_total): - # and node1.node.layout.size == node2.node.layout.size: #FIXME: Need to check layout too? return True + return False def group_fn(self, sizes): diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index f8e2b428..c45c5d29 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -8,13 +8,8 @@ from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import KernelTemplate -from torch._inductor.codegen.common import ChoiceCaller -from torch._inductor.codegen.common import Kernel -from torch._inductor.codegen.common import OpOverrides -from torch._inductor.ir import Buffer -from torch._inductor.ir import IRNode -from torch._inductor.ir import TemplateBuffer +from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -250,11 +245,14 @@ def render(self, template, kwargs): ) def adjust_tile_size(self): + # Fixed tile size for template kernel + self.tile_desc.tile_layout = MLIRTile.TILE_COL_WISE self.tile_desc.n_row = self.render_options['TILE_M'] self.tile_desc.n_col = self.render_options['TILE_N'] return def load_epilogue(self, name: str, index: sympy.Expr): + indices = self.parse_indices(index) index = self.rename_indexing(index) var = self.args.input(name) dtype = V.graph.get_dtype(name) @@ -263,23 +261,24 @@ def load_epilogue(self, name: str, index: sympy.Expr): if name in self.buffer_names: buffer = self.buffer_names[name] else: - dram_mlir_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + # dram_mlir_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) mvin3 = 14 self.consts.add(mvin3) - self.consts.add(0) dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index) + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, indices, index) self.buffer_names[name] = buffer - line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : {dram_mlir_shape}, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" + # line = f"affine.dma_start %{var}[%index2], %{buffer}[%e_c0, %e_c0], %tag[0], %e_c{mvin3}, %N, %c_set : {dram_mlir_shape}, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" + line = f"affine.dma_start %{var}[%index2], %{buffer}[%e_c0, %e_c0], %tag[0], %e_c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" self.cse.generate(self.loads, line, assignment = False) tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" + line = f"{operation} %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" out = self.cse.generate(self.loads, line) var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] self.register_var_info(out, var_info) + self.consts.add(0) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): @@ -292,7 +291,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): dtype = V.graph.get_dtype(name) type_name = mlir_common.DTYPE_TO_MLIR[dtype] - chunk_size = self.tile_desc.get_chunk_size() + chunk_size = 1 # Fixed for template kernel chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) self.consts.add(chunk) @@ -306,14 +305,17 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store" shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" + line = f"{operation} %{value}, %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" self.cse.generate(self.stores, line, assignment = False) self.tags.add(f"{name}_tag") self.consts.add(0) - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag + code = f"affine.dma_start %{buffer}[%e_c0, %e_c0], %{var}[%index2], %tag[0], %c_mvout, %N, %e_c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag self.cse.generate(self.stores, code, assignment = False) + def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): + return super().get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index, True) + class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" diff --git a/test_extension_backend.py b/test_extension_backend.py index 6c5429c7..bee28729 100644 --- a/test_extension_backend.py +++ b/test_extension_backend.py @@ -50,5 +50,6 @@ # # Fusion Test test_matmul_scalar(device) + test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="relu") test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid") test_addmm_residual(device)