Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions yateto/codegen/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class OptimizedKernelGenerator(KernelGenerator):
EXECUTE_ARRAY_NAME = 'ExecutePtrs'
NONZEROFLOPS_NAME = 'NonZeroFlops'
HARDWAREFLOPS_NAME = 'HardwareFlops'
OUTBYTES_NAME = 'OutboundBytes'
INCONSTBYTES_NAME = 'InboundConstBytes'
INBYTES_NAME = 'InboundBytes'
MEMBER_FUNCTION_PTR_NAME = 'member_function_ptr'
TEMP_MEM_REQUIRED_NAME = 'TmpMemRequiredInBytes'
TEMP_MAX_MEM_REQUIRED_NAME = 'TmpMaxMemRequiredInBytes'
Expand All @@ -127,6 +130,9 @@ class KernelOutline(object):
def __init__(self,
nonZeroFlops,
hwFlops,
inConstBytes,
inBytes,
outBytes,
tensors,
writable,
prefetch,
Expand All @@ -138,6 +144,9 @@ def __init__(self,

self.nonZeroFlops = nonZeroFlops
self.hwFlops = hwFlops
self.inConstBytes = inConstBytes
self.inBytes = inBytes
self.outBytes = outBytes
self.tensors = tensors
self.writable = writable
self.prefetch = prefetch
Expand Down Expand Up @@ -166,6 +175,11 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target):
writable = dict()
is_compute_constant_tensors = dict()
scalars = collections.OrderedDict()

inConstTensors = {}
inTensors = {}
outTensors = {}

for scalar in scalarsP:
self.KernelOutline._addTensor(scalar, scalars)
for var in variables:
Expand All @@ -180,6 +194,19 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target):

is_compute_constant_tensors[bn] = var.tensor.is_compute_constant()

size = var.tensor.memoryLayout().storage().requiredReals() * self._arch.bytesPerReal
if var.tensor.is_compute_constant():
inConstTensors[bn] = size
else:
if var.writable:
outTensors[bn] = size
else:
inTensors[bn] = size

inConstBytes = sum(size for size in inConstTensors.values())
inBytes = sum(size for size in inTensors.values())
outBytes = sum(size for size in outTensors.values())

prefetchTensors = SortedPrefetchList().visit(cfg)
prefetch = collections.OrderedDict()
for tensor in prefetchTensors:
Expand All @@ -197,6 +224,9 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target):
function = functionIO.getvalue()
return self.KernelOutline(nonZeroFlops,
hwFlops,
inConstBytes,
inBytes,
outBytes,
tensors,
writable,
prefetch,
Expand Down Expand Up @@ -249,20 +279,20 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None):

with header.Namespace(self.NAMESPACE):
with header.Struct(name):
header('{} {} const {}{} = {};'.format(
MODIFIERS,
self._arch.ulongTypename,
self.NONZEROFLOPS_NAME,
brackets,
formatArray([kernelOutline.nonZeroFlops if kernelOutline else 0 for kernelOutline in kernelOutlines])
))
header('{} {} const {}{} = {};'.format(
MODIFIERS,
self._arch.ulongTypename,
self.HARDWAREFLOPS_NAME,
brackets,
formatArray([kernelOutline.hwFlops if kernelOutline else 0 for kernelOutline in kernelOutlines])
))
def addConst(name, attrcall):
header('{} {} const {}{} = {};'.format(
MODIFIERS,
self._arch.ulongTypename,
name,
brackets,
formatArray([attrcall(kernelOutline) if kernelOutline else 0 for kernelOutline in kernelOutlines])
))

addConst(self.NONZEROFLOPS_NAME, lambda ko: ko.nonZeroFlops)
addConst(self.HARDWAREFLOPS_NAME, lambda ko: ko.hwFlops)
addConst(self.INCONSTBYTES_NAME, lambda ko: ko.inConstBytes)
addConst(self.INBYTES_NAME, lambda ko: ko.inBytes)
addConst(self.OUTBYTES_NAME, lambda ko: ko.outBytes)

# tmp mem required by a kernel(s)
tmp_mem_list = [kernelOutline.tmp_mem_size if kernelOutline else 0 for kernelOutline in kernelOutlines]
Expand Down Expand Up @@ -370,16 +400,6 @@ def generate_extra_offset_args(base_name_with_namespace, groups):
with header.Function(funName, args, '{} {}'.format(MODIFIERS, self._arch.ulongTypename)):
header('return {}[{}];'.format(function, indexF))

flopCounters = [self.NONZEROFLOPS_NAME, self.HARDWAREFLOPS_NAME]
for fc in flopCounters:
cpp('{} {} const {}::{}::{}{};'.format(
CONSTEXPR,
self._arch.ulongTypename,
self.NAMESPACE,
name,
fc,
brackets
))
if familyStride is not None:
cpp('{0} {1}::{2}::{3} {1}::{2}::{4}[];'.format(
CONSTEXPR,
Expand Down