diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index cd99d52e..5a68947e 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -35,6 +35,7 @@ %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> %tag = memref.alloc() : memref<1xi32> %v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32> + {{- kernel.def_local_vars() }} affine.for %b=0 to {{ B }} { affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index b4e2d00b..5060c50d 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -226,11 +226,20 @@ def to_dtype(operand, dst_mlir_dtype, *args, var_info=None): raise NotImplementedError("floating point to integer conversion") if dst_mlir_dtype[0] == "f" and src_mlir_dtype[0] == "i": raise NotImplementedError("integer to floating point conversion") - else: + if dst_mlir_dtype[0] == "i": + if dst_bits > src_bits: + return f"arith.extui %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + elif dst_bits < src_bits: + return f"arith.trunc %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + return f"arith.maximumi %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] + elif dst_mlir_dtype[0] == "f": if dst_bits > src_bits: - return f"arith.extui %{operand} : {src_shape} to {shape}" + return f"arith.extf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] elif dst_bits < src_bits: - return f"arith.trunc %{operand} : {src_shape} to {shape}" + return f"arith.trunf %{operand} : {src_shape} to {shape}", [tile_size, dst_mlir_dtype] + return f"arith.maximumf %{operand}, %{operand} : {shape}", [tile_size, dst_mlir_dtype] + else: + raise NotImplementedError("Unsupported type for to_dtype ops") @staticmethod def constant(value, src_type, *args, var_info=None): @@ -602,7 +611,7 @@ def broadcast(operand1, operand2, *args, var_info=None): "MVIN1": 2, "MVIN2": 1, "MVIN3": 14, - "MVOUT": 3, + "MVOUT1": 3, } class MLIRKernel(mlir_common.BaseMLIRKernel): @@ -611,6 +620,8 @@ class MLIRKernel(mlir_common.BaseMLIRKernel): def __init__(self): super().__init__(mlir_common.MLIRKernelArgs()) + self.const_buffer = IndentedBuffer() + self.alloc_buffer = IndentedBuffer() self.reduction_prefix = IndentedBuffer() self.reduction_suffix = IndentedBuffer() self.body = IndentedBuffer() @@ -623,16 +634,19 @@ def __init__(self): self.iterator_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="iter") self.init_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init") self.init_vec_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="init_vec") + self.const_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="const") + self.alloc_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="alloc") self.map_cse = common.CSE("#", self.suffix, name_prefix="map") - self.consts = set() - self.tags = set() - self.dma_cache = {} - self.dma_counter = 1 + self.consts = dict() + self.tags = dict() + self.dma_read_cache = {} + self.dma_write_cache = {} + self.dma_read_counter = 1 + self.dma_write_counter = 1 self.affine_yield = {} self.welford_reduce_out = None self.reduce_iterator = {} self.is_template_kernel = False - def set_ranges(self, lengths, reduction_lengths, read_writes): ret = super().set_ranges(lengths, reduction_lengths, read_writes) @@ -649,7 +663,12 @@ def get_padding_type(self): return 1 return 0 - def parse_indices(self, expr): + def parse_indices(self, expr) -> common.CSEVariable: + # Constant case + if expr.is_number: + return self.get_const_cse(int(expr)) + + # Identity case if len(expr.args) == 0: return expr @@ -661,9 +680,11 @@ def parse_indices(self, expr): indices[index] = None indices = list(indices.keys()) - args = ", ".join(map(str, indices)) + # Extract // pattern if "//" in expr_str: expr_str = expr_str.replace("//", " floordiv ") + + # Extract modular pattern pattern = r"ModularIndexing\((.*?)\)" matches = re.search(pattern, expr_str) if matches: @@ -672,6 +693,7 @@ def parse_indices(self, expr): replace_str = f"({args_list[0]} floordiv {args_list[1]}) mod {args_list[2]}" expr_str = re.sub(r"ModularIndexing\([^)]*\)", replace_str, expr_str) + args = ", ".join(map(str, indices)) map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") args = ", ".join([f"%{i}" for i in indices]) index = self.cse.generate(self.loads, f"affine.apply #{map_var}({args})") @@ -679,78 +701,52 @@ def parse_indices(self, expr): def load(self, name: str, index: sympy.Expr): index = self.rename_indexing(index) - indices = self.parse_indices(index) padding = self.get_padding_type() - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.input(name) + index_var = self.parse_indices(index) + dram_var = self.args.input(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype) - dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" + tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.loads, indices, index) + sram_var, index_var = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, tile_shape, self.loads, index_var, index) # MVIN Encoding - dma_key = (stride, chunk, dtype) - if dma_key in self.dma_cache: - dmaType, stride, chunk = self.dma_cache[dma_key] - else: - assert(self.dma_counter < 4) - dmaType = DMA_TYPE[f"MVIN{self.dma_counter}"] - self.dma_counter += 1 - self.consts.add(dmaType) - self.consts.add(stride) - self.consts.add(chunk) - self.dma_cache[dma_key] = dmaType, stride, chunk - self.tags.add(f"{name}_tag") - self.consts.add(0) - code = f"affine.dma_start %{var}[{prefix}{indices}], %{buffer}[%c0, %c0], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32> {{padding = {padding}}}" + code = self.get_dma_code("MVIN", stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", self.buffer_types[name][1], tile_shape, padding) self.cse.generate(self.loads, code, assignment = False) # FIXME: assignment = False does not support caching + # Generate vector load instruction operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" + shape = f", vector<{tile_size_per_lane}x{mlir_dtype}>" if tile_size_per_lane > 1 else "" + line = f"{operation} %{sram_var}[0, 0] : memref<{tile_shape}x{mlir_dtype}, 1>{shape}" out = self.cse.generate(self.loads, line) - var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(out, var_info) + self.register_var_info(out, [tile_size_per_lane, mlir_dtype]) return out def store(self, name: str, index: sympy.Expr, value, *args, **kwargs): index = self.rename_indexing(index) - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.output(name) + index_var = self.parse_indices(index) + dram_var = self.args.output(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] stride, chunk, tile_shape, tile_size_per_lane = self.get_dma_info(name, index, dtype) - dram_tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" + tile_shape = f"{tile_shape[0]}x{tile_shape[1]}" # Define scratch pad buffer - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, dram_tile_shape, self.stores, indices, index) - - # MVOUT Encoding - dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 - self.consts.add(dmaType) - self.consts.add(stride) - self.consts.add(chunk) + sram_var, index_var = self.get_scratchpad_buffer(dtype, name, self.tile_desc.n_row, self.tile_desc.n_col, tile_shape, self.stores, index_var, index) + # Generate vector store instruction store_size, operand_type = self.var_info[value] operation = "affine.vector_store" if tile_size_per_lane > 1 and store_size > 1 else "affine.store" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 and store_size > 1 else "" - if type_name != operand_type: - value = ops.custom_cast(value, type_name, var_info=self.var_info) + shape = f", vector<{tile_size_per_lane}x{mlir_dtype}>" if tile_size_per_lane > 1 and store_size > 1 else "" + if mlir_dtype != operand_type: + value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{dram_tile_shape}x{type_name}, 1>{shape}" + line = f"{operation} %{value}, %{sram_var}[0, 0] : memref<{tile_shape}x{mlir_dtype}, 1>{shape}" self.cse.generate(self.stores, line, assignment = False) - self.consts.add(0) - self.tags.add(f"{name}_tag") - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{stride}, %c{chunk} : memref<{dram_tile_shape}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" + + # Generate DMA instruction + code = self.get_dma_code("MVOUT", stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", self.buffer_types[name][1], tile_shape) self.cse.generate(self.stores, code, assignment = False) def reduction(self, dtype, src_dtype, reduction_type, value): @@ -828,28 +824,25 @@ def reduction(self, dtype, src_dtype, reduction_type, value): return acc def store_reduction(self, name, index, value): - var = self.args.output(name) + dram_var = self.args.output(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] index = self.rename_indexing(index) - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) + index_var = self.parse_indices(index) + # Tile is always reuduced in inner loop tile_col = self.tile_desc.n_row tile_row = 1 dram_tile_shape = f"{tile_row}x{tile_col}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, indices, index) + sram_var, index_var = self.get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, self.reductions_suffix, index_var, index) if self.welford_reduce_out is not None: # raise NotImplementedError() sum, sqr_sum, _ = self.welford_reduce_out - shape = f"vector<{self.tile_desc.get_rows_per_lane()}x{type_name}>" if self.buffer_types[name][1] > 1 else type_name + shape = f"vector<{self.tile_desc.get_rows_per_lane()}x{mlir_dtype}>" if self.buffer_types[name][1] > 1 else mlir_dtype # mean divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.ranges[self.reduction_depth])} : f32") if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{type_name}>") + divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to vector<{self.var_info[sum][0]}x{mlir_dtype}>") else: divider_vec = f"f{self.buffer_types[name][1]}" mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{sum}, %{divider_vec} : {shape}") @@ -875,36 +868,31 @@ def store_reduction(self, name, index, value): if self.tile_desc.get_rows_per_lane() == 1: shape = "" else: - shape = f"vector<{self.tile_desc.get_rows_per_lane()}x{type_name}>" + shape = f"vector<{self.tile_desc.get_rows_per_lane()}x{mlir_dtype}>" shape = f", {shape}" if self.buffer_types[name][1] > 1 else "" - line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{tile_row}x{tile_col}x{type_name}, 1>{shape}" + line = f"{operation} %{value}, %{sram_var}[0, 0] : memref<{tile_row}x{tile_col}x{mlir_dtype}, 1>{shape}" self.cse.generate(self.reductions_suffix, line, assignment = False) # MVOUT Encoding - dmaType = 3 # MVIN 2, MVIN2 1, MVIN3 14, MVOUT 3 mm_stride = tile_col is_col_major = mlir_common.MLIRTile.TILE_PER_LANE_ROW_WISE chunk_size = self.tile_desc.get_rows_per_lane() chunk = chunk_size << 1 | (is_col_major == mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE) - self.consts.add(dmaType) - self.consts.add(mm_stride) - self.consts.add(chunk) - self.tags.add(f"{name}_tag") + + # Generate DMA instruction # Change row, col - self.consts.add(0) - code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[{prefix}{indices}], %{name}_tag[0], %c{dmaType}, %c{mm_stride}, %c{chunk} : memref<{tile_row}x{tile_col}x{type_name}, 1>, memref<{self.buffer_types[name][1]}x{type_name}>, memref<1xi32>" + code = self.get_dma_code("MVOUT", mm_stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", self.buffer_types[name][1], f"{tile_row}x{tile_col}") self.cse.generate(self.reductions_suffix, code, assignment = False) def codegen_body(self): def template_store(options): subtile_size = [self.vector_lane, self.vector_lane] async_flag = 1 - self.consts.add(0) - line = f"affine.dma_start %Y_buffer[%c0, %c0], %Y[%index2], %tag[0], %c_mvout, %N, %c_set"\ + zero_var = self.get_const_cse(0) + line = f"affine.dma_start %Y_buffer[%{zero_var}, %{zero_var}], %Y[%index2], %tag[0], %c_mvout, %N, %c_set"\ f": memref<{options['TILE_M']}x{options['TILE_N']}xf32, 1>,"\ f"memref<{options['M'] * options['N']}xf32>, memref<1xi32>" #FIXME: Using constant index self.cse.generate(self.stores, line, assignment = False) - self.body.splice(self.codegen_init('e_')) self.body.splice(self.loads) self.body.splice(self.compute) if len(self.stores._lines) == 0: @@ -917,16 +905,6 @@ def template_store(options): def codegen_global_init(self): return self.global_vars - def codegen_init(self, prefix=""): - code = IndentedBuffer() - tags = sorted(self.tags) - consts = sorted(self.consts) - for tag in tags: - code.writeline(f"%{prefix}{tag} = memref.alloc() : memref<1xi32>") - for const in consts: - code.writeline(f"%{prefix}c{const} = arith.constant {const} : index") - return code - def codegen_loops(self): code = mlir_common.ParallelLoopBuffer() # Loop body part @@ -943,6 +921,9 @@ def codegen_loops(self): vars = ', '.join([f"%{name}" for name, _ in self.affine_yield.items()]) reduced_shapes = ', '.join([f"{shape}" for _, shape in self.affine_yield.items()]) self.stores.writeline(f"affine.yield {vars} : {reduced_shapes}") + + code.splice(self.const_buffer) + code.splice(self.alloc_buffer) with contextlib.ExitStack() as stack: for loop in loops.loops: loop_lines = loop.lines() @@ -969,7 +950,7 @@ def codegen_loops(self): def codegen_nodes(self, nodes, kernel_name): src_code = super().codegen_nodes(nodes, kernel_name) - # Create extra header for simulatoors + # Create extra headers for simulators write_path = extension_codecache.get_write_path(src_code) if not os.path.exists(write_path): os.makedirs(write_path) @@ -1092,6 +1073,47 @@ def get_dma_info(self, name, index, dtype): chunk = chunk_size << 1 | (current_tile.tile_per_lane_layout == mlir_common.MLIRTile.TILE_PER_LANE_COL_WISE) return mm_stride, chunk, [current_tile.n_row, current_tile.n_col], tile_size_per_lane + def get_dma_code(self, dma_type_name, stride, chunk, mlir_dtype, dram_var, index_var, sram_var, tag_name, dram_shape, tile_shape, padding_type=None): + dma_key = (stride, chunk, mlir_dtype) + if dma_type_name == "MVIN" and dma_key in self.dma_read_cache: + dma_type, mm_stride, chunk = self.dma_read_cache[dma_key] + elif dma_type_name == "MVOUT" and dma_key in self.dma_write_cache: + dma_type, mm_stride, chunk = self.dma_read_cache[dma_key] + else: + mm_stride = self.get_const_cse(stride) + chunk = self.get_const_cse(chunk) + if dma_type_name == "MVIN": + dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_read_counter}"]) + self.dma_read_counter += 1 + self.dma_read_cache[dma_key] = [dma_type, mm_stride, chunk] + else: + dma_type = self.get_const_cse(DMA_TYPE[f"{dma_type_name}{self.dma_write_counter}"]) + # self.dma_write_counter += 1 Is it okay? + self.dma_write_cache[dma_key] = [dma_type, mm_stride, chunk] + tag = self.get_tag_cse(tag_name) + zero_cse = self.get_const_cse(0) + + # Prepare opearnds and attributes + dram_operand = f"%{dram_var}[%{index_var}]" + sram_operand = f"%{sram_var}[%{zero_cse}, %{zero_cse}]" + tag_var = f"%{tag}[0]" + dma_attribute = f"%{dma_type}, %{mm_stride}, %{chunk}" + dram_shape = f"memref<{dram_shape}x{mlir_dtype}>" + sram_shape = f"memref<{tile_shape}x{mlir_dtype}, 1>" + tag_shape = "memref<1xi32>" + + if dma_type_name == "MVIN": + src_operand, dst_operand = dram_operand, sram_operand + src_shape, dst_shape = dram_shape, sram_shape + else: + src_operand, dst_operand = sram_operand, dram_operand + src_shape, dst_shape = sram_shape, dram_shape + + code = f"affine.dma_start {src_operand}, {dst_operand}, {tag_var}, {dma_attribute} : {src_shape}, {dst_shape}, {tag_shape}" + if padding_type is not None: + code = code + f" {{padding = {padding_type}}}" + return code + def adjust_tile_size(self): if self.read_writes is not None: read_writes = list(self.read_writes.reads) + list(self.read_writes.writes) @@ -1166,6 +1188,16 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>") return buffer, indices + def get_const_cse(self, value) -> common.CSEVariable: + if value not in self.consts: + self.consts[value] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : index") + return self.consts[value] + + def get_tag_cse(self, value, shape="memref<1xi32>"): + if value not in self.tags: + self.tags[value] = self.alloc_cse.generate(self.alloc_buffer, f"memref.alloc() : {shape}") + return self.tags[value] + @dataclasses.dataclass class LoopLevel: var: sympy.Expr diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 21612a4c..51399197 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -277,9 +277,6 @@ def codegen_global_init(self): def codegen_loops(self): raise NotImplementedError() - def codegen_init(self): - raise NotImplementedError() - def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() @@ -322,7 +319,6 @@ def _codegen_kernel(self, arg_defs, kernel_name): for old, new in self.kernel_group.args.aliases(): code.writeline(f"auto {old} = {new};") # Loop body part - code.splice(self.codegen_init()) code.splice(self.codegen_loops()) return code diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 3f52a61d..304474cc 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -35,6 +35,7 @@ %W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> %Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> %tag = memref.alloc() : memref<1xi32> + {{- kernel.def_local_vars() }} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 954059c0..c116ebf6 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -35,6 +35,7 @@ %tag = memref.alloc() : memref<1xi32>{% if not Bias %} %v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>{% endif %} %c0 = arith.constant 0 : index + {{- kernel.def_local_vars() }} affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} { affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} { diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 655d2944..c793dcf7 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -30,12 +30,15 @@ def can_fuse_horizontal(self, node1, node2): _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group + # Reduction is currently not supported + if node1.is_reduction() or node2.is_reduction(): + return False + # Convolution is currently not supported - if node1.node.origin_node.target._name == 'aten::convolution' or node2.node.origin_node.target._name == 'aten::convolution': + if not isinstance(node1, FusedSchedulerNode) and node1.node.origin_node is not None and node1.node.origin_node.target._name == 'aten::convolution': return False - # Reduction is currently not supported - if node1.is_reduction() or node2.is_reduction(): + if not isinstance(node2, FusedSchedulerNode) and node2.node.origin_node is not None and node2.node.origin_node.target._name == 'aten::convolution': return False if not isinstance(node1, FusedSchedulerNode) and not isinstance(node2, FusedSchedulerNode): @@ -138,6 +141,7 @@ def codegen_src_code(self, kernel, render, template_node, epilogue_nodes): else partial_code.finalize() ) src_code = kernel.add_extra_global_vars(src_code) + src_code = kernel.add_extra_local_vars(src_code) return src_code def codegen_template(self, template_node, epilogue_nodes): diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 2c1678a2..930f4777 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -8,12 +8,13 @@ from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides +from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides, CSE from torch._inductor.ir import Buffer, IRNode, TemplateBuffer from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta from torch._inductor.virtualized import V +from torch._inductor.utils import IndentedBuffer from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo, MLIRTile @@ -46,6 +47,9 @@ def __init__(self, self.tile_size = [] self.loop_size = None self.is_template_kernel = True + self.map_cse = CSE("#", self.suffix, name_prefix="template_map") + self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_const") + self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_alloc") # Overwrite ops self.load = self.load_epilogue @@ -237,6 +241,22 @@ def add_extra_global_vars(self, code): key = "" return code.replace(key, self.replace_global_vars()) + def def_local_vars(self): + return "" + + def replace_local_vars(self): + code = IndentedBuffer() + code.tabwidth = 2 + code.splice("\n") + with code.indent(): + code.splice(self.const_buffer) + code.splice(self.alloc_buffer) + return code.getvalue() + + def add_extra_local_vars(self, code): + key = "" + return code.replace(key, self.replace_local_vars()) + def render(self, template, kwargs): # self.render_hooks = {} return PartialRender( @@ -252,67 +272,68 @@ def adjust_tile_size(self): return def load_epilogue(self, name: str, index: sympy.Expr): - indices = self.parse_indices(index) + #index_var = self.parse_indices(index) + index_var = "index2" index = self.rename_indexing(index) - var = self.args.input(name) + dram_var = self.args.input(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] - - if name in self.buffer_names: - buffer = self.buffer_names[name] - else: - mvin3 = 14 - self.consts.add(mvin3) - dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, indices, index) - self.buffer_names[name] = buffer - line = f"affine.dma_start %{var}[%index2], %{buffer}[%e_c0, %e_c0], %tag[0], %e_c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>" - self.cse.generate(self.loads, line, assignment = False) - + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] + if name not in self.buffer_names: + # Allocate sram buffer + tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" + sram_var, index_var = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], tile_shape, self.loads, index_var, index) + self.buffer_names[name] = sram_var + + # Generate DMA instruction + stride = self.render_options['N'] # FIXME. Is it okay? + chunk = 2 # FIXME. Is it okay? + index_var = "index2" # FIXME. Is it okay? + code = self.get_dma_code("MVIN", stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", self.buffer_types[name][1], tile_shape) + self.cse.generate(self.loads, code, assignment = False) + + # Load vector from sram + sram_var = self.buffer_names[name] tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" + shape = f", vector<{tile_size_per_lane}x{mlir_dtype}>" if tile_size_per_lane > 1 else "" + zero_var = self.get_const_cse(0) + line = f"{operation} %{sram_var}[%{zero_var}, %{zero_var}] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{mlir_dtype}, 1>{shape}" out = self.cse.generate(self.loads, line) - var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]] - self.register_var_info(out, var_info) - self.consts.add(0) + self.register_var_info(out, [tile_size_per_lane, mlir_dtype]) return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - indices = self.parse_indices(index) - prefix = self.newvar_prefix - if index.is_number: - prefix = prefix + "c" - self.consts.add(int(index)) - var = self.args.output(name) + #index_var = self.parse_indices(index) + index_var = "index2" + dram_var = self.args.output(name) dtype = V.graph.get_dtype(name) - type_name = mlir_common.DTYPE_TO_MLIR[dtype] + mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] chunk_size = 1 # Fixed for template kernel chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE) - self.consts.add(chunk) - if name in self.buffer_names: - buffer = self.buffer_names[name] - else: + if name not in self.buffer_names: dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" - buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index) - self.buffer_names[name] = buffer + sram_var, index_var = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, index_var, index) + self.buffer_names[name] = sram_var + sram_var = self.buffer_names[name] tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store" - shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else "" - line = f"{operation} %{value}, %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}" + shape = f", vector<{tile_size_per_lane}x{mlir_dtype}>" if tile_size_per_lane > 1 else "" + zero_var = self.get_const_cse(0) + line = f"{operation} %{value}, %{sram_var}[%{zero_var}, %{zero_var}] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{mlir_dtype}, 1>{shape}" self.cse.generate(self.stores, line, assignment = False) - self.tags.add(f"{name}_tag") - self.consts.add(0) - code = f"affine.dma_start %{buffer}[%e_c0, %e_c0], %{var}[%index2], %tag[0], %c_mvout, %N, %e_c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag + stride = self.render_options['N'] # FIXME. Is it okay? + index_var = "index2" # FIXME. Is it okay? + dram_shape = f"{self.render_options['M'] * self.render_options['N']}" + tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}" + code = self.get_dma_code("MVOUT", stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", dram_shape, tile_shape) self.cse.generate(self.stores, code, assignment = False) - def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index): - return super().get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index, True) + def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, index_var, raw_index): + return super().get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, index_var, raw_index, True) class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self):