diff --git a/yateto/__init__.py b/yateto/__init__.py index 0b3dd50..48c3f0b 100644 --- a/yateto/__init__.py +++ b/yateto/__init__.py @@ -2,3 +2,4 @@ from .generator import NamespacedGenerator, Generator, simpleParameterSpace, parameterSpaceFromRanges from .arch import useArchitectureIdentifiedBy from .gemm_configuration import * +from .memory import * diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index f48b87e..16281b9 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -16,13 +16,13 @@ def variables(self): return {self} def maySubstitute(self, when, by): - return self.substituted(when, by).memoryLayout().isCompatible(self.eqspp()) + return self.substituted(when, by).memoryLayout().isCompatible(self.memoryLayout(), self.eqspp()) def substituted(self, when, by, memoryLayout=None): return by if self == when else self def resultCompatible(self, result): - return result.memoryLayout().isCompatible(self.eqspp()) + return result.memoryLayout().isCompatible(self.memoryLayout(), self.eqspp()) def isGlobal(self): return self.tensor is not None @@ -75,7 +75,7 @@ def variableList(self): def maySubstitute(self, when, by): layouts = [var.substituted(when, by).memoryLayout() for var in self._variables] - c1 = all(layouts[i].isCompatible(var.eqspp()) for i,var in enumerate(self._variables)) + c1 = all(layouts[i].isCompatible(var.memoryLayout(), var.eqspp()) for i,var in enumerate(self._variables)) c2 = self.node.argumentsCompatible(layouts) return c1 and c2 @@ -83,7 +83,7 @@ def substituted(self, when, by, memoryLayout): return Expression(self.node, memoryLayout, [var.substituted(when, by) for var in self._variables]) def resultCompatible(self, result): - c1 = result.memoryLayout().isCompatible(self.eqspp()) + c1 = result.memoryLayout().isCompatible(self.memoryLayout(), self.eqspp()) c2 = self.node.resultCompatible(result.memoryLayout()) return c1 and c2 diff --git a/yateto/memory.py b/yateto/memory.py index 0292044..8d9024c 100644 --- a/yateto/memory.py +++ b/yateto/memory.py @@ -1,10 +1,22 @@ from .ast.indices import BoundingBox, Range import copy +from enum import Enum import itertools import warnings import numpy as np from abc import ABC, abstractmethod +class Alignment(Enum): + """Alignment mode. + + Automatic: Assume aligned memory if stride is divisible by vector width. + Aligned: Pad the leading dimension with zeros such that stride is divisible by memory width. + Unaligned: Always assume unaligned memory access. + """ + Automatic = 0, + Aligned = 1, + Unaligned = 2 + class MemoryLayout(ABC): def __init__(self, shape): self._shape = shape @@ -45,7 +57,7 @@ def __eq__(self, other): pass @abstractmethod - def isCompatible(self, spp): + def isCompatible(self, other, eqspp): pass class DenseMemoryLayout(MemoryLayout): @@ -55,7 +67,15 @@ class DenseMemoryLayout(MemoryLayout): def setAlignmentArch(cls, arch): cls.ALIGNMENT_ARCH = arch - def __init__(self, shape, boundingBox=None, stride=None, alignStride=False): + def __init__(self, shape, boundingBox=None, stride=None, alignStride=Alignment.Automatic): + """Construct DenseMemoryLayout. + + :param shape: tensor shape (tuple of integers) + :param boundingBox: Non-zero BoundingBox, covers complete tensor if None + :param stride: Stride of the leading dimension, computed automatically if None + :param alignStride: Alignment mode. Passing False is equal to Alignment.Automatic and passing + True is equal to Alignment.Aligned. + """ super().__init__(shape) if boundingBox: @@ -63,8 +83,16 @@ def __init__(self, shape, boundingBox=None, stride=None, alignStride=False): else: self._bbox = BoundingBox([Range(0, s) for s in self._shape]) + if alignStride == True: + self._alignment = Alignment.Aligned + elif alignStride == False: + self._alignment = Alignment.Automatic + elif isinstance(alignStride, Alignment): + self._alignment = alignStride + else: + raise ValueError("Unknown type for option alignStride") self._range0 = None - if alignStride: + if self._alignment == Alignment.Aligned: self._alignBB() if stride: @@ -86,8 +114,11 @@ def _alignBB(self): else: warnings.warn('Set architecture with DenseMemoryLayout.setAlignmentArch(arch) if you want to use the align stride feature.', UserWarning) + def alignment(self): + return self._alignment + def alignedStride(self): - if self.ALIGNMENT_ARCH is None: + if self.ALIGNMENT_ARCH is None or self._alignment == Alignment.Unaligned: return False offsetOk = self.ALIGNMENT_ARCH.checkAlignment(self._bbox[0].start) ldOk = self._stride[0] == 1 and (len(self._stride) == 1 or self.ALIGNMENT_ARCH.checkAlignment(self._stride[1])) @@ -99,7 +130,7 @@ def mayVectorizeDim(self, dim): return self.ALIGNMENT_ARCH.checkAlignment(self._bbox[dim].size()) @classmethod - def fromSpp(cls, spp, alignStride=False): + def fromSpp(cls, spp, alignStride=Alignment.Automatic): bbox = BoundingBox.fromSpp(spp) return cls(spp.shape, bbox, alignStride=alignStride) @@ -111,7 +142,7 @@ def permuted(self, permutation): originalBB = BoundingBox([self._range0] + self._bbox[1:]) if self._range0 else self._bbox newBB = BoundingBox([copy.copy(originalBB[p]) for p in permutation]) - return DenseMemoryLayout(newShape, newBB, alignStride=self._range0 is not None) + return DenseMemoryLayout(newShape, newBB, alignStride=self._alignment) def address(self, entry): assert entry in self._bbox @@ -237,11 +268,15 @@ def defuse(self, fusedRange, indices, I): stop -= B*s return ranges - def isCompatible(self, spp): - return BoundingBox.fromSpp(spp) in self.bbox() + def isCompatible(self, other, eqspp): + bb_contained = BoundingBox.fromSpp(eqspp) in self.bbox() + alignment_ok = self.alignedStride() == other.alignedStride() + return bb_contained and alignment_ok + def __eq__(self, other): - return self._stride == other._stride and self._bbox == other._bbox and self._stride == other._stride + return self._stride == other._stride and self._bbox == other._bbox and self._alignment == other._alignment + def __str__(self): return '{}(shape: {}, bounding box: {}, stride: {})'.format(type(self).__name__, self._shape, self._bbox, self._stride) @@ -321,8 +356,8 @@ def fromSpp(cls, spp, **kwargs): def __contains__(self, entry): return entry in self._bbox - def isCompatible(self, spp): - return self.fromSpp(spp) == self + def isCompatible(self, other, eqspp): + return other == self def __eq__(self, other): return self._bbox == other._bbox and np.array_equal(self._rowIndex, other._rowIndex) and np.array_equal(self._colPtr, other._colPtr) diff --git a/yateto/type.py b/yateto/type.py index fba8f3b..e8c52f0 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -1,7 +1,7 @@ import re from .ast.node import Node, IndexedTensor from numpy import ndarray, zeros, float64 -from .memory import DenseMemoryLayout +from .memory import DenseMemoryLayout, Alignment from . import aspp class AbstractType(object): @@ -80,7 +80,7 @@ def __init__(self, shape, spp=None, memoryLayoutClass=DenseMemoryLayout, - alignStride=False, + alignStride=Alignment.Automatic, namespace=None): super().__init__(name, namespace=namespace) if not isinstance(shape, tuple): @@ -126,7 +126,7 @@ def __init__(self, def __hash__(self): return hash(self._name) - def setMemoryLayout(self, memoryLayoutClass, alignStride=False): + def setMemoryLayout(self, memoryLayoutClass, alignStride=Alignment.Automatic): self._memoryLayout = memoryLayoutClass.fromSpp(self._groupSpp, alignStride=alignStride) def _setSparsityPattern(self, spp, setOnlyGroupSpp=False): @@ -139,7 +139,7 @@ def _setSparsityPattern(self, spp, setOnlyGroupSpp=False): def setGroupSpp(self, spp): self._setSparsityPattern(spp, setOnlyGroupSpp=True) - self.setMemoryLayout(self._memoryLayout.__class__, alignStride=self._memoryLayout.alignedStride()) + self.setMemoryLayout(self._memoryLayout.__class__, alignStride=self._memoryLayout.alignment()) def __getitem__(self, indexNames): return IndexedTensor(self, indexNames)