Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 5 additions & 73 deletions PyTorchSimFrontend/mlir/mlir_codegen_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,6 @@ class MLIRKernel(mlir_common.BaseMLIRKernel):

def __init__(self):
super().__init__(mlir_common.MLIRKernelArgs())

from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel
self.is_template_kernel = isinstance(self, MLIRTemplateKernel)

self.kernel_group = None
self.call_ranges = None
self.ranges = None
Expand Down Expand Up @@ -695,6 +691,7 @@ def __init__(self):
self.affine_yield = {}
self.welford_reduce_out = None
self.reduce_iterator = {}
self.is_template_kernel = False

def get_constant_vector(self, expr):
constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars]
Expand Down Expand Up @@ -892,36 +889,7 @@ def codegen_nodes(self, nodes, kernel_name):
write_atomic(gem5_write_path, self.gem5_header.getvalue())
return src_code

def load_epilogue(self, name: str, index: sympy.Expr):
index = self.rename_indexing(index)
var = self.args.input(name)
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]

if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
mvin3 = 14
self.consts.add(mvin3)
self.consts.add(0)
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index)
self.buffer_names[name] = buffer
line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : memref<{self.buffer_types[name][1]}x{type_name}>, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>"
self.cse.generate(self.loads, line, assignment = False)

tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
out = self.cse.generate(self.loads, line)
var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]]
self.register_var_info(out, var_info)
return out

def load(self, name: str, index: sympy.Expr):
if self.is_template_kernel:
return self.load_epilogue(name, index)
index = self.rename_indexing(index)
indices = self.parse_indices(index)
prefix = self.newvar_prefix
Expand Down Expand Up @@ -961,41 +929,7 @@ def load(self, name: str, index: sympy.Expr):
self.register_var_info(out, var_info)
return out

def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
indices = self.parse_indices(index)
prefix = self.newvar_prefix
if index.is_number:
prefix = prefix + "c"
self.consts.add(int(index))
var = self.args.output(name)
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]

chunk_size = self.tile_desc.get_chunk_size()
chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE)
self.consts.add(chunk)

if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index)
self.buffer_names[name] = buffer

tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
self.cse.generate(self.stores, line, assignment = False)

self.tags.add(f"{name}_tag")
self.consts.add(0)
code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag
self.cse.generate(self.stores, code, assignment = False)

def store(self, name: str, index: sympy.Expr, value, *args, **kwargs):
if self.is_template_kernel:
return self.store_epilogue(name, index, value, args, kwargs)
index = self.rename_indexing(index)
indices = self.parse_indices(index)
prefix = self.newvar_prefix
Expand Down Expand Up @@ -1287,10 +1221,6 @@ def _codegen_kernel(self, arg_defs, kernel_name):
return code

def adjust_tile_size(self):
if self.is_template_kernel:
self.tile_desc.n_row = self.render_options['TILE_M']
self.tile_desc.n_col = self.render_options['TILE_N']
return
if self.read_writes is not None:
read_writes = list(self.read_writes.reads) + list(self.read_writes.writes)
cv_list = []
Expand Down Expand Up @@ -1372,15 +1302,17 @@ def get_scratchpad_buffer(self, dtype, name, tile_row, tile_col, dram_tile_shape
indices = self.cse.generate(self.loads, f"affine.apply #{mapping}(%{indices})") # FIXME. Only loads?

if name not in self.global_vars_dict:
self.global_vars_dict[name] = set()
self.global_vars_dict[name] = list()

if str(raw_index) not in self.global_vars_dict[name]:
new_name = f"{name}_{len(self.global_vars_dict[name])}"
# Add definition to header
self.header.writeline(f"{c_type} {new_name}_spad[{tile_size // self.vector_lane}] __attribute__ ((section(\".spad\")));")
self.gem5_header.writeline(f"{c_type} {new_name}_spad[{tile_size}];")
self.global_vars.writeline(f"memref.global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>")
self.global_vars_dict[name].add(str(raw_index))
self.global_vars_dict[name].append(str(raw_index))
else:
new_name = f"{name}_{self.global_vars_dict[name].index(str(raw_index))}"
buffer = self.cse.generate(code_buffer, f"memref.get_global @{new_name}_spad : memref<{dram_tile_shape}x{mlir_type}, 1>")
return buffer, indices

Expand Down
74 changes: 74 additions & 0 deletions PyTorchSimFrontend/mlir/mlir_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import textwrap
import re
import math
import sympy

from typing import List, Optional
from unittest.mock import patch

Expand All @@ -22,6 +24,8 @@
from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo
from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, MLIRTile

from . import mlir_common

class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo):
def __init__(self,
kernel_name,
Expand All @@ -46,6 +50,11 @@ def __init__(self,
self.render_options = dict()
self.tile_size = []
self.loop_size = None
self.is_template_kernel = True

# Overwrite ops
self.load = self.load_epilogue
self.store = self.store_epilogue

def add_loop_info(self, mat_size, tile_size):
for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)):
Expand Down Expand Up @@ -240,6 +249,71 @@ def render(self, template, kwargs):
self.render_hooks,
)

