Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions PyTorchSimFrontend/mlir/mlir_bmm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
%Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>
%tag = memref.alloc() : memref<1xi32>
%v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>
{{- kernel.def_local_vars() }}

affine.for %b=0 to {{ B }} {
affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {
Expand Down
218 changes: 125 additions & 93 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions PyTorchSimFrontend/mlir/mlir_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,6 @@ def codegen_global_init(self):
def codegen_loops(self):
raise NotImplementedError()

def codegen_init(self):
raise NotImplementedError()

def call_kernel(self, kernel_name):
wrapper = V.graph.wrapper_code
_, call_args, _, _ = self.kernel_group.args.mlir_argdefs()
Expand Down Expand Up @@ -322,7 +319,6 @@ def _codegen_kernel(self, arg_defs, kernel_name):
for old, new in self.kernel_group.args.aliases():
code.writeline(f"auto {old} = {new};")
# Loop body part
code.splice(self.codegen_init())
code.splice(self.codegen_loops())
return code

Expand Down
1 change: 1 addition & 0 deletions PyTorchSimFrontend/mlir/mlir_conv_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
%W_buffer = memref.get_global @W_spad : memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1>
%Y_buffer = memref.get_global @Y_spad : memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1>
%tag = memref.alloc() : memref<1xi32>
{{- kernel.def_local_vars() }}

affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {
affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} {
Expand Down
1 change: 1 addition & 0 deletions PyTorchSimFrontend/mlir/mlir_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
%tag = memref.alloc() : memref<1xi32>{% if not Bias %}
%v0 = arith.constant dense<0.0> : vector<{{ TILE_M * TILE_N // kernel.vector_lane }}xf32>{% endif %}
%c0 = arith.constant 0 : index
{{- kernel.def_local_vars() }}

affine.for %t_m = 0 to {{ M }} step {{ TILE_M }} {
affine.for %t_n = 0 to {{ N }} step {{ TILE_N }} {
Expand Down
10 changes: 7 additions & 3 deletions PyTorchSimFrontend/mlir/mlir_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ def can_fuse_horizontal(self, node1, node2):
_, (vars1, reduce1) = node1.group
_, (vars2, reduce2) = node2.group

# Reduction is currently not supported
if node1.is_reduction() or node2.is_reduction():
return False

# Convolution is currently not supported
if node1.node.origin_node.target._name == 'aten::convolution' or node2.node.origin_node.target._name == 'aten::convolution':
if not isinstance(node1, FusedSchedulerNode) and node1.node.origin_node is not None and node1.node.origin_node.target._name == 'aten::convolution':
return False

# Reduction is currently not supported
if node1.is_reduction() or node2.is_reduction():
if not isinstance(node2, FusedSchedulerNode) and node2.node.origin_node is not None and node2.node.origin_node.target._name == 'aten::convolution':
return False

if not isinstance(node1, FusedSchedulerNode) and not isinstance(node2, FusedSchedulerNode):
Expand Down Expand Up @@ -138,6 +141,7 @@ def codegen_src_code(self, kernel, render, template_node, epilogue_nodes):
else partial_code.finalize()
)
src_code = kernel.add_extra_global_vars(src_code)
src_code = kernel.add_extra_local_vars(src_code)
return src_code

def codegen_template(self, template_node, epilogue_nodes):
Expand Down
103 changes: 62 additions & 41 deletions PyTorchSimFrontend/mlir/mlir_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from typing import List, Optional
from unittest.mock import patch

from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides
from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides, CSE
from torch._inductor.ir import Buffer, IRNode, TemplateBuffer
from torch._inductor.select_algorithm import PartialRender
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
from torch._inductor.autotune_process import TensorMeta
from torch._inductor.virtualized import V
from torch._inductor.utils import IndentedBuffer

from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest
from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo, MLIRTile
Expand Down Expand Up @@ -46,6 +47,9 @@ def __init__(self,
self.tile_size = []
self.loop_size = None
self.is_template_kernel = True
self.map_cse = CSE("#", self.suffix, name_prefix="template_map")
self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_const")
self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="template_alloc")

# Overwrite ops
self.load = self.load_epilogue
Expand Down Expand Up @@ -237,6 +241,22 @@ def add_extra_global_vars(self, code):
key = "<GLOBAL_VARS>"
return code.replace(key, self.replace_global_vars())

def def_local_vars(self):
return "<LOCAL_VARS>"

def replace_local_vars(self):
code = IndentedBuffer()
code.tabwidth = 2
code.splice("\n")
with code.indent():
code.splice(self.const_buffer)
code.splice(self.alloc_buffer)
return code.getvalue()

def add_extra_local_vars(self, code):
key = "<LOCAL_VARS>"
return code.replace(key, self.replace_local_vars())

def render(self, template, kwargs):
# self.render_hooks = {}
return PartialRender(
Expand All @@ -252,67 +272,68 @@ def adjust_tile_size(self):
return

def load_epilogue(self, name: str, index: sympy.Expr):
indices = self.parse_indices(index)
#index_var = self.parse_indices(index)
index_var = "index2"
index = self.rename_indexing(index)
var = self.args.input(name)
dram_var = self.args.input(name)
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]

if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
mvin3 = 14
self.consts.add(mvin3)
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, indices, index)
self.buffer_names[name] = buffer
line = f"affine.dma_start %{var}[%index2], %{buffer}[%e_c0, %e_c0], %tag[0], %e_c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>"
self.cse.generate(self.loads, line, assignment = False)

mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype]
if name not in self.buffer_names:
# Allocate sram buffer
tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
sram_var, index_var = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], tile_shape, self.loads, index_var, index)
self.buffer_names[name] = sram_var

# Generate DMA instruction
stride = self.render_options['N'] # FIXME. Is it okay?
chunk = 2 # FIXME. Is it okay?
index_var = "index2" # FIXME. Is it okay?
code = self.get_dma_code("MVIN", stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", self.buffer_types[name][1], tile_shape)
self.cse.generate(self.loads, code, assignment = False)

# Load vector from sram
sram_var = self.buffer_names[name]
tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
shape = f", vector<{tile_size_per_lane}x{mlir_dtype}>" if tile_size_per_lane > 1 else ""
zero_var = self.get_const_cse(0)
line = f"{operation} %{sram_var}[%{zero_var}, %{zero_var}] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{mlir_dtype}, 1>{shape}"
out = self.cse.generate(self.loads, line)
var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]]
self.register_var_info(out, var_info)
self.consts.add(0)
self.register_var_info(out, [tile_size_per_lane, mlir_dtype])
return out

def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
indices = self.parse_indices(index)
prefix = self.newvar_prefix
if index.is_number:
prefix = prefix + "c"
self.consts.add(int(index))
var = self.args.output(name)
#index_var = self.parse_indices(index)
index_var = "index2"
dram_var = self.args.output(name)
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]
mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype]

