Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import OrderedDict
import torch
from torch._inductor.codegen import cpp, wrapper, common
from torch._inductor.scheduler import BaseScheduling
from torch._inductor.virtualized import V, _ops as ops
from torch._inductor.codecache import write_atomic, write
from torch._inductor.utils import (
Expand Down Expand Up @@ -903,7 +904,7 @@ def template_store(options):
f": memref<{options['TILE_M']}x{options['TILE_N']}xf32, 1>,"\
f"memref<{options['M'] * options['N']}xf32>, memref<1xi32>" #FIXME: Using constant index
self.cse.generate(self.stores, line, assignment = False)
self.body.splice(self.codegen_init())
self.body.splice(self.codegen_init('e_'))
self.body.splice(self.loads)
self.body.splice(self.compute)
if len(self.stores._lines) == 0:
Expand All @@ -916,14 +917,14 @@ def template_store(options):
def codegen_global_init(self):
return self.global_vars

def codegen_init(self):
def codegen_init(self, prefix=""):
code = IndentedBuffer()
tags = sorted(self.tags)
consts = sorted(self.consts)
for tag in tags:
code.writeline(f"%{tag} = memref.alloc() : memref<1xi32>")
code.writeline(f"%{prefix}{tag} = memref.alloc() : memref<1xi32>")
for const in consts:
code.writeline(f"%c{const} = arith.constant {const} : index")
code.writeline(f"%{prefix}c{const} = arith.constant {const} : index")
return code

def codegen_loops(self):
Expand Down Expand Up @@ -1140,17 +1141,14 @@ def adjust_tile_size(self):
if len(self.itervars) >= 3 and self.reduction_depth < len(self.itervars):
raise NotImplementedError()

def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index):
def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index, is_template=False):
c_type = mlir_common.DTYPE_TO_C[dtype]
mlir_type = mlir_common.DTYPE_TO_MLIR[dtype]
# Make sure each lane's buffer has at least two element
tile_size = max(self.roundup_vectorlane(tile_row * tile_col), self.vector_lane * 2)
if dtype == torch.bool and not self.is_template_kernel: #FIXME: epilogue ReLU does not need this
if self.is_template_kernel:
mapping = f"template_{indices} "
self.map_cse.generate(self.global_vars, f"#{mapping} = affine_map<({indices}) -> ({indices} floordiv 8)>", assignment=False)
else:
mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>")

if dtype == torch.bool and not is_template:
mapping = self.map_cse.generate(self.global_vars, f"affine_map<({indices}) -> ({indices} floordiv 8)>")
indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads?

if name not in self.global_vars_dict:
Expand Down Expand Up @@ -1210,4 +1208,4 @@ def mark_parallel(self, par_depth):
loops[0].parallel = par_depth
for i in range(1, par_depth):
loops[i].collapsed = True
loops[0].simd = loops[par_depth - 1].simd
loops[0].simd = loops[par_depth - 1].simd
27 changes: 23 additions & 4 deletions PyTorchSimFrontend/mlir/mlir_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel

from torch._inductor import config
from torch._inductor.scheduler import BaseScheduling
from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode
from torch._inductor.utils import IndentedBuffer
from torch._inductor.virtualized import V

Expand All @@ -24,24 +24,43 @@ def _set_flush_status(self, status: bool):
self._ready_to_flush = status

def can_fuse_vertical(self, node1, node2):
return False
return self.can_fuse_horizontal(node1, node2) and not node1.is_reduction()

def can_fuse_horizontal(self, node1, node2):
return False
_, (vars1, reduce1) = node1.group
_, (vars2, reduce2) = node2.group

# Convolution is currently not supported
if node1.node.origin_node.target._name == 'aten::convolution' or node2.node.origin_node.target._name == 'aten::convolution':
return False

# Reduction is currently not supported
if node1.is_reduction() or node2.is_reduction():
return False

if not isinstance(node1, FusedSchedulerNode) and not isinstance(node2, FusedSchedulerNode):
# Different layout is not supported
if node1.node.layout.dtype != node2.node.layout.dtype:
return False