def adjust_tile_size(self):
self.tile_desc.n_row = self.render_options['TILE_M']
self.tile_desc.n_col = self.render_options['TILE_N']
return

def load_epilogue(self, name: str, index: sympy.Expr):
index = self.rename_indexing(index)
var = self.args.input(name)
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]

if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
dram_mlir_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name])
mvin3 = 14
self.consts.add(mvin3)
self.consts.add(0)
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.loads, index)
self.buffer_names[name] = buffer
line = f"affine.dma_start %{var}[%index2], %{buffer}[%c0, %c0], %tag[0], %c{mvin3}, %N, %c_set : {dram_mlir_shape}, memref<{dram_tile_shape}x{type_name}, 1>, memref<1xi32>"
self.cse.generate(self.loads, line, assignment = False)

tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_load" if tile_size_per_lane > 1 else "affine.load"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
out = self.cse.generate(self.loads, line)
var_info = [tile_size_per_lane, mlir_common.DTYPE_TO_MLIR[dtype]]
self.register_var_info(out, var_info)
return out

def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs):
indices = self.parse_indices(index)
prefix = self.newvar_prefix
if index.is_number:
prefix = prefix + "c"
self.consts.add(int(index))
var = self.args.output(name)
dtype = V.graph.get_dtype(name)
type_name = mlir_common.DTYPE_TO_MLIR[dtype]

chunk_size = self.tile_desc.get_chunk_size()
chunk = chunk_size << 1 | (self.tile_desc.tile_per_lane_layout == MLIRTile.TILE_PER_LANE_COL_WISE)
self.consts.add(chunk)

if name in self.buffer_names:
buffer = self.buffer_names[name]
else:
dram_tile_shape = f"{self.render_options['TILE_M']}x{self.render_options['TILE_N']}"
buffer, indices = self.get_scratchpad_buffer(dtype, name, self.render_options['TILE_M'], self.render_options['TILE_N'], dram_tile_shape, self.stores, indices, index)
self.buffer_names[name] = buffer

tile_size_per_lane = self.render_options['TILE_M'] * self.render_options['TILE_N'] // self.vector_lane
operation = "affine.vector_store" if tile_size_per_lane > 1 else "affine.store"
shape = f", vector<{tile_size_per_lane}x{type_name}>" if tile_size_per_lane > 1 else ""
line = f"{operation} %{value}, %{buffer}[0, 0] : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>{shape}"
self.cse.generate(self.stores, line, assignment = False)

self.tags.add(f"{name}_tag")
self.consts.add(0)
code = f"affine.dma_start %{buffer}[%c0, %c0], %{var}[%index2], %tag[0], %c_mvout, %N, %c{chunk} : memref<{self.render_options['TILE_M']}x{self.render_options['TILE_N']}x{type_name}, 1>, memref<{self.render_options['M'] * self.render_options['N']}x{type_name}>, memref<1xi32>" #FIXME: Using constant index and tag
self.cse.generate(self.stores, code, assignment = False)

class MLIRTemplateCaller(CUDATemplateCaller):
def __str__(self):
return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})"
Expand Down
Loading