chunk_size = 1 # Fixed for template kernel
chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE)
self.consts.add(chunk)

if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
if name not in self.buffer_names:
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index)
self.buffer_names[name] = buffer
sram_var, index_var = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, index_var, index)
self.buffer_names[name] = sram_var
sram_var = self.buffer_names[name]

tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{value}, %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
shape = f", vector<{tile_size_per_lane}x{mlir_dtype}>" if tile_size_per_lane > 1 else ""
zero_var = self.get_const_cse(0)
line = f"{operation} %{value}, %{sram_var}[%{zero_var}, %{zero_var}] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{mlir_dtype}, 1>{shape}"
self.cse.generate(self.stores, line, assignment = False)

self.tags.add(f"{name}_tag")
self.consts.add(0)
code = f"affine.dma_start %{buffer}[%e_c0, %e_c0], %{var}[%index2], %tag[0], %c_mvout, %N, %e_c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag
stride = self.render_options['N'] # FIXME. Is it okay?
index_var = "index2" # FIXME. Is it okay?
dram_shape = f"{self.render_options['M'] * self.render_options['N']}"
tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
code = self.get_dma_code("MVOUT", stride, chunk, mlir_dtype, dram_var, index_var, sram_var, f"{name}_tag", dram_shape, tile_shape)
self.cse.generate(self.stores, code, assignment = False)

def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index):
return super().get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index, True)
def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, index_var, raw_index):
return super().get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, index_var, raw_index, True)

class MLIRTemplateCaller(CUDATemplateCaller):
def __str__(self):
Expand Down
Loading