diff --git a/yateto/__init__.py b/yateto/__init__.py index 0b3dd50..48c3f0b 100644 --- a/yateto/__init__.py +++ b/yateto/__init__.py @@ -2,3 +2,4 @@ from .generator import NamespacedGenerator, Generator, simpleParameterSpace, parameterSpaceFromRanges from .arch import useArchitectureIdentifiedBy from .gemm_configuration import * +from .memory import * diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index eb3f7a0..b155881 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -89,9 +89,10 @@ def generate(self, cpp, routineCache): cpp( self._gemm_cfg.call(d.transA, d.transB, m.size(), n.size(), k.size(), - d.alpha, self._pointer(d.leftTerm, (m.start, k.start), d.transA), ldA, + d.alpha, self._pointer(d.leftTerm, (m.start, k.start), d.transA), + ldA, d.alignedA, self._pointer(d.rightTerm, (k.start, n.start), d.transB), ldB, - d.beta, self._pointer(d.result, (m.start, n.start), False), ldC)) + d.beta, self._pointer(d.result, (m.start, n.start), False), ldC, d.alignedC)) elif isinstance(self._gemm_cfg, GemmForge): diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index ecbaa5f..4e81308 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -61,9 +61,9 @@ def generate(self, cpp, cfg, factory, routineCache, gemm_cfg): # an provided by the user required_tmp_mem = 0 cfg = DetermineLocalInitialization().visit(cfg) - localPtrs = list() + localPtrs = set() for pp in cfg: - localPtrs.extend(pp.bufferMap.keys()) + localPtrs.update(pp.bufferMap.keys()) if localPtrs: cpp( '{}{};'.format(self._arch.typename, ','.join(map(lambda x: ' *' + str(x), localPtrs))) ) for pp in cfg: diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index 9990cb6..e99a655 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -13,13 +13,13 @@ def variables(self): return {self} def maySubstitute(self, when, by): - return self.substituted(when, by).memoryLayout().isCompatible(self.eqspp()) + return self.substituted(when, by).memoryLayout().isCompatible(self.memoryLayout(), self.eqspp()) def substituted(self, when, by, memoryLayout=None): return by if self == when else self def resultCompatible(self, result): - return result.memoryLayout().isCompatible(self.eqspp()) + return result.memoryLayout().isCompatible(self.memoryLayout(), self.eqspp()) def isGlobal(self): return self.tensor is not None @@ -71,7 +71,7 @@ def variableList(self): def maySubstitute(self, when, by): layouts = [var.substituted(when, by).memoryLayout() for var in self._variables] - c1 = all(layouts[i].isCompatible(var.eqspp()) for i,var in enumerate(self._variables)) + c1 = all(layouts[i].isCompatible(var.memoryLayout(), var.eqspp()) for i,var in enumerate(self._variables)) c2 = self.node.argumentsCompatible(layouts) return c1 and c2 @@ -79,7 +79,7 @@ def substituted(self, when, by, memoryLayout): return Expression(self.node, memoryLayout, [var.substituted(when, by) for var in self._variables]) def resultCompatible(self, result): - c1 = result.memoryLayout().isCompatible(self.eqspp()) + c1 = result.memoryLayout().isCompatible(self.memoryLayout(), self.eqspp()) c2 = self.node.resultCompatible(result.memoryLayout()) return c1 and c2 diff --git a/yateto/controlflow/transformer.py b/yateto/controlflow/transformer.py index e041389..73c16a6 100644 --- a/yateto/controlflow/transformer.py +++ b/yateto/controlflow/transformer.py @@ -124,7 +124,9 @@ def visit(self, cfg): ua = cfg[i].action # assign buffer if ua and not ua.isCompound() and ua.result.isLocal(): - if len(freeBuffers) > 0: + if ua.result in usedBuffers: + buf = usedBuffers[ua.result] + elif len(freeBuffers) > 0: buf = freeBuffers.pop() else: buf = numBuffers diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index c02b285..9e491ab 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -38,7 +38,7 @@ def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, def bool2Trans(self, trans): return 'Cblas{}Trans'.format('' if trans else 'No') - def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC): + def call(self, transA, transB, M, N, K, alpha, A, ldA, alignedA, B, ldB, beta, C, ldC, alignedC): parameters = [ 'CblasColMajor', self.bool2Trans(transA), @@ -65,7 +65,7 @@ def __init__(self, arch): def bool2Trans(self, trans): return 'BLIS{}TRANSPOSE'.format('_' if trans else '_NO_') - def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC): + def call(self, transA, transB, M, N, K, alpha, A, ldA, alignedA, B, ldB, beta, C, ldC, alignedC): init = '_blis_alpha = {}; _blis_beta = {};'.format(alpha, beta) parameters = [ self.bool2Trans(transA), @@ -91,13 +91,13 @@ def bool2Trans(self, trans): def sizeTrans(self, rows, cols, trans): return '{},{}'.format(cols,rows) if trans else '{},{}'.format(rows,cols) - def align(self, ld): + def align(self, ld, is_aligned): aligned = 'Unaligned' - if self._arch.checkAlignment(ld) and self._arch.alignment in [16,32,64,128]: + if is_aligned and self._arch.checkAlignment(ld) and self._arch.alignment in [16,32,64,128]: aligned = 'Aligned{}'.format(self._arch.alignment) return aligned - def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC): + def call(self, transA, transB, M, N, K, alpha, A, ldA, alignedA, B, ldB, beta, C, ldC, alignedC): AxB = '{alpha}_mapA{transA}*_mapB{transB}'.format( alpha=str(alpha) + '*' if alpha != 1.0 else '', transA=self.bool2Trans(transA), transB=self.bool2Trans(transB), @@ -122,7 +122,7 @@ def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC): sizeA=self.sizeTrans(M,K,transA), sizeB=self.sizeTrans(K,N,transB), ldA=ldA, ldB=ldB, ldC=ldC, A=A, B=B, C=C, - alignA=self.align(ldA), alignC=self.align(ldC), + alignA=self.align(ldA, alignedA), alignC=self.align(ldC, alignedC), code=code) return code @@ -133,8 +133,8 @@ def __init__(self, operation_name: str, includes: List[str], cmd: str, arch): self._arch = arch class LIBXSMM(CodeGenerator): - def __init__(self, arch, threshold: int = 128): - super().__init__('libxsmm', [], 'libxsmm_gemm_generator', arch) + def __init__(self, arch, cmd: str = 'libxsmm_gemm_generator', threshold: int = 128): + super().__init__('libxsmm', [], cmd, arch) self._threshold = threshold def _archSupported(self): @@ -159,8 +159,8 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali return Preference.LOW class PSpaMM(CodeGenerator): - def __init__(self, arch, threshold: int = 128): - super().__init__('pspamm', [], 'pspamm.py', arch) + def __init__(self, arch, cmd: str = 'pspamm.py', threshold: int = 128): + super().__init__('pspamm', [], cmd, arch) self._threshold = threshold def _archSupported(self): diff --git a/yateto/memory.py b/yateto/memory.py index 4cc8acf..48a755b 100644 --- a/yateto/memory.py +++ b/yateto/memory.py @@ -1,10 +1,22 @@ from .ast.indices import BoundingBox, Range import copy +from enum import Enum import itertools import warnings import numpy as np from abc import ABC, abstractmethod +class Alignment(Enum): + """Alignment mode. + + Automatic: Assume aligned memory if stride is divisible by vector width. + Aligned: Pad the leading dimension with zeros such that stride is divisible by memory width. + Unaligned: Always assume unaligned memory access. + """ + Automatic = 0, + Aligned = 1, + Unaligned = 2 + class MemoryLayout(ABC): def __init__(self, shape): self._shape = shape @@ -45,7 +57,7 @@ def __eq__(self, other): pass @abstractmethod - def isCompatible(self, spp): + def isCompatible(self, other, eqspp): pass class DenseMemoryLayout(MemoryLayout): @@ -55,7 +67,15 @@ class DenseMemoryLayout(MemoryLayout): def setAlignmentArch(cls, arch): cls.ALIGNMENT_ARCH = arch - def __init__(self, shape, boundingBox=None, stride=None, alignStride=False): + def __init__(self, shape, boundingBox=None, stride=None, alignStride=Alignment.Automatic): + """Construct DenseMemoryLayout. + + :param shape: tensor shape (tuple of integers) + :param boundingBox: Non-zero BoundingBox, covers complete tensor if None + :param stride: Stride of the leading dimension, computed automatically if None + :param alignStride: Alignment mode. Passing False is equal to Alignment.Automatic and passing + True is equal to Alignment.Aligned. + """ super().__init__(shape) if boundingBox: @@ -63,8 +83,16 @@ def __init__(self, shape, boundingBox=None, stride=None, alignStride=False): else: self._bbox = BoundingBox([Range(0, s) for s in self._shape]) + if alignStride == True: + self._alignment = Alignment.Aligned + elif alignStride == False: + self._alignment = Alignment.Automatic + elif isinstance(alignStride, Alignment): + self._alignment = alignStride + else: + raise ValueError("Unknown type for option alignStride") self._range0 = None - if alignStride: + if self._alignment == Alignment.Aligned: self._alignBB() if stride: @@ -86,8 +114,11 @@ def _alignBB(self): else: warnings.warn('Set architecture with DenseMemoryLayout.setAlignmentArch(arch) if you want to use the align stride feature.', UserWarning) + def alignment(self): + return self._alignment + def alignedStride(self): - if self.ALIGNMENT_ARCH is None: + if self.ALIGNMENT_ARCH is None or self._alignment == Alignment.Unaligned: return False offsetOk = self.ALIGNMENT_ARCH.checkAlignment(self._bbox[0].start) ldOk = self._stride[0] == 1 and (len(self._stride) == 1 or self.ALIGNMENT_ARCH.checkAlignment(self._stride[1])) @@ -99,7 +130,7 @@ def mayVectorizeDim(self, dim): return self.ALIGNMENT_ARCH.checkAlignment(self._bbox[dim].size()) @classmethod - def fromSpp(cls, spp, alignStride=False): + def fromSpp(cls, spp, alignStride=Alignment.Automatic): bbox = BoundingBox.fromSpp(spp) return cls(spp.shape, bbox, alignStride=alignStride) @@ -111,7 +142,7 @@ def permuted(self, permutation): originalBB = BoundingBox([self._range0] + self._bbox[1:]) if self._range0 else self._bbox newBB = BoundingBox([copy.copy(originalBB[p]) for p in permutation]) - return DenseMemoryLayout(newShape, newBB, alignStride=self._range0 is not None) + return DenseMemoryLayout(newShape, newBB, alignStride=self._alignment) def address(self, entry): assert entry in self._bbox @@ -240,11 +271,15 @@ def defuse(self, fusedRange, indices, I): stop -= B*s return ranges - def isCompatible(self, spp): - return BoundingBox.fromSpp(spp) in self.bbox() + def isCompatible(self, other, eqspp): + bb_contained = BoundingBox.fromSpp(eqspp) in self.bbox() + alignment_ok = self.alignedStride() == other.alignedStride() + return bb_contained and alignment_ok + def __eq__(self, other): - return self._stride == other._stride and self._bbox == other._bbox and self._stride == other._stride + return self._stride == other._stride and self._bbox == other._bbox and self._alignment == other._alignment + def __str__(self): return '{}(shape: {}, bounding box: {}, stride: {})'.format(type(self).__name__, self._shape, self._bbox, self._stride) @@ -324,8 +359,8 @@ def fromSpp(cls, spp, **kwargs): def __contains__(self, entry): return entry in self._bbox - def isCompatible(self, spp): - return self.fromSpp(spp) == self + def isCompatible(self, other, eqspp): + return other == self def __eq__(self, other): return self._bbox == other._bbox and np.array_equal(self._rowIndex, other._rowIndex) and np.array_equal(self._colPtr, other._colPtr) diff --git a/yateto/type.py b/yateto/type.py index a22e397..204c76b 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -1,7 +1,7 @@ import re from .ast.node import Node, IndexedTensor from numpy import ndarray, zeros, float64 -from .memory import DenseMemoryLayout +from .memory import DenseMemoryLayout, Alignment from . import aspp class AbstractType(object): @@ -35,7 +35,7 @@ def __init__(self, shape, spp=None, memoryLayoutClass=DenseMemoryLayout, - alignStride=False, + alignStride=Alignment.Automatic, namespace=None): if not isinstance(shape, tuple): raise ValueError('shape must be a tuple') @@ -77,7 +77,7 @@ def __init__(self, self.setMemoryLayout(memoryLayoutClass, alignStride) - def setMemoryLayout(self, memoryLayoutClass, alignStride=False): + def setMemoryLayout(self, memoryLayoutClass, alignStride=Alignment.Automatic): self._memoryLayout = memoryLayoutClass.fromSpp(self._groupSpp, alignStride=alignStride) def _setSparsityPattern(self, spp, setOnlyGroupSpp=False): @@ -90,7 +90,7 @@ def _setSparsityPattern(self, spp, setOnlyGroupSpp=False): def setGroupSpp(self, spp): self._setSparsityPattern(spp, setOnlyGroupSpp=True) - self.setMemoryLayout(self._memoryLayout.__class__, alignStride=self._memoryLayout.alignedStride()) + self.setMemoryLayout(self._memoryLayout.__class__, alignStride=self._memoryLayout.alignment()) def __getitem__(self, indexNames): return IndexedTensor(self, indexNames)