From 6a1d6428e47ef2b911c110b4a7c10eef853b3298 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 10 Jan 2025 02:13:05 +0000 Subject: [PATCH 1/2] [Fix #d06d197] Add missing new_name setting part --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 3d65aa53..dae2fed7 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1372,7 +1372,7 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads? if name not in self.global_vars_dict: - self.global_vars_dict[name] = set() + self.global_vars_dict[name] = list() if str(raw_index) not in self.global_vars_dict[name]: new_name = f"{name}_{len(self.global_vars_dict[name])}" @@ -1380,7 +1380,9 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));") self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}];") self.global_vars.writeline(f"memref.global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") - self.global_vars_dict[name].add(str(raw_index)) + self.global_vars_dict[name].append(str(raw_index)) + else: + new_name = f"{name}_{self.global_vars_dict[name].index(str(raw_index))}" buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices From cbebc9c0c87ce7af4a0cf8aab7b678ac8d2a85b2 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 9 Jan 2025 13:21:02 +0000 Subject: [PATCH 2/2] [Frontend] cleanup fusion code --- .../mlir/mlir_codegen_backend.py | 72 +----------------- PyTorchSimFrontend/mlir/mlir_template.py | 74 +++++++++++++++++++ 2 files changed, 75 insertions(+), 71 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index dae2fed7..09c5aa34 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -664,10 +664,6 @@ class MLIRKernel(mlir_common.BaseMLIRKernel): def __init__(self): super().__init__(mlir_common.MLIRKernelArgs()) - - from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel - self.is_template_kernel = isinstance(self, MLIRTemplateKernel) - self.kernel_group = None self.call_ranges = None self.ranges = None @@ -695,6 +691,7 @@ def __init__(self): self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} + self.is_template_kernel = False def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] @@ -892,36 +889,7 @@ def codegen_nodes(self, nodes, kernel_name): write_atomic(gem5_write_path, self.gem5_header.getvalue()) return src_code - def load_epilogue(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) - var = self.args.input(name) - dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - - if name in self.buffer_names: - buffer = self.buffer_names[name] - else: - mvin3 = 14 - self.consts.add(mvin3) - self.consts.add(0) - dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index) - self.buffer_names[name] = buffer - line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" - self.cse.generate(self.loads, line, assignment = False) - - tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane - operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" - out = self.cse.generate(self.loads, line) - var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(out, var_info) - return out - def load(self, name: str, index: sympy.Expr): - if self.is_template_kernel: - return self.load_epilogue(name, index) index = self.rename_indexing(index) indices = self.parse_indices(index) prefix = self.newvar_prefix @@ -961,41 +929,7 @@ def load(self, name: str, index: sympy.Expr): self.register_var_info(out, var_info) return out - def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.output(name) - dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - - chunk_size = self.tile_desc.get_chunk_size() - chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) - self.consts.add(chunk) - - if name in self.buffer_names: - buffer = self.buffer_names[name] - else: - dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) - self.buffer_names[name] = buffer - - tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane - operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" - self.cse.generate(self.stores, line, assignment = False) - - self.tags.add(f"{name}_tag") - self.consts.add(0) - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag - self.cse.generate(self.stores, code, assignment = False) - def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): - if self.is_template_kernel: - return self.store_epilogue(name, index, value, args, kwargs) index = self.rename_indexing(index) indices = self.parse_indices(index) prefix = self.newvar_prefix @@ -1287,10 +1221,6 @@ def _codegen_kernel(self, arg_defs, kernel_name): return code def adjust_tile_size(self): - if self.is_template_kernel: - self.tile_desc.n_row = self.render_options['TILE_M'] - self.tile_desc.n_col = self.render_options['TILE_N'] - return if self.read_writes is not None: read_writes = list(self.read_writes.reads) + list(self.read_writes.writes) cv_list = [] diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 63d00dab..ec1340e7 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -3,6 +3,8 @@ import textwrap import re import math +import sympy + from typing import List, Optional from unittest.mock import patch @@ -22,6 +24,8 @@ from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, MLIRTile +from . import mlir_common + class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): def __init__(self, kernel_name, @@ -46,6 +50,11 @@ def __init__(self, self.render_options = dict() self.tile_size = [] self.loop_size = None + self.is_template_kernel = True + + # Overwrite ops + self.load = self.load_epilogue + self.store = self.store_epilogue def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): @@ -240,6 +249,71 @@ def render(self, template, kwargs): self.render_hooks, ) + def adjust_tile_size(self): + self.tile_desc.n_row = self.render_options['TILE_M'] + self.tile_desc.n_col = self.render_options['TILE_N'] + return + + def load_epilogue(self, name: str, index: sympy.Expr): + index = self.rename_indexing(index) + var = self.args.input(name) + dtype = V.graph.get_dtype(name) + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + + if name in self.buffer_names: + buffer = self.buffer_names[name] + else: + dram_mlir_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) + mvin3 = 14 + self.consts.add(mvin3) + self.consts.add(0) + dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index) + self.buffer_names[name] = buffer + line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : {dram_mlir_shape}, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" + self.cse.generate(self.loads, line, assignment = False) + + tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane + operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" + shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" + line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" + out = self.cse.generate(self.loads, line) + var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] + self.register_var_info(out, var_info) + return out + + def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): + indices = self.parse_indices(index) + prefix = self.newvar_prefix + if index.is_number: + prefix = prefix + "c" + self.consts.add(int(index)) + var = self.args.output(name) + dtype = V.graph.get_dtype(name) + type_name = mlir_common.DTYPE_TO_MLIR[dtype] + + chunk_size = self.tile_desc.get_chunk_size() + chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) + self.consts.add(chunk) + + if name in self.buffer_names: + buffer = self.buffer_names[name] + else: + dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" + buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) + self.buffer_names[name] = buffer + + tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane + operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store" + shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" + line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" + self.cse.generate(self.stores, line, assignment = False) + + self.tags.add(f"{name}_tag") + self.consts.add(0) + code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag + self.cse.generate(self.stores, code, assignment = False) + class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})"