diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index 1faa934..b94c0ba 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -105,6 +105,9 @@ class OptimizedKernelGenerator(KernelGenerator): EXECUTE_ARRAY_NAME = 'ExecutePtrs' NONZEROFLOPS_NAME = 'NonZeroFlops' HARDWAREFLOPS_NAME = 'HardwareFlops' + OUTBYTES_NAME = 'OutboundBytes' + INCONSTBYTES_NAME = 'InboundConstBytes' + INBYTES_NAME = 'InboundBytes' MEMBER_FUNCTION_PTR_NAME = 'member_function_ptr' TEMP_MEM_REQUIRED_NAME = 'TmpMemRequiredInBytes' TEMP_MAX_MEM_REQUIRED_NAME = 'TmpMaxMemRequiredInBytes' @@ -127,6 +130,9 @@ class KernelOutline(object): def __init__(self, nonZeroFlops, hwFlops, + inConstBytes, + inBytes, + outBytes, tensors, writable, prefetch, @@ -138,6 +144,9 @@ def __init__(self, self.nonZeroFlops = nonZeroFlops self.hwFlops = hwFlops + self.inConstBytes = inConstBytes + self.inBytes = inBytes + self.outBytes = outBytes self.tensors = tensors self.writable = writable self.prefetch = prefetch @@ -166,6 +175,11 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): writable = dict() is_compute_constant_tensors = dict() scalars = collections.OrderedDict() + + inConstTensors = {} + inTensors = {} + outTensors = {} + for scalar in scalarsP: self.KernelOutline._addTensor(scalar, scalars) for var in variables: @@ -180,6 +194,19 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): is_compute_constant_tensors[bn] = var.tensor.is_compute_constant() + size = var.tensor.memoryLayout().storage().requiredReals() * self._arch.bytesPerReal + if var.tensor.is_compute_constant(): + inConstTensors[bn] = size + else: + if var.writable: + outTensors[bn] = size + else: + inTensors[bn] = size + + inConstBytes = sum(size for size in inConstTensors.values()) + inBytes = sum(size for size in inTensors.values()) + outBytes = sum(size for size in outTensors.values()) + prefetchTensors = SortedPrefetchList().visit(cfg) prefetch = collections.OrderedDict() for tensor in prefetchTensors: @@ -197,6 +224,9 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): function = functionIO.getvalue() return self.KernelOutline(nonZeroFlops, hwFlops, + inConstBytes, + inBytes, + outBytes, tensors, writable, prefetch, @@ -249,20 +279,20 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): with header.Namespace(self.NAMESPACE): with header.Struct(name): - header('{} {} const {}{} = {};'.format( - MODIFIERS, - self._arch.ulongTypename, - self.NONZEROFLOPS_NAME, - brackets, - formatArray([kernelOutline.nonZeroFlops if kernelOutline else 0 for kernelOutline in kernelOutlines]) - )) - header('{} {} const {}{} = {};'.format( - MODIFIERS, - self._arch.ulongTypename, - self.HARDWAREFLOPS_NAME, - brackets, - formatArray([kernelOutline.hwFlops if kernelOutline else 0 for kernelOutline in kernelOutlines]) - )) + def addConst(name, attrcall): + header('{} {} const {}{} = {};'.format( + MODIFIERS, + self._arch.ulongTypename, + name, + brackets, + formatArray([attrcall(kernelOutline) if kernelOutline else 0 for kernelOutline in kernelOutlines]) + )) + + addConst(self.NONZEROFLOPS_NAME, lambda ko: ko.nonZeroFlops) + addConst(self.HARDWAREFLOPS_NAME, lambda ko: ko.hwFlops) + addConst(self.INCONSTBYTES_NAME, lambda ko: ko.inConstBytes) + addConst(self.INBYTES_NAME, lambda ko: ko.inBytes) + addConst(self.OUTBYTES_NAME, lambda ko: ko.outBytes) # tmp mem required by a kernel(s) tmp_mem_list = [kernelOutline.tmp_mem_size if kernelOutline else 0 for kernelOutline in kernelOutlines] @@ -370,16 +400,6 @@ def generate_extra_offset_args(base_name_with_namespace, groups): with header.Function(funName, args, '{} {}'.format(MODIFIERS, self._arch.ulongTypename)): header('return {}[{}];'.format(function, indexF)) - flopCounters = [self.NONZEROFLOPS_NAME, self.HARDWAREFLOPS_NAME] - for fc in flopCounters: - cpp('{} {} const {}::{}::{}{};'.format( - CONSTEXPR, - self._arch.ulongTypename, - self.NAMESPACE, - name, - fc, - brackets - )) if familyStride is not None: cpp('{0} {1}::{2}::{3} {1}::{2}::{4}[];'.format( CONSTEXPR,