# Different size is not supported for non-template node
if not node1.is_template() and (node1._sizes[0] != node2._sizes[0]):
return False

if vars1 == vars2 and reduce1 == reduce2:
return True
if reduce1 == () and vars1 == vars2 + reduce2:
return True

#TODO: Temporary solution determining the fusion condition similar to CPP/OpenMP
v1_total = math.prod(vars1) if len(vars1) else 0
v2_total = math.prod(vars2) if len(vars2) else 0
r1_total = math.prod(reduce1) if len(reduce1) else 0
r2_total = math.prod(reduce2) if len(reduce2) else 0
if reduce1 == () \
and v1_total == (v2_total + r2_total):
# and node1.node.layout.size == node2.node.layout.size: #FIXME: Need to check layout too?
return True

return False

def group_fn(self, sizes):
Expand Down
30 changes: 15 additions & 15 deletions PyTorchSimFrontend/mlir/mlir_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,8 @@
from typing import List, Optional
from unittest.mock import patch

from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.codegen.common import ChoiceCaller
from torch._inductor.codegen.common import Kernel
from torch._inductor.codegen.common import OpOverrides
from torch._inductor.ir import Buffer
from torch._inductor.ir import IRNode
from torch._inductor.ir import TemplateBuffer
from torch._inductor.codegen.common import Kernel, KernelTemplate, ChoiceCaller, OpOverrides
from torch._inductor.ir import Buffer, IRNode, TemplateBuffer
from torch._inductor.select_algorithm import PartialRender
from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller
from torch._inductor.autotune_process import TensorMeta
Expand Down Expand Up @@ -250,11 +245,14 @@ def render(self, template, kwargs):
)

def adjust_tile_size(self):
# Fixed tile size for template kernel
self.tile_desc.tile_layout = MLIRTile.TILE_COL_WISE
self.tile_desc.n_row = self.render_options['TILE_M']
self.tile_desc.n_col = self.render_options['TILE_N']
return

def load_epilogue(self, name: str, index: sympy.Expr):
indices = self.parse_indices(index)
index = self.rename_indexing(index)
var = self.args.input(name)
dtype = V.graph.get_dtype(name)
Expand All @@ -263,23 +261,22 @@ def load_epilogue(self, name: str, index: sympy.Expr):
if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
dram_mlir_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
mvin3 = 14
self.consts.add(mvin3)
self.consts.add(0)
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index)
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, indices, index)
self.buffer_names[name] = buffer
line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : {dram_mlir_shape}, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>"
line = f"affine.dma_start %{var}[%index2], %{buffer}[%e_c0, %e_c0], %tag[0], %e_c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>"
self.cse.generate(self.loads, line, assignment = False)

tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
line = f"{operation} %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
out = self.cse.generate(self.loads, line)
var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]]
self.register_var_info(out, var_info)
self.consts.add(0)
return out

def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
Expand All @@ -292,7 +289,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]

chunk_size = self.tile_desc.get_chunk_size()
chunk_size = 1 # Fixed for template kernel
chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE)
self.consts.add(chunk)

Expand All @@ -306,14 +303,17 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
line = f"{operation} %{value}, %{buffer}[%e_c0, %e_c0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
self.cse.generate(self.stores, line, assignment = False)

self.tags.add(f"{name}_tag")
self.consts.add(0)
code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag
code = f"affine.dma_start %{buffer}[%e_c0, %e_c0], %{var}[%index2], %tag[0], %c_mvout, %N, %e_c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag
self.cse.generate(self.stores, code, assignment = False)

def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index):
return super().get_scratchpad_buffer(dtype, name, tile_row, tile_col, dram_tile_shape, code_buffer, indices, raw_index, True)

class MLIRTemplateCaller(CUDATemplateCaller):
def __str__(self):
return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})"
Expand Down
1 change: 1 addition & 0 deletions test_extension_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@

# # Fusion Test
test_matmul_scalar(device)
test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="relu")
test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid")
test_addmm_residual(device)
Loading