diff --git a/.github/workflows/yateto-cpu.yml b/.github/workflows/yateto-cpu.yml index e9c719a..7ccfa5b 100644 --- a/.github/workflows/yateto-cpu.yml +++ b/.github/workflows/yateto-cpu.yml @@ -73,7 +73,7 @@ jobs: - name: Codegen Tests run: | cd ./tests/code-gen - for example in matmul minimal indices slicing; do + for example in matmul minimal indices slicing elementwise reduction datatype conditional; do for build_type in Debug Release; do for precision in single double; do echo " ====== Test Config: ======" diff --git a/include/yateto.h b/include/yateto.h index bee5074..ccc005e 100644 --- a/include/yateto.h +++ b/include/yateto.h @@ -5,5 +5,6 @@ #include "yateto/LinearAllocator.h" #include "yateto/Misc.h" #include "yateto/TensorView.h" +#include "yateto/Type.h" #endif diff --git a/include/yateto/LinearAllocator.h b/include/yateto/LinearAllocator.h index a0c41fd..cfc45b5 100644 --- a/include/yateto/LinearAllocator.h +++ b/include/yateto/LinearAllocator.h @@ -13,6 +13,12 @@ struct LinearAllocatorT { userSpaceMem = ptr; } + template + void initialize(S* ptr) { + isInit = true; + userSpaceMem = reinterpret_cast(ptr); + } + T* allocate(size_t size) { assert(isInit && "YATETO: Temporary-Memory manager hasn't been initialized"); int currentByteCount = byteCount; diff --git a/include/yateto/Type.h b/include/yateto/Type.h new file mode 100644 index 0000000..1e1a798 --- /dev/null +++ b/include/yateto/Type.h @@ -0,0 +1,39 @@ +#ifndef YATETO_TYPE_H_ +#define YATETO_TYPE_H_ + +#include + +// C++23 include +#if __has_include() +#include +#endif + +// cf. https://stackoverflow.com/a/70868019 +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include + +namespace yateto { + +#ifdef __STDCPP_FLOAT128_T__ +using f128_ty = std::float128_t; +#elif defined(FLT128_MIN) +using f128_ty = _Float128; +#else +using f128_ty = __float128; +#endif +#ifdef __STDCPP_FLOAT16_T__ +using f16_ty = std::float16_t; +#elif defined(FLT16_MIN) +using f16_ty = _Float16; +#else +using f16_ty = __fp16; +#endif +#ifdef __STDCPP_BFLOAT16_T__ +using bf16_ty = std::bfloat16_t; +#else +using bf16_ty = __bf16; +#endif + +} // namespace yateto + +#endif // YATETO_TYPE_H_ diff --git a/tests/code-gen/conditional.py b/tests/code-gen/conditional.py new file mode 100644 index 0000000..7295a26 --- /dev/null +++ b/tests/code-gen/conditional.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A = Tensor('A', (N, N)) + B = Tensor('B', (N, N)) + C = Tensor('C', (N, N)) + + AI = Tensor('AI', (N, N), datatype=Datatype.I32) + BI = Tensor('BI', (N, N), datatype=Datatype.I32) + CI = Tensor('CI', (N, N), datatype=Datatype.I32) + + AB = Tensor('AB', (N, N), datatype=Datatype.BOOL) + + X = Tensor('X', (), datatype=Datatype.BOOL) + X1 = Tensor('X1', (), datatype=Datatype.BOOL) + X2 = Tensor('X2', (), datatype=Datatype.BOOL) + X3 = Tensor('X3', (), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(yf.assignIf(X[''], A['ij'], yf.sqrt(B['ij']))) + _(yf.assignIf(yf.all(AB['ij'], 'ij'), A['ij'], yf.sqrt(B['ij']))) + _([ + yf.assignIf(X[''], A['ij'], yf.sqrt(B['ij'])), + yf.assignIf(X[''], AI['ij'], -BI['ij']) + ]) + _([ + yf.assignIf(X1[''], A['ij'], B['ik'] * C['kj'] + C['ij']), + yf.assignIf(X1[''], A['ij'], A['ij'] + B['ik'] * C['kj'] + C['ij']), + yf.assignIf(X2[''], C['ij'], yf.sqrt(B['ij'])) + ]) diff --git a/tests/code-gen/datatype.py b/tests/code-gen/datatype.py new file mode 100644 index 0000000..c2f7581 --- /dev/null +++ b/tests/code-gen/datatype.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A = Tensor('A', (N, N)) + B = Tensor('B', (N, N)) + C = Tensor('C', (N, N)) + + AI = Tensor('AI', (N, N), datatype=Datatype.I32) + BI = Tensor('BI', (N, N), datatype=Datatype.I32) + CI = Tensor('CI', (N, N), datatype=Datatype.I32) + + AB = Tensor('AB', (N, N), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(AI['ij'] <= yf.cast(A['ij'], Datatype.I32)) diff --git a/tests/code-gen/elementwise.py b/tests/code-gen/elementwise.py new file mode 100644 index 0000000..a289d3e --- /dev/null +++ b/tests/code-gen/elementwise.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A = Tensor('A', (N, N)) + B = Tensor('B', (N, N)) + C = Tensor('C', (N, N)) + + AI = Tensor('AI', (N, N), datatype=Datatype.I32) + BI = Tensor('BI', (N, N), datatype=Datatype.I32) + CI = Tensor('CI', (N, N), datatype=Datatype.I32) + + AB = Tensor('AB', (N, N), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(A['ij'] <= yf.sqrt(B['ij'])) + _(A['ij'] <= yf.sqrt(B['ij']) + yf.sin(C['ij'])) + _(A['ij'] <= yf.sqrt(B['ij']) * yf.sin(C['ij'])) + _(A['ij'] <= yf.minimum(B['ij'], C['ij'])) + _(A['ij'] <= yf.minimum(B['ij'], C['ij'] + yf.atanh(B['ij']))) + + _(AI['ij'] <= BI['ij'] + CI['ij']) + _(AI['ij'] <= yf.bitwise_and(BI['ij'], CI['ij'])) + + _(AB['ij'] <= yf.greater_equal(BI['ij'], CI['ij'])) + _(A['ij'] <= yf.where(yf.greater_equal(BI['ij'], CI['ij']), B['ij'], C['ij'])) + + _(AI['ij'] <= yf.cast(A['ij'], Datatype.I32)) diff --git a/tests/code-gen/reduction.py b/tests/code-gen/reduction.py new file mode 100644 index 0000000..420a246 --- /dev/null +++ b/tests/code-gen/reduction.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A0 = Tensor('A0', ()) + A1 = Tensor('A1', (N,)) + A2 = Tensor('A2', (N, N)) + + AI0 = Tensor('AI0', (), datatype=Datatype.I32) + AI1 = Tensor('AI1', (N,), datatype=Datatype.I32) + AI2 = Tensor('AI2', (N, N), datatype=Datatype.I32) + + AB0 = Tensor('AB0', (), datatype=Datatype.BOOL) + AB1 = Tensor('AB1', (N,), datatype=Datatype.BOOL) + AB2 = Tensor('AB2', (N, N), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(A0[''] <= yf.sum(A1['i'], 'i')) + _(A0[''] <= yf.sum(A2['ij'], 'ij')) + _(A0[''] <= yf.min(A2['ij'], 'ij')) + + _(AI0[''] <= yf.sum(AI2['ij'], 'ij')) + _(AI0[''] <= yf.min(AI2['ij'], 'ij')) + _(AI0[''] <= yf.all(AI2['ij'], 'ij')) + _(AI0[''] <= yf.any(AI2['ij'], 'ij')) + + _(AB0[''] <= yf.all(AB2['ij'], 'ij')) + _(AB0[''] <= yf.any(AB2['ij'], 'ij')) diff --git a/yateto/arch.py b/yateto/arch.py index ff99cbc..3fd2f76 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -38,6 +38,7 @@ # from .memory import DenseMemoryLayout +from .type import Datatype from collections import namedtuple from typing import Union import re @@ -69,19 +70,20 @@ def __init__(self, self.host_name = host_name self.precision = precision.upper() - if self.precision == 'D': - self.bytesPerReal = 8 - self.typename = 'double' - self.epsilon = 2.22e-16 + if self.precision == 'Q': + self.epsilon = 2**-112 + self.datatype = Datatype.F128 + elif self.precision == 'D': + self.epsilon = 2**-52 + self.datatype = Datatype.F64 elif self.precision == 'S': - self.bytesPerReal = 4 - self.typename = 'float' - self.epsilon = 1.19e-7 + self.epsilon = 2**-23 + self.datatype = Datatype.F32 else: raise ValueError(f'Unknown precision type {self.precision}') self.alignment = alignment - assert self.alignment % self.bytesPerReal == 0 - self.alignedReals = self.alignment // self.bytesPerReal + assert self.alignment % self.datatype.size() == 0 + self.alignedReals = self.alignment // self.datatype.size() self.enablePrefetch = enablePrefetch self.uintTypename = 'unsigned' @@ -110,10 +112,10 @@ def checkAlignment(self, offset): return offset % self.alignedReals == 0 def formatConstant(self, constant): - return str(constant) + ('f' if self.precision == 'S' else '') + return self.datatype.literal(constant) - def onHeap(self, numReals): - return (numReals * self.bytesPerReal) > self._tmpStackLimit + def onHeap(self, byteCount): + return byteCount > self._tmpStackLimit def __eq__(self, other): return self.name == other.name diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 435ca9f..4d3e2b0 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -2,7 +2,7 @@ from ..memory import DenseMemoryLayout from .indices import BoundingBox, Indices, LoGCost from abc import ABC, abstractmethod -from .. import aspp +from .. import aspp, ops import numpy as np class Node(ABC): @@ -10,6 +10,7 @@ def __init__(self): self.indices = None self._children = [] self._eqspp = None + self.datatype = None self.prefetch = None def size(self): @@ -108,7 +109,7 @@ def __rmul__(self, other): def __add__(self, other): if not isinstance(other, Node): - raise ValueError('Unsupported operation: Cannot add {} to {}.'.format(self, other)) + raise ValueError(f'Unsupported operation: Cannot add {self} to {other}.') return self._binOp(other, Add) def __radd__(self, other): @@ -124,6 +125,12 @@ def __sub__(self, other): def __le__(self, other): return Assign(self, other) + def __truediv__(self, other): + return Elementwise(ops.Div(), self, other) + + def __rtruediv__(self, other): + return Elementwise(ops.Div(), other, self) + def subslice(self, index, start, end): return SliceView(self, index, start, end) @@ -209,7 +216,7 @@ def __deepcopy__(self, memo): return it def __str__(self): - return '{}[{}]'.format(self.tensor.name(), str(self.indices)) + return f'{self.tensor.name()}[{str(self.indices)}]' class Op(Node): def __init__(self, *args): @@ -227,7 +234,7 @@ def computeMemoryLayout(self): alignStride = False alignOffset = float('inf') - if len(self.indices) > 0: + if self.indices is not None and len(self.indices) > 0: for child in self: if self.indices[0] in child.indices: position = child.indices.find(self.indices[0]) @@ -342,7 +349,24 @@ def setChildren(self, children): raise ValueError('BinOp node must have exactly 2 children.') super().setChildren(children) -class Assign(BinOp): +class Assign(Op): + def __init__(self, lTerm, rTerm, condition=True): + if isinstance(condition, Node): + super().__init__(lTerm, rTerm, condition) + else: + super().__init__(lTerm, rTerm) + + self._condition = condition + + def leftTerm(self): + return self._children[0] + + def rightTerm(self): + return self._children[1] + + def condition(self): + return self._condition + def setChildren(self, children): if not isinstance(children[0].viewed(), IndexedTensor): raise ValueError('First child of Assign node must be an IndexedTensor: ' + str(children[0].viewed())) @@ -352,9 +376,15 @@ def nonZeroFlops(self): return 0 def computeSparsityPattern(self, *spps): - spp = spps[1] if len(spps) == 2 else self.rightTerm().eqspp() + spp = spps[1] if len(spps) >= 2 else self.rightTerm().eqspp() return self.broadcast(self.rightTerm().indices, self.permute(self.rightTerm().indices, spp, False)) + def __str__(self): + selfname = type(self).__name__ + indices = self.indices if self.indices != None else '' + condition = '' if isinstance(self.condition(), bool) and self.condition() else f' if {self.condition()}' + return f'{selfname}[{indices}]: {self.leftTerm()} <- {self.rightTerm()}{condition}' + class Permute(UnaryOp): # permute a given tensor @@ -553,7 +583,6 @@ def is_pure_gemm(self): return True if len(left_indices - right_indices) == 1 else False - class FusedGEMMs(Op): def __init__(self): super().__init__() @@ -578,3 +607,118 @@ def nonZeroFlops(self): def is_empty(self): return len(self._children) == 0 + +class IfThenElse(Op): + def __init__(self, condition, yesTerm, noTerm): + if isinstance(condition, Node): + super().__init__(yesTerm, noTerm, condition) + else: + super().__init__(yesTerm, noTerm) + + self._condition = condition + + def condition(self): + return condition + + def nonZeroFlops(self): + return 0 + + def computeSparsityPattern(self, *spps): + # TODO: yesTerm OR noTerm + spp = spps[0] if len(spps) >= 2 else self.term().eqspp() + return spp + + def __str__(self): + indices = self.indices if self.indices != None else '' + return f'{type(self).__name__}[{indices}]' + +class Elementwise(Op): + def __init__(self, optype: ops.Operation, *terms): + nodeTerms = [term for term in terms if isinstance(term, Node)] + super().__init__(*nodeTerms) + + self.nodeTermIndices = [None] * len(terms) + self.termTemplate = [None] * len(terms) + index = 0 + for i, term in enumerate(terms): + if isinstance(term, Node): + self.nodeTermIndices[i] = index + index += 1 + else: + self.nodeTermIndices[i] = None + self.termTemplate[i] = term + + self.optype = optype + self.terms = terms + + self.indices = Indices() + for nodeTerm in nodeTerms: + nodeIndices = nodeTerm.indices if nodeTerm.indices is not None else Indices() + K = self.indices & nodeIndices + # assert self.indices.subShape(K) == nodeTerm.subShape(K) + self.indices = self.indices.merged(nodeIndices - K) + + def nonZeroFlops(self): + return self.eqspp().count_nonzero() + + def fillTerms(self, terms): + assert len(terms) == len(self) + return [terms[index] if template is None else template for template, index in zip(self.termTemplate, self.nodeTermIndices)] + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + xspp = spps[0] + return spps[0] + + def __str__(self): + indices = self.indices if self.indices != None else '' + return f'{type(self).__name__}({self.optype})[{indices}]' + +class Reduction(UnaryOp): + def __init__(self, optype, term, sumIndex): + # TODO: what if we datatype/field does not match the operation? (w.r.t. the sparsity patterns) + super().__init__(term) + self.indices = term.indices - set([sumIndex]) + self._reductionIndex = term.indices.extract(sumIndex) + self.optype = optype + + def nonZeroFlops(self): + return self.term().eqspp().count_nonzero() - self.eqspp().count_nonzero() + + def reductionIndex(self): + return self._reductionIndex + + def reductionIndices(self): + return [self._reductionIndex] + + def computeSparsityPattern(self, *spps): + assert len(spps) <= 1 + spp = spps[0] if len(spps) == 1 else self.term().eqspp() + return spp.indexSum(self.term().indices, self.indices) + + def __str__(self): + indices = self.indices if self.indices != None else '' + return f'{type(self).__name__}({self.optype})[{indices}]' + +class Accumulate(Op): + def __init__(self, optype, *operands): + super().__init__(*operands) + + self.optype = optype + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + permute_summand = lambda i: self.permute(self[i].indices, spps[i]) + spp = permute_summand(0) + for i in range(1, len(spps)): + add_spp = permute_summand(i) + spp = aspp.add(spp, add_spp) + return spp + + def nonZeroFlops(self): + nzFlops = 0 + for child in self: + nzFlops += child.eqspp().count_nonzero() + return nzFlops - self.eqspp().count_nonzero() diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index 505e6e4..13385b7 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -34,16 +34,16 @@ def visit(self, node, bound=None): elif isinstance(self._targetIndices, Indices): node.indices = self._targetIndices else: - raise ValueError('Target indices type ({}) is not supported.'.format(self._targetIndices.__class__.__name__)) + raise ValueError(f'Target indices type ({self._targetIndices.__class__.__name__}) is not supported.') if not (node.indices <= oldIndices and oldIndices <= node.indices): - raise ValueError('Target index dimensions do not match: {} != {}'.format(node.indices.__repr__(), oldIndices.__repr__())) + raise ValueError(f'Target index dimensions do not match: {node.indices.__repr__()} != {oldIndices.__repr__()}') return node def visit_IndexedTensor(self, node, bound): if set(node.indices) > bound: free = node.indices - bound - raise ValueError('The indices {} are not bound in {}.'.format(free.__repr__(), node)) + raise ValueError(f'The indices {free.__repr__()} are not bound in {node}.') return node def visit_Einsum(self, node, bound): @@ -98,6 +98,17 @@ def visit_ScalarMultiplication(self, node, bound): node.indices = deepcopy(node.term().indices) return node + def visit_Elementwise(self, node, bound): + for child in node: + self.visit(child, bound) + node.indices = deepcopy(node[0].indices) + return node + + def visit_Reduction(self, node, bound): + subbound = bound | set(node.reductionIndices()) + self.visit(node.term(), subbound) + return node + def visit_SliceView(self, node, bound): self.visit(node.term(), bound) node.indices = Indices(node.term().indices, [shape if index != node.index else (node.end - node.start) for index, shape in zip(node.term().indices, node.term().shape())]) @@ -122,7 +133,7 @@ def visit_Assign(self, node, bound): node.indices = lhs.indices if not (rhs.indices <= lhs.indices): - raise ValueError('Index dimensions do not match: {} != {}'.format(lhs.indices.__repr__(), rhs.indices.__repr__())) + raise ValueError(f'Index dimensions do not match: {lhs.indices.__repr__()} != {rhs.indices.__repr__()}') return node @@ -219,6 +230,16 @@ def visit_Assign(self, node): node.setEqspp( node.computeSparsityPattern() ) return node + def visit_Elementwise(self, node): + self.generic_visit(node) + node.setEqspp( node.computeSparsityPattern() ) + return node + + def visit_Reduction(self, node): + self.generic_visit(node) + node.setEqspp( node.computeSparsityPattern() ) + return node + def getEqspp(self, terms, targetIndices): # Shortcut if all terms have dense eqspps if all(term.eqspp().is_dense() for term in terms): @@ -266,3 +287,51 @@ def generic_visit(self, node): def visit_IndexedTensor(self, node): return node + +class SetDatatype1(Transformer): + def __init__(self, arch): + self.arch = arch + + def generic_visit(self, node): + super().generic_visit(node) + assert(len(node) > 0) + assert(all(child.datatype == node[0].datatype for child in node)) + node.datatype = node[0].datatype + return node + + def visit_IndexedTensor(self, node): + super().generic_visit(node) + node.datatype = node.tensor.getDatatype(self.arch) + return node + + def visit_Elementwise(self, node): + super().generic_visit(node) + node.datatype = node.optype.datatypeResult([c.datatype for c in node]) + return node + + def visit_Assign(self, node): + super().generic_visit(node) + node.datatype = node[0].datatype + return node + +class SetDatatype2(Transformer): + def generic_visit(self, node): + super().generic_visit(node) + assert(len(node) > 0) + assert(all(child.datatype == node[0].datatype for child in node)) + node.datatype = node[0].datatype + return node + + def visit_IndexedTensor(self, node): + super().generic_visit(node) + return node + + def visit_Elementwise(self, node): + super().generic_visit(node) + node.datatype = node.optype.datatypeResult([c.datatype for c in node]) + return node + + def visit_Assign(self, node): + super().generic_visit(node) + node.datatype = node[0].datatype + return node diff --git a/yateto/ast/visitor.py b/yateto/ast/visitor.py index c333d58..12106b9 100644 --- a/yateto/ast/visitor.py +++ b/yateto/ast/visitor.py @@ -1,4 +1,4 @@ -from numpy import ndindex, arange, float64, add, einsum +from numpy import ndindex, arange, float64, add, einsum, apply_along_axis import math import collections import itertools @@ -110,29 +110,31 @@ def variantsFixedRootPermutation(self, node, fixedPerm, permutationVariants): variants[fixedPerm] = self.Variant(minCost, minInd) return variants - def allPermutationsNoCostBinaryOp(self, node): + def allPermutationsNoCostNAryOp(self, node): permutationVariants = self.findVariants(node) - lV = permutationVariants[node.leftTerm()] - rV = permutationVariants[node.rightTerm()] + V = [permutationVariants[child] for child in node] minCost = LoGCost() - minAind = None - minBind = None - for Aind in sorted(lV): - for Bind in sorted(rV): - cost = lV[Aind]._cost + rV[Bind]._cost - if cost < minCost: - minCost = cost - minAind = Aind - minBind = Bind - assert minAind is not None and minBind is not None + minInd = None + + for ind in itertools.product(*V): + + cost = sum((V[i][Vind]._cost for i,Vind in enumerate(ind)), + LoGCost.addIdentity()) + + if cost < minCost: + minCost = cost + minInd = ind + + assert minInd is not None + iterator = itertools.permutations(node.indices) - permutationVariants[node] = {''.join(Cs): self.Variant(minCost, [minAind, minBind]) for Cs in iterator} + permutationVariants[node] = {''.join(Cs): self.Variant(minCost, list(minInd)) for Cs in iterator} return permutationVariants def generic_visit(self, node): permutationVariants = self.findVariants(node) variants = self.variantsFixedRootPermutation(node, str(node.indices), permutationVariants) - assert variants, 'Could not find implementation for {}.'.format(type(node)) + assert variants, f'Could not find implementation for {node}.' permutationVariants[node] = variants return permutationVariants @@ -152,7 +154,10 @@ def visit_ScalarMultiplication(self, node): return permutationVariants def visit_Product(self, node): - return self.allPermutationsNoCostBinaryOp(node) + return self.allPermutationsNoCostNAryOp(node) + + def visit_Elementwise(self, node): + return self.allPermutationsNoCostNAryOp(node) def visit_IndexSum(self, node): permutationVariants = self.findVariants(node) @@ -169,6 +174,9 @@ def visit_IndexSum(self, node): permutationVariants[node] = {''.join(Cs): self.Variant(minCost, [minTind]) for Cs in iterator} return permutationVariants + def visit_Reduction(self, node): + return self.visit_IndexSum(node) + def visit_Contraction(self, node): permutationVariants = self.findVariants(node) @@ -299,28 +307,43 @@ def generic_visit(self, node): def visit_Einsum(self, node): terms = self.generic_visit(node) childIndices = [child.indices for child in node] - assert None not in childIndices and node.indices is not None, 'Use DeduceIndices before {}.'.format(self.__class__.__name__) + assert None not in childIndices and node.indices is not None, f'Use DeduceIndices before {self.__class__.__name__}.' einsumDescription = ','.join(indices.tostring() for indices in childIndices) einsumDescription = '{}->{}'.format(einsumDescription, node.indices.tostring()) return einsum(einsumDescription, *terms) def visit_Add(self, node): terms = self.generic_visit(node) - assert len(terms) > 1 permute = lambda indices, tensor: tensor.transpose(tuple(indices.find(idx) for idx in node.indices)) return reduce(add, [permute(child.indices, terms[i]) for i,child in enumerate(node)]) def visit_ScalarMultiplication(self, node): - assert node.is_constant() is not None, '{} may only be used when all involved scalars are constant.'.format(self.__class__.__name__) + assert node.is_constant() is not None, f'{self.__class__.__name__} may only be used when all involved scalars are constant.' terms = self.generic_visit(node) assert len(terms) == 1 return node.scalar() * terms[0] def visit_IndexedTensor(self, node): term = node.tensor.values_as_ndarray(self._dtype) - assert term is not None, '{} may only be used when all involved tensors are constant.'.format(self.__class__.__name__) + assert term is not None, f'{self.__class__.__name__} may only be used when all involved tensors are constant.' return term + def visit_Elementwise(self, node): + terms = self.generic_visit(node) + fullTerms = node.fillTerms(terms) + return node.optype.call(*fullTerms) + + def visit_Reduction(self, node): + terms = self.generic_visit(node) + assert len(terms) == 1 + indexpos = node.term().indices.find(node.reductionIndex()) + return np.apply_along_axis(lambda x: reduce(node.optype.call, x), indexpos, terms[0]) + + def visit_Accumulate(self, node): + terms = self.generic_visit(node) + permute = lambda indices, tensor: tensor.transpose(tuple(indices.find(idx) for idx in node.indices)) + return reduce(node.optype.call, [permute(child.indices, terms[i]) for i,child in enumerate(node)]) + class ComputeIndexSet(CachedVisitor): def generic_visit(self, node): union = set() diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index 626bbb8..1537137 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -1,11 +1,11 @@ from __future__ import annotations from .. import aspp +from ..type import AddressingMode from ..ast.indices import BoundingBox from ..ast.log import splitByDistance -from .tiny_tensor_language import Dump, Function, IntegerType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst +from .tiny_tensor_language import Dump, Function, IntegerType, FloatingType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst import hashlib - class TensorDescription(object): def __init__(self, name, memoryLayout, eqspp, is_compute_constant=False, is_temporary=False, values=None, datatype=None, addressing=None): """ @@ -18,6 +18,9 @@ def __init__(self, name, memoryLayout, eqspp, is_compute_constant=False, is_temp elements are known at compile time is_temporary (bool): if true then the description is for a temporary tensor which usually results from a result of an intermediate computation + values (Union[np.ndarray, None]): the values of the compute_constant tensor, if they are known at compile time + datatype (Datatype): the datatype of the tensor elements + addressing (AddressingMode): the addressing mode for the tensor """ self.name = name self.memoryLayout = memoryLayout @@ -39,33 +42,33 @@ def __init__(self, name, indices, memoryLayout, eqspp, is_compute_constant=False @classmethod def fromNode(cls, var, node): + baseNode = node.viewed() + datatype = baseNode.datatype + is_const = False values = None - datatype = None addressing = None - baseNode = node.viewed() if hasattr(baseNode, 'tensor'): is_const = baseNode.tensor.is_compute_constant() if is_const: values = baseNode.tensor.values() - datatype = None # node.tensor.datatype - addressing = None # node.tensor.addressing + addressing = baseNode.tensor.addressing return cls(str(var), node.indices, var.memoryLayout(), node.eqspp(), is_const, var.is_temporary, values, datatype, addressing) @classmethod def fromVar(cls, var, indices): + datatype = var.datatype + is_const = False values = None - datatype = None addressing = None if hasattr(var, 'tensor'): if var.tensor is not None: is_const = var.tensor.is_compute_constant() if is_const: values = var.tensor.values() - datatype = None # var.tensor.datatype - addressing = None # var.tensor.addressing + addressing = var.tensor.addressing return cls(str(var), indices, var.memoryLayout(), var.eqspp(), is_const, var.is_temporary, values, datatype, addressing) def forLoops(cpp, indexNames, ranges, body, pragmaSimd=True, prefix='_', indexNo=None): @@ -103,17 +106,17 @@ def boundingBoxFromLoopRanges(indices, loopRanges): def reduceSpp(spp, sourceIndices, targetIndices): return spp.indexSum(sourceIndices, targetIndices) -def initializeWithZero(cpp, arch, result: TensorDescription, writeBB = None): +def initializeWithZero(cpp, result: TensorDescription, writeBB = None): if writeBB: addresses = sorted(result.memoryLayout.notWrittenAddresses(writeBB)) if len(addresses) > 0: regions = splitByDistance(addresses) for region in regions: m, M = min(region), max(region) - initialAddress = '{} + {}'.format(result.name, m) - cpp.memset(initialAddress, M-m+1, arch.typename) + initialAddress = f'{result.name} + {m}' + cpp.memset(initialAddress, M-m+1, result.datatype.ctype()) else: - cpp.memset(result.name, result.memoryLayout.requiredReals(), arch.typename) + cpp.memset(result.name, result.memoryLayout.requiredReals(), result.datatype.ctype()) class BatchedOperationsAux: @@ -123,30 +126,37 @@ class BatchedOperationsAux: FLAGS_NAME = 'flags' FORBIDDEN_STREAM_PTR = 'reinterpret_cast(std::numeric_limits::max())' - def __init__(self, underlying_data_type): - self.underlying_data_type = underlying_data_type + @classmethod + def _get_ptr_type(cls, addressing: AddressingMode): + return addressing.pointer_type() - def _get_ptr_type(self, addressing): - return '**' if addressing == 'pointer_based' else '*' + @classmethod + def deduce_addresing(cls, term): + if term.addressing is not None: + return term.addressing - def deduce_addresing(self, term): + # default deduction if term.is_compute_constant: - return 'none' + return AddressingMode.DIRECT if term.is_temporary: - return 'strided' + return AddressingMode.STRIDED else: - return 'pointer_based' + return AddressingMode.INDIRECT - def deduce_ptr_arg(self, term, as_const=False): + @classmethod + def deduce_ptr_arg(cls, term, as_const=False): if as_const: - addressing = self.deduce_addresing(term) - ptr = self._get_ptr_type(addressing) - const_ptr_type = f'const {self.underlying_data_type} {ptr}' + addressing = cls.deduce_addresing(term) + ptr = cls._get_ptr_type(addressing) + assert term.datatype is not None + datatype = term.datatype.ctype() + const_ptr_type = f'const {datatype} {ptr}' return f'const_cast<{const_ptr_type}>({term.name})' else: return f'{term.name}' - def deduce_offset_arg(self, term): + @classmethod + def deduce_offset_arg(cls, term): if term.is_compute_constant or term.is_temporary: return '0' else: @@ -154,11 +164,12 @@ def deduce_offset_arg(self, term): class TinytcKernelArgument: - def __init__(self, name: str, call_expr: str, constant: bool, temporary: bool, modified: bool, offset: int = 0): + def __init__(self, name: str, datatype: str, call_expr: str, constant: bool, temporary: bool, modified: bool, offset: int = 0): """Kernel argument for TinytcWrapper. Arguments: name -- Argument name + datatype -- Argument datatype in C/C++ call_expr -- Expression used in calling wrapper constant -- Whether a tensor is invariant to group id temporary -- Whether a tensor is stored in a temporary buffer @@ -179,7 +190,7 @@ def __init__(self, name: str, call_expr: str): class TinytcWrapper: - def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | TinytcScalarKernelArgument], real_type: str, name: str = ''): + def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | TinytcScalarKernelArgument], name: str = ''): self.kernel_name = kernel.name self.source = Dump().visit(kernel) if name: @@ -194,13 +205,13 @@ def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | Tiny self.call_args = [] for arg in arguments: if isinstance(arg, TinytcScalarKernelArgument): - self.wrapper_args.append(f'{real_type} {arg.name}') + self.wrapper_args.append(f'{arg.datatype} {arg.name}') self.wrapper_call_args.append(arg.name) self.call_args.append(arg.call_expr) else: ptr2ptr = '*' if not (arg.constant or arg.temporary) else '' const = ' const' if not (arg.modified or arg.temporary) else '' - wrapper_type = f'{real_type}{const}*{ptr2ptr}' + wrapper_type = f'{arg.datatype}{const}*{ptr2ptr}' self.wrapper_args.append(f'{wrapper_type} {arg.name}') self.wrapper_call_args.append(arg.name) self.call_args.append(f'const_cast<{wrapper_type}>({arg.call_expr})') @@ -285,3 +296,25 @@ def makeLoad(bb, operand, gid, isComputeConstant: bool, isTemporary: bool): return bb.add(SubviewInst(operand, offsetList, sizeList)) else: return bb.add(LoadInst(operand, [gid])) + +def toTinyTCType(datatype: Datatype): + return { + Datatype.BOOL: ScalarType(IntegerType.i1), # presumably, maybe i8 + Datatype.I8: ScalarType(IntegerType.i8), + Datatype.I16: ScalarType(IntegerType.i16), + Datatype.I32: ScalarType(IntegerType.i32), + Datatype.I64: ScalarType(IntegerType.i64), + Datatype.F32: ScalarType(FloatingType.f32), + Datatype.F64: ScalarType(FloatingType.f64) + }[datatype] + +def toTinyTCImmediate(datatype: Datatype, value): + immtype = { + Datatype.BOOL: lambda value: IntImmValue(IntegerType.i1, value), + Datatype.I8: lambda value: IntImmValue(IntegerType.i8, value), + Datatype.I16: lambda value: IntImmValue(IntegerType.i16, value), + Datatype.I32: lambda value: IntImmValue(IntegerType.i32, value), + Datatype.I64: lambda value: IntImmValue(IntegerType.i64, value), + Datatype.F32: lambda value: FloatImmValue(FloatingType.f32, value), + Datatype.F64: lambda value: FloatImmValue(FloatingType.f64, value), + }[datatype](value) diff --git a/yateto/codegen/copyscaleadd/csa_gen.py b/yateto/codegen/copyscaleadd/csa_gen.py index 8f8c6d5..fad4c81 100644 --- a/yateto/codegen/copyscaleadd/csa_gen.py +++ b/yateto/codegen/copyscaleadd/csa_gen.py @@ -53,9 +53,9 @@ def _formatTerm(self, alpha, term): if alpha == 1.0: prefix = term.name else: - prefix = '{} * {}'.format(alpha, term.name) + prefix = f'{alpha} * {term.name}' - return '{}[{}]'.format(prefix, term.memoryLayout.addressString(term.indices)) + return f'{prefix}[{term.memoryLayout.addressString(term.indices)}]' def generate(self, cpp, routineCache): """Generates a tensor equation of a form: B = beta * B + alpha * A @@ -72,7 +72,7 @@ def generate(self, cpp, routineCache): n = d.loopRanges[d.result.indices[1]] alpha = d.alpha - aux = BatchedOperationsAux(self._arch.typename) + aux = BatchedOperationsAux() matrix_a = gf.YatetoInterface.produce_dense_matrix((m, n), d.term.memoryLayout.bbox(), addressing=aux.deduce_addresing(d.term), @@ -84,7 +84,7 @@ def generate(self, cpp, routineCache): transpose=False) try: - vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=self._arch.typename) + vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=d.result.datatype.ctype()) forge_generator = gf.CsaGenerator(vm) forge_generator.set(matrix_a, matrix_b, alpha, d.beta) routine_name = forge_generator.get_base_name() diff --git a/yateto/codegen/copyscaleadd/generic.py b/yateto/codegen/copyscaleadd/generic.py index 078199d..b941849 100644 --- a/yateto/codegen/copyscaleadd/generic.py +++ b/yateto/codegen/copyscaleadd/generic.py @@ -20,7 +20,7 @@ def generate(self, cpp, routineCache): if d.beta == 0.0: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) class CopyScaleAddBody(object): def __call__(s): diff --git a/yateto/codegen/copyscaleadd/tinytc.py b/yateto/codegen/copyscaleadd/tinytc.py index 03cd069..5cc33b6 100644 --- a/yateto/codegen/copyscaleadd/tinytc.py +++ b/yateto/codegen/copyscaleadd/tinytc.py @@ -12,14 +12,12 @@ class CopyScaleAddTinytc: def __init__(self, arch, descr): self._arch = arch self._descr = descr - self._scalar_type = 'f64' if self._arch.bytesPerReal == 8 else 'f32' - self._ty = ScalarType( - FloatingType.f64) if self._arch.bytesPerReal == 8 else ScalarType( - FloatingType.f32) def generate(self, cpp, routineCache): d = self._descr + ty = toTinyTCType(d.result.datatype) + # Order can be 1 or 2 def MakeLoopOverAxpby(d, order, transpose, A, B): A_offset_list = [None] * len(d.term.indices) @@ -49,7 +47,7 @@ def MakeLoopOverAxpby(d, order, transpose, A, B): csa_bb = RegionBuilder() a = csa_bb.add(SubviewInst(A, A_offset_list, A_size_list)) b = csa_bb.add(SubviewInst(B, B_offset_list, B_size_list)) - beta = csa_bb.add(ConstantInst(FloatImmValue(self._ty, d.beta))) + beta = csa_bb.add(ConstantInst(toTinyTCImmediate(ty, d.beta))) csa_bb.add(AxpbyInst(trans, alpha, a, beta, b)) csa_region = csa_bb.get_product() @@ -62,13 +60,13 @@ def MakeLoopOverAxpby(d, order, transpose, A, B): [ForInst(B_offset_list[j], start, stop, csa_region)]) return csa_region - alpha = LocalValue(self._ty, 'alpha') + alpha = LocalValue(ty, 'alpha') Abatch = LocalValue( - makeBatchType(self._ty, d.term.memoryLayout, + makeBatchType(ty, d.term.memoryLayout, d.term.is_compute_constant, d.term.is_temporary), 'A') Bbatch = LocalValue( - makeBatchType(self._ty, d.result.memoryLayout, + makeBatchType(ty, d.result.memoryLayout, d.result.is_compute_constant, d.result.is_temporary), 'B') kernel = Function(None, [alpha, Abatch, Bbatch], None) @@ -105,7 +103,7 @@ def MakeLoopOverAxpby(d, order, transpose, A, B): d.result.is_compute_constant, d.result.is_temporary, True) ] - wrapper = TinytcWrapper(kernel, args, self._arch.typename) + wrapper = TinytcWrapper(kernel, args) prototype = wrapper.prototype() routineCache.addRoutine(prototype, diff --git a/yateto/codegen/elementwise/__init__.py b/yateto/codegen/elementwise/__init__.py new file mode 100644 index 0000000..fb914e0 --- /dev/null +++ b/yateto/codegen/elementwise/__init__.py @@ -0,0 +1 @@ +from .factory import Description, generator diff --git a/yateto/codegen/elementwise/factory.py b/yateto/codegen/elementwise/factory.py new file mode 100644 index 0000000..8771498 --- /dev/null +++ b/yateto/codegen/elementwise/factory.py @@ -0,0 +1,40 @@ +from ...memory import CSCMemoryLayout +from ..common import * +from .generic import Generic + +from ...ops import Operation + +from typing import Union + +class Description(object): + def __init__(self, alpha, add: bool, optype: Operation, result: IndexedTensorDescription, terms: list[IndexedTensorDescription], termTemplate, nodeTermIndices): + self.alpha = alpha + self.add = add + self.result = result + self.terms = terms + self.optype = optype + self.termTemplate = termTemplate + self.nodeTermIndices = nodeTermIndices + + self.isSparse = [isinstance(term.memoryLayout, CSCMemoryLayout) for term in terms] + + rR = loopRanges(self.result, self.result.indices) + + # TODO: shall we allow boundingboxing? + if len(terms) == 0: + self.loopRanges = rR + else: + self.loopRanges = loopRanges(self.terms[0], self.result.indices) + assert testLoopRangesAContainedInB(self.loopRanges, rR) + for term in self.terms[1:]: + newRange = loopRanges(term, self.result.indices) + assert testLoopRangesEqual(newRange, self.loopRanges) + assert testLoopRangesAContainedInB(newRange, rR) + + self.loopRanges.update(newRange) + +def generator(arch, descr, target): + if target == 'cpu': + return Generic(arch, descr) + elif target == 'gpu': + raise RuntimeError("Elementwise operation has not been implemented for GPU-like architectures. At least not like this.") diff --git a/yateto/codegen/elementwise/generic.py b/yateto/codegen/elementwise/generic.py new file mode 100644 index 0000000..a64cf4e --- /dev/null +++ b/yateto/codegen/elementwise/generic.py @@ -0,0 +1,40 @@ +from ..common import * + +class Generic(object): + def __init__(self, arch, descr): + self._arch = arch + self._descr = descr + + def _affine(self, add, alpha): + flops = 1 + scale = f'{alpha} * ' if alpha != 1.0 else '' + assign = '+=' if add else '=' + + if alpha != 1.0: flops += 1 + if add: flops += 1 + + return flops, lambda left, right: f'{left} {assign} {scale}{right};' + + def _generateDenseDense(self, cpp): + d = self._descr + + if not d.add: + writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) + initializeWithZero(cpp, d.result, writeBB) + + flops, assigner = self._affine(d.add, d.alpha) + + class ProductBody(object): + def __call__(s): + args = [f'{arg.name}[{arg.memoryLayout.addressString(arg.indices)}]' for arg in d.terms] + fullArgs = [args[index] if template is None else template for index, template in zip(d.nodeTermIndices, d.termTemplate)] + opstr = d.optype.callstr(*fullArgs) + resultstr = f'{d.result.name}[{d.result.memoryLayout.addressString(d.result.indices)}]' + cpp(assigner(resultstr, opstr)) + return flops + return forLoops(cpp, d.result.indices, d.loopRanges, ProductBody()) + + def generate(self, cpp, routineCache): + d = self._descr + + return self._generateDenseDense(cpp) diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index bff1654..c6bb2f2 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -3,7 +3,8 @@ from ..ast.node import IndexedTensor from ..memory import DenseMemoryLayout from .common import forLoops, TensorDescription, IndexedTensorDescription, BatchedOperationsAux -from . import copyscaleadd, indexsum, log, product, fused_gemms +from . import copyscaleadd, indexsum, log, product, fused_gemms, elementwise, reduction +from ..type import Datatype, AddressingMode, Scalar class KernelFactory(object): ERROR_NAME = '_error' @@ -22,25 +23,23 @@ def create(self, node, *args): def generic_create(self, node, *args): raise NotImplementedError - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): raise NotImplementedError - def temporary(self, bufname, size, iniZero=False, memory=list()): + def temporary(self, bufname, size, datatype, iniZero=False, memory=list()): assert(iniZero == False or len(memory) == 0) + if datatype is None: + datatype = Datatype.I8 + if self._target == 'cpu': if self._arch.onHeap(size): if len(self._freeList) == 0: - self._cpp('int {};'.format(self.ERROR_NAME)) - self._cpp('{}* {};'.format(self._arch.typename, bufname)) - self._cpp('{} = posix_memalign(reinterpret_cast(&{}), {}, {}*sizeof({}));'.format( - self.ERROR_NAME, - bufname, - self._arch.cacheline, - size, - self._arch.typename)) + self._cpp(f'int {self.ERROR_NAME};') + self._cpp(f'{datatype.ctype()}* {bufname};') + self._cpp(f'{self.ERROR_NAME} = posix_memalign(reinterpret_cast(&{bufname}), {self._arch.cacheline}, {size}*sizeof({datatype.ctype()}));') if iniZero: - self._cpp.memset(bufname, size, self._arch.typename) + self._cpp.memset(bufname, size, datatype.ctype()) if memory: for i, data in enumerate(memory): self._cpp(f'{bufname}[{i}] = {data};') @@ -51,9 +50,9 @@ def temporary(self, bufname, size, iniZero=False, memory=list()): ini = ' = {}' elif memory: ini = ' = {{{}}}'.format(', '.join(memory)) - self._cpp(f'alignas({self._arch.cacheline}) {self._arch.typename} {bufname}[{size}] {ini};') + self._cpp(f'alignas({self._arch.cacheline}) {datatype.ctype()} {bufname}[{size}] {ini};') else: - declaration = f'{self._arch.typename}* {bufname}' + declaration = f'{datatype.ctype()}* {bufname}' total_size = f'{BatchedOperationsAux.NUM_ELEMENTS_NAME} * {size}' self._cpp(f'{declaration} = linearAllocator.allocate({total_size});') @@ -94,11 +93,26 @@ def _indices(self, var): shape = var.memoryLayout().shape() return Indices(string.ascii_lowercase[:len(shape)], shape) + def _conditional(self, condition, generate): + if isinstance(condition, bool): + if condition: + return generate() + else: + return 0 + else: + if condition.tautology(): + return generate() + elif condition.unfulfillable(): + return 0 + else: + with self._cpp.If(f'{condition.ccode()}'): + return generate() + class OptimizedKernelFactory(KernelFactory): def __init__(self, cpp, arch, target): super().__init__(cpp, arch, target) - def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 description = log.Description( alpha = scalar, @@ -112,14 +126,14 @@ def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName prefetchName = prefetchName ) generator = log.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache, gemm_cfg) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache, gemm_cfg)) - def create_FusedGEMMs(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): - description = fused_gemms.Description(node, result, arguments, add, scalar) + def create_FusedGEMMs(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + description = fused_gemms.Description(node, result, arguments, condition, add, scalar) generator = fused_gemms.generator(self._arch, description, gemm_cfg, self._target) - return generator.generate(self._cpp, routineCache, gemm_cfg) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache, gemm_cfg)) - def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 1 description = indexsum.Description( alpha = scalar, @@ -128,9 +142,9 @@ def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, ro term = IndexedTensorDescription.fromNode(arguments[0], node.term()) ) generator = indexsum.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - def create_Product(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 description = product.Description( alpha = scalar, @@ -140,24 +154,48 @@ def create_Product(self, node, result, arguments, add, scalar, prefetchName, rou rightTerm = IndexedTensorDescription.fromNode(arguments[1], node.rightTerm()) ) generator = product.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) + + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + description = elementwise.Description( + alpha = scalar, + add = add, + result = IndexedTensorDescription.fromNode(result, node), + terms = [IndexedTensorDescription.fromNode(argument, term) for argument, term in zip(arguments, node)], + optype = node.optype, + termTemplate = node.termTemplate, + nodeTermIndices = node.nodeTermIndices + ) + generator = elementwise.generator(self._arch, description, self._target) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) + + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + description = reduction.Description( + alpha = scalar, + add = add, + result = IndexedTensorDescription.fromNode(result, node), + term = IndexedTensorDescription.fromNode(arguments[0], node.term()), + optype = node.optype, + ) + generator = reduction.generator(self._arch, description, self._target) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): result = IndexedTensorDescription.fromNode(result, node) term = IndexedTensorDescription.fromNode(arguments[0], node.term()) - return self._csa(result, term, add, scalar, routineCache, gemm_cfg) + return self._csa(result, term, add, condition, scalar, routineCache, gemm_cfg) - def create_Broadcast(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Broadcast(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): result = IndexedTensorDescription.fromNode(result, node) term = IndexedTensorDescription.fromNode(arguments[0], node.term()) - return self._csa(result, term, add, scalar, routineCache, gemm_cfg) + return self._csa(result, term, add, condition, scalar, routineCache, gemm_cfg) - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): result = IndexedTensorDescription.fromVar(result, self._indices(result)) term = IndexedTensorDescription.fromVar(term, self._indices(term)) - return self._csa(result, term, add, scalar, routineCache, gemm_cfg) + return self._csa(result, term, condition, add, scalar, routineCache, gemm_cfg) - def _csa(self, result, term, add, scalar, routineCache, gemm_cfg): + def _csa(self, result, term, condition, add, scalar, routineCache, gemm_cfg): description = copyscaleadd.Description( alpha = scalar, beta = 1.0 if add else 0.0, @@ -165,7 +203,7 @@ def _csa(self, result, term, add, scalar, routineCache, gemm_cfg): term = term ) generator = copyscaleadd.generator(self._arch, description, gemm_cfg, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) class UnitTestFactory(KernelFactory): def __init__(self, cpp, arch, nameFun, testFramework): @@ -176,9 +214,9 @@ def __init__(self, cpp, arch, nameFun, testFramework): def _formatTerm(self, var, indices): address = var.memoryLayout().addressString(indices) - return '{}[{}]'.format(self._name(var), address) + return f'{self._name(var)}[{address}]' - def create_Einsum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Einsum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = node.indices for child in node: g = g.merged(child.indices - g) @@ -192,50 +230,86 @@ def create_Einsum(self, node, result, arguments, add, scalar, prefetchName, rout terms.insert(0, str(scalar)) if not add: - self._cpp.memset(self._name(result), result.memoryLayout().requiredReals(), self._arch.typename) + self._cpp.memset(self._name(result), result.memoryLayout().requiredReals(), result.datatype.ctype()) class EinsumBody(object): def __call__(s): - self._cpp( '{} += {};'.format(resultTerm, ' * '.join(terms)) ) + self._cpp(f"{resultTerm} += {' * '.join(terms)};") return len(terms) - return forLoops(self._cpp, g, ranges, EinsumBody(), pragmaSimd=False) + return self._conditional(condition, lambda: forLoops(self._cpp, g, ranges, EinsumBody(), pragmaSimd=False)) - def create_ScalarMultiplication(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): - return self.simple(result, arguments[0], add, scalar, routineCache) + def create_ScalarMultiplication(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + return self._conditional(condition, lambda: self.simple(result, arguments[0], add, scalar, routineCache)) - def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert node.indices <= node.term().indices and node.term().indices <= node.indices resultTerm = self._formatTerm(result, node.indices) termTerm = self._formatTerm(arguments[0], node.term().indices) - return self._simpleBody(resultTerm, termTerm, add, scalar, node.indices) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, node.indices)) def create_Broadcast(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): assert node.term().indices <= node.indices resultTerm = self._formatTerm(result, node.indices) termTerm = self._formatTerm(arguments[0], node.term().indices) - return self._simpleBody(resultTerm, termTerm, add, scalar, node.indices) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, node.indices)) + + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + + argTerms = [self._formatTerm(argument, term.indices) for argument, term in zip(arguments, node)] + termTerm = f'({argTerms[0]}) * ({argTerms[1]})' + + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) + + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + + argTerms = [self._formatTerm(argument, term.indices) for argument, term in zip(arguments, node)] + termTerm = node.optype.callstr(*node.fillTerms(argTerms)) + + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) - def _simpleBody(self, resultTerm, termTerm, add, scalar, indices): + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + termTerm = self._formatTerm(arguments[0], node.term().indices) + + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) + + def create_IfThenElse(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + yesTerm = self._formatTerm(arguments[0], node.yesTerm().indices) + noTerm = self._formatTerm(arguments[1], node.noTerm().indices) + conditionTerm = self._formatTerm(arguments[2], node.condition().indices) + + termTerm = f'(({conditionTerm}) ? ({yesTerm}) : ({noTerm}))' + + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) + + def _simpleBody(self, resultTerm, termTerm, add, scalar, indices, reduceIdx = None): ranges = {idx: Range(0, indices.indexSize(idx)) for idx in indices} if scalar and scalar != 1.0: - termTerm = '{} * {}'.format(scalar, termTerm) + termTerm = f'{scalar} * {termTerm}' class AssignBody(object): def __call__(s): - self._cpp( '{} {} {};'.format(resultTerm, '+=' if add else '=', termTerm) ) + self._cpp(f"{resultTerm} {'+=' if add else '='} {termTerm};") return 1 if add else 0 return forLoops(self._cpp, indices, ranges, AssignBody(), pragmaSimd=False) - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, g) termTerm = self._formatTerm(term, g) - return self._simpleBody(resultTerm, termTerm, add, scalar, g) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) def compare(self, ref, target, epsMult = 100.0): g = self._indices(ref) @@ -244,8 +318,8 @@ def compare(self, ref, target, epsMult = 100.0): class CompareBody(object): def __call__(s): - self._cpp( 'double ref = {};'.format(refTerm) ) - self._cpp( 'double diff = ref - {};'.format(targetTerm) ) + self._cpp( f'double ref = {refTerm};' ) + self._cpp( f'double diff = ref - {targetTerm};' ) self._cpp( 'error += diff * diff;' ) self._cpp( 'refNorm += ref * ref;' ) return 0 @@ -262,19 +336,21 @@ def tensor(self, node, resultName, maxValue = 512, scale = 1 / 512): ml = node.memoryLayout() size = ml.requiredReals() + datatype = node.getDatatype(self._arch) + spp = node.spp() isDense = spp.count_nonzero() == size if isDense: - self.temporary(resultName, size) - with self._cpp.For('int i = 0; i < {}; ++i'.format(size)): - self._cpp(f'{resultName}[i] = static_cast<{self._arch.typename}>((i + {self._rand}) % {maxValue} + 1) * static_cast<{self._arch.typename}>({scale});') + self.temporary(resultName, size, node.getDatatype(self._arch)) + with self._cpp.For(f'int i = 0; i < {size}; ++i'): + self._cpp(f'{resultName}[i] = static_cast<{datatype.ctype()}>((i + {self._rand}) % {maxValue} + 1);') else: - memory = ['0.0']*size + memory = [datatype.literal(0)]*size nz = spp.nonzero() for entry in zip(*nz): addr = ml.address(entry) - memory[addr] = str((float((addr + self._rand) % maxValue)+1.0) * scale) - self.temporary(resultName, size, memory=memory) + memory[addr] = datatype.literal(((addr + self._rand) % maxValue)+1.0) + self.temporary(resultName, size, datatype, memory=memory) self._rand += 1 class ExportGenerator: @@ -289,6 +365,12 @@ def generate(self, cpp, cache): def add_linear_operation(self, dest, ops, target, permute, add): pass + def add_operation(self, description): + pass + + def add_tensor(self, description): + pass + class ExportFactory(KernelFactory): @classmethod def makeFactory(cls, generator): @@ -297,6 +379,8 @@ def makeFactory(cls, generator): def __init__(self, generator, cpp, arch, target): super().__init__(cpp, arch, target) self.generator = generator + self.tensors = {} + self.scalarcounter = 0 def post_generate(self, routine_cache): self.generator.generate(self._cpp, routine_cache) @@ -304,63 +388,233 @@ def post_generate(self, routine_cache): def allocateTemporary(self): return False - def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): - assert len(arguments) == 2 - makeNode = IndexedTensorDescription.fromNode - argnodes = [makeNode(arguments[0], node.leftTerm()), makeNode(arguments[1], node.rightTerm())] - return self.handleLinear(makeNode(result, node), argnodes, add, scalar, node.transA(), node.transB()) + def _nodeTensor(self, tensor, node): + return self._handleTensorDesc(IndexedTensorDescription.fromNode(tensor, node)) - def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): - assert len(arguments) == 1 - makeNode = IndexedTensorDescription.fromNode - argnodes = [makeNode(arguments[0], node.term())] - return self.handleLinear(makeNode(result, node), argnodes, add, scalar, False, False) + def _varTensor(self, var, indices): + return self._handleTensorDesc(IndexedTensorDescription.fromVar(var, indices)) + + def _handleAddressing(self, desc): + if desc.addressing is None: + addressing = BatchedOperationsAux.deduce_addresing(desc) + else: + addressing = desc.addressing + + # & == deref + # n == current element + # N == element size + # o == extraOffset + # *,+ == default add and mul + # read left to right + if addressing == AddressingMode.DIRECT: + return '&' + elif addressing == AddressingMode.STRIDED: + return 'n*N+o&' + elif addressing == AddressingMode.INDIRECT: + return 'n&+o&' + elif addressing == AddressingMode.SCALAR: + return '' + + raise NotImplementedError(addressing) + + def _handleTensorDesc(self, tensorIndexed: IndexedTensorDescription): + if isinstance(tensorIndexed.memoryLayout, DenseMemoryLayout): + shape = list(tensorIndexed.memoryLayout.shape()) + shapeXt = [max(rng.stop - rng.start, shp) for rng, shp in zip(tensorIndexed.memoryLayout.bbox(), shape)] + storage = { + 'shape': shapeXt, + 'type': 'bbox', + 'start': [rng.start for rng in tensorIndexed.memoryLayout.bbox()], + 'sizes': [rng.stop - rng.start for rng in tensorIndexed.memoryLayout.bbox()] + } + else: + assert False + + eqsppnz = tensorIndexed.eqspp.nonzero() + spp = [elem for elem in zip(*eqsppnz)] + + values = None if tensorIndexed.values is None else list(tensorIndexed.values) + + tensor = { + 'name': tensorIndexed.name, + 'addressing': self._handleAddressing(tensorIndexed), + #'eqspp': spp, + 'datatype': str(tensorIndexed.datatype), + 'storage': storage, + 'values': values, + 'flags': { + 'temporary': tensorIndexed.is_temporary, + 'constant': tensorIndexed.is_compute_constant + } + } + + return self._handleTensor(tensor, spp, tensorIndexed.indices) + + def _scalarTensor(self, scalar): + if isinstance(scalar, (int, float)): # TODO numpy types + name = f'_scalar{self.scalarcounter}' + self.scalarcounter += 1 + + tensor = { + 'name': name, + 'addressing': '', + 'eqspp': (), + 'datatype': str(self._arch.datatype), + 'storage': { + 'shape': (), + 'type': 'full' + }, + 'values': { + (): scalar + }, + 'flags': { + 'temporary': False, + 'constant': True + } + } + elif isinstance(scalar, Scalar): + tensor = { + 'name': scalar.name(), + 'addressing': '', + 'eqspp': (), + 'datatype': str(scalar.getDatatype(self._arch)), + 'storage': { + 'shape': (), + 'type': 'full' + }, + 'values': None, + 'flags': { + 'temporary': False, + 'constant': True + } + } + else: + assert False - def create_Product(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + return self._handleTensor(tensor, (), ()) + + def _handleTensor(self, tensor, eqspp, indices): + if tensor['name'] not in self.tensors: + self.tensors[tensor['name']] = tensor + self.generator.add_tensor(tensor) + else: + assert tensor == self.tensors[tensor['name']] + + return { + 'name': tensor['name'], + 'spp': eqspp, + 'indices': indices + } + + def _handleCondition(self, condition): + out = [] + for clause in condition.clauses: + outclause = [] + for var in clause.variables: + tensor = self._varTensor(clause.variables[var], ()) + outclause += [tensor] + out += [outclause] + return out + + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + result = self._nodeTensor(result, node) + preArgs = [self._nodeTensor(argument, term) for argument, term in zip(arguments, node)] + args = node.fillTerms(preArgs) + + description = { + 'type': 'elementwise', + 'result': result, + 'args': args, + 'condition': self._handleCondition(condition), + 'linear': { + 'alpha': self._scalarTensor(scalar), + 'add': add, + }, + 'optype': str(node.optype) + } + return self.generator.add_operation(description) + + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + assert len(arguments) == 1 + result = self._nodeTensor(result, node) + argnodes = [self._nodeTensor(arguments[0], node.term())] + + description = { + 'type': 'reduction', + 'result': result, + 'args': argnodes, + 'condition': self._handleCondition(condition), + 'linear': { + 'alpha': self._scalarTensor(scalar), + 'add': add, + }, + 'optype': str(node.optype) + } + return self.generator.add_operation(description) + + def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 - makeNode = IndexedTensorDescription.fromNode - argnodes = [makeNode(arguments[0], node.leftTerm()), makeNode(arguments[1], node.rightTerm())] - return self.handleLinear(makeNode(result, node), argnodes, add, scalar, False, False) + argnodes = [self._nodeTensor(arguments[0], node.leftTerm()), self._nodeTensor(arguments[1], node.rightTerm())] + return self.handleLinear(self._nodeTensor(result, node), argnodes, condition, add, scalar, node.transA(), node.transB()) - def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + return create_Reduction(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) + + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + return create_Elementwise(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) + + def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] - return self.handleLinear(IndexedTensorDescription.fromVar(result, node.indices), [IndexedTensorDescription.fromVar(term, node.term().indices)], add, scalar, False, False) + return self.handleLinear(self._varTensor(result, node.indices), [self._varTensor(term, node.term().indices)], condition, add, scalar, False, False) - def create_Broadcast(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Broadcast(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] - return self.handleLinear(IndexedTensorDescription.fromVar(result, node.indices), [IndexedTensorDescription.fromVar(term, node.term().indices)], add, scalar, False, False) + return self.handleLinear(self._varTensor(result, node.indices), [self._varTensor(term, node.term().indices)], condition, add, scalar, False, False) - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): - return self.handleLinear(IndexedTensorDescription.fromVar(result, self._indices(result)), [IndexedTensorDescription.fromVar(term, self._indices(term))], add, scalar, False, False) + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): + return self.handleLinear(self._varTensor(result, self._indices(result)), [self._varTensor(term, self._indices(term))], condition, add, scalar, False, False) def getIndices(self, dest, ops): if dest is None: target_indices = [] else: - target_indices = dest.indices + target_indices = dest['indices'] indexindex = {index:i for i, index in enumerate(target_indices)} contract_counter = -1 for op in ops: - for index in op.indices: + for index in op['indices']: if index not in indexindex: indexindex[index] = contract_counter contract_counter -= 1 - target = [[indexindex[index] for index in op.indices] for op in ops] - permute = [[i for i,_ in enumerate(op.indices)] for op in ops] + target = [[indexindex[index] for index in op['indices']] for op in ops] + permute = [[i for i,_ in enumerate(op['indices'])] for op in ops] return target, permute - def handleLinear(self, dest, ops, add, scalar, transposeA, transposeB): + def handleLinear(self, dest, ops, condition, add, scalar, transposeA, transposeB): # convert indices to loop numbers target, permute = self.getIndices(dest, ops) if not (scalar == 1 or scalar == 1.0): - ops += [scalar] + ops += [self._scalarTensor(scalar)] target += [[]] permute += [[]] - return self.generator.add_linear_operation(dest, ops, target, permute, add) + description = { + 'type': 'multilinear', + 'result': dest, + 'args': ops, + 'condition': self._handleCondition(condition), + 'permute': permute, + 'target': target, + 'linear': { + 'alpha': self._scalarTensor(scalar), + 'add': add, + }, + # 'optype': node.optype + } + return self.generator.add_operation(description) diff --git a/yateto/codegen/fused_gemms/external_generator.py b/yateto/codegen/fused_gemms/external_generator.py index 4fea176..04a5d7b 100644 --- a/yateto/codegen/fused_gemms/external_generator.py +++ b/yateto/codegen/fused_gemms/external_generator.py @@ -11,7 +11,8 @@ class FusedGemms: def __init__(self, arch, descr): self._arch = arch self._descr = descr - self._batch_aux = BatchedOperationsAux(self._arch.typename) + self._datatype = self._descr[0].node.datatype + self._batch_aux = BatchedOperationsAux() self._cache = {} self._tmp_matrices = {} @@ -39,7 +40,7 @@ def generate(self, cpp, routineCache, cfg): context = Context(arch=self._arch.name, backend=self._arch.backend, - fp_type=FloatingPointType.str2enum(self._arch.typename)) + fp_type=FloatingPointType.str2enum(self._datatype.ctype())) chainforge_generator = ChainForgeGenerator(gemm_list, context) chainforge_generator.register() @@ -123,7 +124,7 @@ def _gen_call_site(self, generator): offset_name_map = {} for name, matrix in self._cache.items(): if matrix.direction == DataFlowDirection.SOURCE: - ptr_type = f'{self._arch.typename} {Addressing.addr2ptr_type(matrix.addressing)}' + ptr_type = f'{self._datatype.ctype()} {Addressing.addr2ptr_type(matrix.addressing)}' mat_name_map[name] = f'const_cast<{ptr_type}>({name})' else: mat_name_map[name] = name diff --git a/yateto/codegen/fused_gemms/tinytc.py b/yateto/codegen/fused_gemms/tinytc.py index 24f1a37..a73d182 100644 --- a/yateto/codegen/fused_gemms/tinytc.py +++ b/yateto/codegen/fused_gemms/tinytc.py @@ -13,9 +13,6 @@ class FusedGemmsTinytc: def __init__(self, arch, descr): self._arch = arch self._descr = descr - self._ty = ScalarType( - FloatingType.f64) if self._arch.bytesPerReal == 8 else ScalarType( - FloatingType.f32) def generate(self, cpp, routineCache, cfg): args = dict() @@ -39,7 +36,7 @@ def addVal(var, node): is_constant[var] = node.tensor.is_compute_constant( ) if isinstance(node, IndexedTensor) else False arg = LocalValue( - makeBatchType(self._ty, node.memoryLayout(), + makeBatchType(toTinyTCType(var.datatype), node.memoryLayout(), is_constant[var], var.is_temporary), name) args[var] = arg vals[var] = makeLoad(bb, arg, gid, is_constant[var], var.is_temporary) @@ -59,7 +56,7 @@ def addVal(var, node): if res.is_temporary: res_val = bb.add( AllocaInst( - makeMemrefType(self._ty, res.memoryLayout(), False, True))) + makeMemrefType(toTinyTCType(res.datatype), res.memoryLayout(), False, True))) vals[res] = res_val else: modified.add(res) @@ -87,7 +84,7 @@ def offsetSizeLists(ml, range0, range1): return ([IntImmValue(IntegerType.index, o) for o in offsets], [IntImmValue(IntegerType.index, s) for s in sizes]) - alpha = bb.add(ConstantInst(FloatImmValue(self._ty, scalar))) + alpha = bb.add(ConstantInst(toTinyTCImmediate(toTinyTCType(res.datatype), scalar))) op1_sub = bb.add( SubviewInst( op1_val, @@ -96,7 +93,7 @@ def offsetSizeLists(ml, range0, range1): SubviewInst( op2_val, *offsetSizeLists(node.rightTerm().memoryLayout(), k, n))) - beta = bb.add(ConstantInst(FloatImmValue(self._ty, 1.0 if add else 0.0))) + beta = bb.add(ConstantInst(toTinyTCImmediate(toTinyTCType(res.datatype), 1.0 if add else 0.0))) res_sub = bb.add( SubviewInst(res_val, *offsetSizeLists(node.memoryLayout(), m, n))) @@ -119,7 +116,7 @@ def offsetSizeLists(ml, range0, range1): wrapper_args.append( TinytcKernelArgument(name, str(key), is_constant[key], key.is_temporary, key in modified)) - wrapper = TinytcWrapper(kernel, wrapper_args, self._arch.typename) + wrapper = TinytcWrapper(kernel, wrapper_args) cpp(wrapper.call()) prototype = wrapper.prototype() routineCache.addRoutine(prototype, diff --git a/yateto/codegen/gemm/factory.py b/yateto/codegen/gemm/factory.py index 0f38428..5ac0de9 100644 --- a/yateto/codegen/gemm/factory.py +++ b/yateto/codegen/gemm/factory.py @@ -81,6 +81,9 @@ def generator(arch, descr, gemm_cfg, target): descr.beta, descr.alignedA, descr.alignedC, + descr.leftTerm.datatype, + descr.rightTerm.datatype, + descr.result.datatype, target) if gemmTool: return GemmGen(arch, descr, gemmTool) diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 5510d18..cfd0324 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -7,8 +7,9 @@ from ..cache import RoutineGenerator, GpuRoutineGenerator, TinytcWriter from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge, tinytc -from ..common import BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper +from ..common import BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper, toTinyTCType, toTinyTCImmediate from ..tiny_tensor_language import * +from ...type import Datatype import importlib.util @@ -66,9 +67,8 @@ def generateRoutineName(self, gemm, sppA, sppB): sha = hashlib.new('md5', usedforsecurity=False) sha.update(str(sppB).encode()) name += '_' + sha.hexdigest() - return '{name}_{datatype}_{arch}_m{M}_n{N}_k{K}_ldA{LDA}_ldB{LDB}_ldC{LDC}_alpha{alphaSubs}_beta{betaSubs}_alignedA{alignedA}_alignedC{alignedC}_transA{transA}_transB{transB}_{prefetch}'.format( + return '{name}_{datatypeA}_{datatypeB}_{datatypeC}_{arch}_m{M}_n{N}_k{K}_ldA{LDA}_ldB{LDB}_ldC{LDC}_alpha{alphaSubs}_beta{betaSubs}_alignedA{alignedA}_alignedC{alignedC}_transA{transA}_transB{transB}_{prefetch}'.format( name=name, - datatype=self._arch.typename, arch=self._arch.name.replace('-', '_'), alphaSubs=self._alpha(gemm['alpha']), betaSubs=self._beta(gemm['beta']), @@ -139,11 +139,18 @@ def generate(self, cpp, routineCache): d.beta, ptr_c, ldC, alignedA=d.alignedA, alignedC=d.alignedC, - prefetchName=d.prefetchName)) + prefetchName=d.prefetchName, + datatypeA=d.leftTerm.datatype, + datatypeB=d.rightTerm.datatype, + datatypeC=d.result.datatype)) elif isinstance(self._gemm_cfg, GemmForge): + assert d.result.datatype == d.leftTerm.datatype + assert d.result.datatype == d.rightTerm.datatype + ctype = d.result.datatype.ctype() + if gf_spec: - aux = BatchedOperationsAux(self._arch.typename) + aux = BatchedOperationsAux() matrix_a = gf.YatetoInterface.produce_dense_matrix((m, k), d.leftTerm.memoryLayout.bbox(), @@ -161,7 +168,7 @@ def generate(self, cpp, routineCache): transpose=False) try: - vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=self._arch.typename) + vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=ctype) forge_generator = gf.GemmGenerator(vm) forge_generator.set(d.transA, d.transB, matrix_a, matrix_b, matrix_c, d.alpha, d.beta) routine_name = forge_generator.get_base_name() @@ -191,7 +198,7 @@ def generate(self, cpp, routineCache): raise RuntimeError('gemmforge module is not found. You can install it with pip3. ' 'e.g., pip3 install gemmforge') elif isinstance(self._gemm_cfg, tinytc): - aux = BatchedOperationsAux(self._arch.typename) + aux = BatchedOperationsAux() gemm = { 'M': m.size(), 'N': n.size(), @@ -209,12 +216,15 @@ def generate(self, cpp, routineCache): 'beta': self._beta(d.beta), 'transA': d.transA, 'transB': d.transB, + 'datatypeA': d.leftTerm.datatype, + 'datatypeB': d.rightTerm.datatype, + 'datatypeC': d.result.datatype, } kernel = tinytcGemmGen(self._arch, gemm) def call_arg(name, term, modified, offset): - return TinytcKernelArgument(name, term.name, term.is_compute_constant, term.is_temporary, modified, offset) + return TinytcKernelArgument(name, term.datatype, term.name, term.is_compute_constant, term.is_temporary, modified, offset) offset_a = self._offset(term=d.leftTerm, offset2=(m.start, k.start), transpose=d.transA) offset_b = self._offset(term=d.rightTerm, offset2=(k.start, n.start), transpose=d.transB) offset_c = self._offset(term=d.result, offset2=(m.start, n.start), transpose=False) @@ -222,8 +232,8 @@ def call_arg(name, term, modified, offset): call_arg('A', d.leftTerm, False, offset_a), call_arg('B', d.rightTerm, False, offset_b), call_arg('C', d.result, True, offset_c)] - routine_name = 'tinytc_wrapper_{typename}_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_tA{transA}_tB{transB}'.format(typename=self._arch.typename, **gemm) - wrapper = TinytcWrapper(kernel, args, self._arch.typename, routine_name) + routine_name = 'tinytc_wrapper_{datatypeA}_{datatypeB}_{datatypeC}_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_tA{transA}_tB{transB}'.format(**gemm) + wrapper = TinytcWrapper(kernel, args, routine_name) cpp(wrapper.call()) prototype = wrapper.prototype() @@ -243,7 +253,9 @@ def call_arg(name, term, modified, offset): 'prefetch': 'BL2viaC' if self._arch.enablePrefetch and d.prefetchName is not None else 'nopf', 'transA': d.transA, 'transB': d.transB, - + 'datatypeA': d.leftTerm.datatype, + 'datatypeB': d.rightTerm.datatype, + 'datatypeC': d.result.datatype, } routineName = self.generateRoutineName(gemm, sppA, sppB) @@ -323,7 +335,19 @@ def _callGenerator(self, argList): def __call__(self, routineName, fileName): cpu_arch = self._arch.host_name + assert self._gemmDescr['datatypeC'] == self._gemmDescr['datatypeA'] + assert self._gemmDescr['datatypeC'] == self._gemmDescr['datatypeB'] + if self._mode == 'pspamm': + assert self._gemmDescr['datatypeC'] in [Datatype.BF16, Datatype.F16, Datatype.F32, Datatype.F64] + + precision = { + Datatype.BF16: 'BF16', + Datatype.F16: 'H', + Datatype.F32: 'S', + Datatype.F64: 'D' + }[self._gemmDescr['datatypeC']] + pspamm_arch = cpu_arch if cpu_arch == 'a64fx': pspamm_arch = 'arm_sve512' @@ -358,7 +382,7 @@ def __call__(self, routineName, fileName): '--output_filename', fileName, '--precision', - self._arch.precision + precision ] if self._gemmDescr['prefetch'] != 'nopf': argList.extend(['--prefetching', self._gemmDescr['prefetch']]) @@ -369,6 +393,14 @@ def __call__(self, routineName, fileName): for key, val in self._blockSize.items(): argList.extend(['--' + key, val]) else: + assert self._gemmDescr['datatypeC'] in [Datatype.I16, Datatype.F32, Datatype.F64] + + precision = { + Datatype.I16: 'I16', + Datatype.F32: 'SP', + Datatype.F64: 'DP' + }[self._gemmDescr['datatypeC']] + libxsmm_arch = cpu_arch if cpu_arch in ['naples', 'rome', 'milan', 'avx2-256']: # names are Zen1, Zen2, Zen3, respectively @@ -393,7 +425,7 @@ def __call__(self, routineName, fileName): self._gemmDescr['alignedC'], libxsmm_arch, # libxsmm has no support for rome, hsw works well in practice self._gemmDescr['prefetch'], - self._arch.precision + 'P' + precision ] class SparsityWrapper: def __init__(self, shape, spp): @@ -438,13 +470,18 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._callGenerator(argList) if self._mode == 'pspamm': - return 'void {name}(const {type}* A, const {type}* B, {type}* C, {type} alpha, {type} beta, const {type}* prefetch);'.format(name=routineName, type=self._arch.typename) + return 'void {name}(const {atype}* A, const {btype}* B, {ctype}* C, {ctype} alpha, {ctype} beta, const {ctype}* prefetch);'.format(name=routineName, + atype=self._gemmDescr['datatypeA'].ctype(), + btype=self._gemmDescr['datatypeB'].ctype(), + ctype=self._gemmDescr['datatypeC'].ctype(), + ) # LIBXSMM header + datatype = self._gemmDescr['datatypeC'].ctype() if self._gemmDescr['prefetch'] == 'nopf': - return 'void {name}(const {type}* A, const {type}* B, {type}* C);'.format(name=routineName, type=self._arch.typename) + return 'void {name}(const {type}* A, const {type}* B, {type}* C);'.format(name=routineName, type=datatype) else: - return 'void {name}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch);'.format(name=routineName, type=self._arch.typename) + return 'void {name}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch);'.format(name=routineName, type=datatype) class GemmForgeWriter(GpuRoutineGenerator): def __init__(self, forge_generator, headers): @@ -500,6 +537,9 @@ def _kernel(self, routine_name): prefetch = self._gemmDescr['prefetch'] transA = self._gemmDescr['transA'] transB = self._gemmDescr['transB'] + datatypeA = self._gemmDescr['datatypeA'] + datatypeB = self._gemmDescr['datatypeB'] + datatypeC = self._gemmDescr['datatypeC'] flags = ["LIBXSMM_GEMM_FLAG_NONE"] if transA: @@ -535,7 +575,7 @@ def _kernel(self, routine_name): {prefetch_flag} // prefetch ); """.format(kernel_var_name=kernel_var_name, - prec=self._arch.typename, M=M, N=N, K=K, + prec=datatypeC.ctype(), M=M, N=N, K=K, ldA=ldA, ldB=ldB, ldC=ldC, alpha=alpha, beta=beta, flag=libxsmm_flag_str, @@ -563,7 +603,8 @@ def _call(self, routineName): """ def _functionSignature(self, routineName): - return 'void {routineName}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch)'.format(routineName=routineName, type=self._arch.typename) + datatypeC = self._gemmDescr['datatypeC'].ctype() + return 'void {routineName}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch)'.format(routineName=routineName, type=datatypeC) def __call__(self, routineName, fileName): func_signature = self._functionSignature(routineName) @@ -574,50 +615,55 @@ def __call__(self, routineName, fileName): return func_signature + ";" def tinytcGemmGen(arch, gd): - scalar_ty = ScalarType(FloatingType.f64) if arch.bytesPerReal == 8 else ScalarType(FloatingType.f32) - - Operand = namedtuple('Operand', ['name', 'addr', 'rows', 'cols', 'ld', 'dist']) + Operand = namedtuple('Operand', ['name', 'addr', 'rows', 'cols', 'ld', 'dist', 'datatype']) + def scalar_type(op): + return toTinyTCType(op.datatype) def data_type(op): - if op.addr == 'pointer_based': + scalar_ty = scalar_type(op) + if op.addr == AddressingMode.INDIRECT: return GroupType(MemrefType(scalar_ty, (op.rows, op.cols), (1, op.ld)), DYNAMIC) - elif op.addr == 'strided': + elif op.addr == AddressingMode.STRIDED: return MemrefType(scalar_ty, (op.rows, op.cols, DYNAMIC), (1, op.ld, op.dist)) - elif op.addr == 'none': + elif op.addr == AddressingMode.NONE: return MemrefType(scalar_ty, (op.rows, op.cols), (1, op.ld)) else: raise NameError(op.addr) def load_inst(op, batch, gid): zero = IntImmValue(IntegerType.index, 0) dyn = IntImmValue(IntegerType.index, DYNAMIC) - if op.addr == 'pointer_based': + if op.addr == AddressingMode.INDIRECT: return LoadInst(batch, [gid]) - elif op.addr == 'strided': + elif op.addr == AddressingMode.STRIDED: return SubviewInst(batch, [zero,zero,gid], [dyn,dyn,None]) - elif op.addr == 'none': + elif op.addr == AddressingMode.NONE: return SubviewInst(batch, [zero,zero], [dyn,dyn]) else: raise NameError(op.addr) - opA = Operand('A', gd['addrA'], gd['M'], gd['K'], gd['LDA'], gd['distA']) - opB = Operand('B', gd['addrB'], gd['K'], gd['N'], gd['LDB'], gd['distB']) - opC = Operand('C', gd['addrC'], gd['M'], gd['N'], gd['LDC'], gd['distC']) + opA = Operand('A', gd['addrA'], gd['M'], gd['K'], gd['LDA'], gd['distA'], gd['datatypeA']) + opB = Operand('B', gd['addrB'], gd['K'], gd['N'], gd['LDB'], gd['distB'], gd['datatypeB']) + opC = Operand('C', gd['addrC'], gd['M'], gd['N'], gd['LDC'], gd['distC'], gd['datatypeC']) T = lambda x: Transpose.t if x else Transpose.n tA = T(gd['transA']) tB = T(gd['transB']) - beta = gd['beta'] - - alpha = LocalValue(scalar_ty, 'alpha') - Abatch = LocalValue(data_type(opA), 'Abatch') - Bbatch = LocalValue(data_type(opB), 'Bbatch') - Cbatch = LocalValue(data_type(opC), 'Cbatch') + alphaV = gd['alpha'] + betaV = gd['beta'] + + scalar_ty_a = toTinyTCType(ocA.datatype) + scalar_ty_b = toTinyTCType(ocB.datatype) + scalar_ty_c = toTinyTCType(ocC.datatype) + alpha = bb.add(ConstantInst(toTinyTCImmediate(scalar_ty_c, alphaV))) if isinstance(alphaV, (int, float)) else LocalValue(scalar_ty_c, 'alpha') + Abatch = LocalValue(scalar_ty_a, 'Abatch') + Bbatch = LocalValue(scalar_ty_b, 'Bbatch') + Cbatch = LocalValue(scalar_ty_c, 'Cbatch') kernel = Function('gemm', [alpha, Abatch, Bbatch, Cbatch], None) bb = RegionBuilder() gid = bb.add(GroupIdInst()) A = bb.add(load_inst(opA, Abatch, gid)) B = bb.add(load_inst(opB, Bbatch, gid)) C = bb.add(load_inst(opC, Cbatch, gid)) - beta = bb.add(ConstantInst(FloatImmValue(scalar_ty, beta))) + beta = bb.add(ConstantInst(toTinyTCImmediate(scalar_ty_c, betaV))) if isinstance(betaV, (int, float)) else LocalValue(scalar_ty_c, 'beta') bb.add(GemmInst(tA, tB, alpha, A, B, beta, C)) kernel.body = bb.get_product() AssignIdentifiers().visit(kernel) diff --git a/yateto/codegen/indexsum/generic.py b/yateto/codegen/indexsum/generic.py index 6f072ab..5a57764 100644 --- a/yateto/codegen/indexsum/generic.py +++ b/yateto/codegen/indexsum/generic.py @@ -10,19 +10,19 @@ def generate(self, cpp, routineCache): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) sumIndex = d.term.indices - d.result.indices assert len(sumIndex) == 1 class IndexSumBody(object): def __call__(s): target = '{}[{}]'.format(d.result.name, d.result.memoryLayout.addressString(d.result.indices)) - initialValue = target if d.add else self._arch.formatConstant(0.0) - cpp( '{} sum = {};'.format(self._arch.typename, initialValue) ) + initialValue = target if d.add else d.result.datatype.literal(0.0) + cpp(f'{d.result.datatype.ctype()} sum = {initialValue};') with cpp.For('int {0} = {1}; {0} < {2}; ++{0}'.format(sumIndex, d.sumLoopRange.start, d.sumLoopRange.stop)): - cpp( 'sum += {}[{}];'.format(d.term.name, d.term.memoryLayout.addressString(d.term.indices)) ) - mult = '{} * '.format(d.alpha) if d.alpha != 1.0 else '' - cpp( '{} = {}sum;'.format(target, mult) ) + cpp( f'sum += {d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}];' ) + mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' + cpp( f'{target} = {mult}sum;' ) flop = 1 if d.alpha != 1.0 else 0 return d.sumLoopRange.size() + flop diff --git a/yateto/codegen/log/generic.py b/yateto/codegen/log/generic.py index a239e8c..59177c0 100644 --- a/yateto/codegen/log/generic.py +++ b/yateto/codegen/log/generic.py @@ -14,7 +14,7 @@ def _pointer(self, cpp, targetName, baseName, term, loopIndices, const=True): addressStr = term.memoryLayout.addressString(term.indices, indices) if len(indices) > 0 else '' if len(addressStr) > 0: addressStr = ' + ' + addressStr - cpp('{} {}* {} = {}{};'.format(self._arch.typename, 'const' if const else '', targetName, baseName, addressStr)) + cpp('{} {}* {} = {}{};'.format(term.datatype.ctype(), 'const' if const else '', targetName, baseName, addressStr)) def _alignedStart(self, term, loopIndices): return term.memoryLayout.isAlignedAddressString(term.indices, term.indices & loopIndices) @@ -76,9 +76,9 @@ def generate(self, cpp, routineCache, gemm_cfg): Ceqspp = self._reduce(d.result, C, CmemLayout) gemmDescr = gemm.Description( - leftTerm = TensorDescription(innerAname, AmemLayout, Aeqspp, d.leftTerm.is_compute_constant, d.leftTerm.is_temporary), - rightTerm = TensorDescription(innerBname, BmemLayout, Beqspp, d.rightTerm.is_compute_constant, d.rightTerm.is_temporary), - result = TensorDescription(innerCname, CmemLayout, Ceqspp, d.result.is_compute_constant, d.result.is_temporary), + leftTerm = TensorDescription(innerAname, AmemLayout, Aeqspp, d.leftTerm.is_compute_constant, d.leftTerm.is_temporary, datatype=d.leftTerm.datatype), + rightTerm = TensorDescription(innerBname, BmemLayout, Beqspp, d.rightTerm.is_compute_constant, d.rightTerm.is_temporary, datatype=d.rightTerm.datatype), + result = TensorDescription(innerCname, CmemLayout, Ceqspp, d.result.is_compute_constant, d.result.is_temporary, datatype=d.result.datatype), transA = d.transA, transB = d.transB, alpha = d.alpha, @@ -96,7 +96,7 @@ def generate(self, cpp, routineCache, gemm_cfg): lr.update( self._defuse(m, d.leftTerm, Im) ) lr.update( self._defuse(n, d.rightTerm, In) ) writeBB = boundingBoxFromLoopRanges(d.result.indices, lr) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) class LoGBody(object): def __call__(s): diff --git a/yateto/codegen/product/generic.py b/yateto/codegen/product/generic.py index 452f2bf..f6f8587 100644 --- a/yateto/codegen/product/generic.py +++ b/yateto/codegen/product/generic.py @@ -21,7 +21,7 @@ def _generateDenseDense(self, cpp): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) class ProductBody(object): def __call__(s): @@ -42,7 +42,7 @@ def _generateSparseSparse(self, cpp): d = self._descr if not d.add: - initializeWithZero(cpp, self._arch, d.result) + initializeWithZero(cpp, d.result) left = d.result.indices.positions(d.leftTerm.indices, sort=False) right = d.result.indices.positions(d.rightTerm.indices, sort=False) diff --git a/yateto/codegen/reduction/__init__.py b/yateto/codegen/reduction/__init__.py new file mode 100644 index 0000000..fb914e0 --- /dev/null +++ b/yateto/codegen/reduction/__init__.py @@ -0,0 +1 @@ +from .factory import Description, generator diff --git a/yateto/codegen/reduction/factory.py b/yateto/codegen/reduction/factory.py new file mode 100644 index 0000000..4133201 --- /dev/null +++ b/yateto/codegen/reduction/factory.py @@ -0,0 +1,30 @@ +from ..common import * +from .generic import Generic + +from ...ops import Operation + +class Description(object): + def __init__(self, alpha, add: bool, result: IndexedTensorDescription, term: IndexedTensorDescription, optype: Operation): + self.alpha = alpha + self.add = add + self.result = result + self.term = term + self.optype = optype + + rA = loopRanges(self.term, self.result.indices) + rB = loopRanges(self.result, self.result.indices) + assert testLoopRangesAContainedInB(rA, rB) + + self.loopRanges = rA + + self.sumIndex = self.term.indices - self.result.indices + assert len(self.sumIndex) == 1 + + self.sumLoopRange = loopRanges(self.term, self.sumIndex)[str(self.sumIndex)] + + +def generator(arch, descr, target): + if target == 'cpu': + return Generic(arch, descr) + elif target == 'gpu': + raise RuntimeError("IndexSum operation has not been implemented for GPU-like architectures") diff --git a/yateto/codegen/reduction/generic.py b/yateto/codegen/reduction/generic.py new file mode 100644 index 0000000..522ef21 --- /dev/null +++ b/yateto/codegen/reduction/generic.py @@ -0,0 +1,31 @@ +from ..common import * + +class Generic(object): + def __init__(self, arch, descr): + self._arch = arch + self._descr = descr + + def generate(self, cpp, routineCache): + d = self._descr + + if not d.add: + writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) + initializeWithZero(cpp, d.result, writeBB) + + sumIndex = d.term.indices - d.result.indices + assert len(sumIndex) == 1 + class IndexSumBody(object): + def __call__(s): + target = f'{d.result.name}[{d.result.memoryLayout.addressString(d.result.indices)}]' + initialValue = target if d.add else d.result.datatype.literal(d.optype.neutral()) + cpp(f'{d.result.datatype.ctype()} acc = {initialValue};') + with cpp.For(f'int {sumIndex} = {d.sumLoopRange.start}; {sumIndex} < {d.sumLoopRange.stop}; ++{sumIndex}'): + argstr = f'{d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}]' + cpp( f'acc = {d.optype.callstr('acc', argstr)};' ) + mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' + cpp( f'{target} = {mult}acc;' ) + + flop = 1 if d.alpha != 1.0 else 0 + return d.sumLoopRange.size() + flop + + return forLoops(cpp, d.result.indices, d.loopRanges, IndexSumBody()) diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index 0ec981d..3ee650d 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -78,24 +78,24 @@ def generate(self, cpp, cfg, factory, routineCache, gemm_cfg): localPtrs = set() for pp in cfg: localPtrs.update(pp.bufferMap.keys()) - if localPtrs: - cpp( '{}{};'.format(self._arch.typename, ','.join(map(lambda x: ' *' + str(x), localPtrs))) ) + for localPtr in localPtrs: + cpp(f'{localPtr.datatype.ctype()}* {localPtr};') for pp in cfg: if factory.allocateTemporary(): for buf, size in pp.initBuffer.items(): - required_tmp_mem += size * self._arch.bytesPerReal + required_tmp_mem += size bufname = self._bufferName(buf) - factory.temporary(bufname, size) + factory.temporary(bufname, size, None) for local, buf in pp.bufferMap.items(): - cpp('{} = {};'.format(local, self._bufferName(buf))) + cpp(f'{local} = reinterpret_cast<{localPtr.datatype.ctype()}*>({self._bufferName(buf)});') action = pp.action if action: scalar = self.deduce_scalar(action) if action.isRHSExpression(): prefetchName = '{}.{}'.format(self.PREFETCHVAR_NAME, action.term.node.prefetch.name()) if action.term.node.prefetch is not None else None - hwFlops += factory.create(action.term.node, action.result, action.term.variableList(), action.add, scalar, prefetchName, routineCache, gemm_cfg) + hwFlops += factory.create(action.term.node, action.result, action.term.variableList(), action.condition, action.add, scalar, prefetchName, routineCache, gemm_cfg) else: - hwFlops += factory.simple(action.result, action.term, action.add, scalar, routineCache, gemm_cfg) + hwFlops += factory.simple(action.result, action.term, action.condition, action.add, scalar, routineCache, gemm_cfg) return hwFlops, required_tmp_mem class OptimizedKernelGenerator(KernelGenerator): @@ -134,6 +134,7 @@ def __init__(self, function, tmp_mem_size, is_compute_constant_tensors, + datatype, target): self.nonZeroFlops = nonZeroFlops @@ -145,6 +146,7 @@ def __init__(self, self.function = function self.tmp_mem_size = tmp_mem_size self.is_compute_constant_tensors = is_compute_constant_tensors + self.datatype = datatype self.target = target @classmethod @@ -166,12 +168,16 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): writable = dict() is_compute_constant_tensors = dict() scalars = collections.OrderedDict() + datatype = dict() for scalar in scalarsP: self.KernelOutline._addTensor(scalar, scalars) + datatype[scalar.baseNameWithNamespace()] = scalar.getDatatype(self._arch) for var in variables: self.KernelOutline._addTensor(var.tensor, tensors) bn = var.tensor.baseNameWithNamespace() + datatype[bn] = var.datatype + if bn in writable: if var.writable: writable[bn] = True @@ -204,6 +210,7 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): function, tmp_memory, is_compute_constant_tensors, + datatype, target) @classmethod @@ -211,7 +218,7 @@ def _addFromKO(cls, koEntries, entries): for key, value in koEntries.items(): if key not in entries: entries[key] = value - else: + elif entries[key] != value: entries[key] = entries[key] | value @@ -221,6 +228,7 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): writable = dict() scalars = collections.OrderedDict() is_compute_constant_tensors = dict() + datatype = dict() for ko in kernelOutlines: if ko: self._addFromKO(ko.scalars, scalars) @@ -228,6 +236,7 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): self._addFromKO(ko.writable, writable) self._addFromKO(ko.prefetch, prefetch) self._addFromKO(ko.is_compute_constant_tensors, is_compute_constant_tensors) + self._addFromKO(ko.datatype, datatype) target = kernelOutlines[-1].target is_same_target = True @@ -279,13 +288,14 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): if target == 'gpu': # LinearAllocatorT controls external extra mem. allocated on gpu for tmp. variables - header(f'yateto::LinearAllocatorT<{self._arch.typename}> linearAllocator;') + # back-casted to char for now + header(f'yateto::LinearAllocatorT linearAllocator;') header.emptyline() - def kernelArgs(base_name_with_namespace, groups, writable, is_constant, target): + def kernelArgs(base_name_with_namespace, groups, writable, is_constant, datatype, target): prefix, base_name = Tensor.splitBasename(base_name_with_namespace) - typ = self._arch.typename + typ = datatype.ctype() ptr_type = '**' if not is_constant and target == 'gpu' else '*' if not writable: typ += ' const' @@ -294,11 +304,11 @@ def kernelArgs(base_name_with_namespace, groups, writable, is_constant, target): container_type = f'{InitializerGenerator.CONTAINER_CLASS_NAME}<{typ}{ptr_type}>' header(f'{class_name}::{container_type} {base_name};') else: - header(f'{typ}{ptr_type} {base_name}{{}};') + header(f'{typ}{ptr_type} {base_name}{"{"}nullptr{"}"};') - def scalarArgs(base_name_with_namespace, groups): + def scalarArgs(base_name_with_namespace, datatype, groups): prefix, base_name = Tensor.splitBasename(base_name_with_namespace) - typ = self._arch.typename + typ = datatype.ctype() if len(next(iter(groups))) > 0: class_name = f'{prefix}{InitializerGenerator.TENSOR_NAMESPACE}::{base_name}' container_type = f'{InitializerGenerator.CONTAINER_CLASS_NAME}<{typ}>' @@ -308,12 +318,14 @@ def scalarArgs(base_name_with_namespace, groups): for baseName, groups in scalars.items(): scalarArgs(baseName, + datatype[baseName], groups) for baseName, groups in tensors.items(): kernelArgs(baseName, groups, writable[baseName], is_compute_constant_tensors[baseName], + datatype[baseName], target) header.emptyline() @@ -341,7 +353,7 @@ def generate_extra_offset_args(base_name_with_namespace, groups): if len(prefetch) > 0: with header.Struct(self.PREFETCHSTRUCT_NAME): for baseName, groups in prefetch.items(): - kernelArgs(baseName, groups, writable=False, is_constant=False, target='any') + kernelArgs(baseName, groups, writable=False, is_constant=False, datatype=self._arch.datatype, target='any') header('{} {};'.format(self.PREFETCHSTRUCT_NAME, self.PREFETCHVAR_NAME)) header.emptyline() @@ -454,7 +466,7 @@ def _devTensorKernelArgument(self, var, writable): elif writable[var.tensor.baseNameWithNamespace()]: return self._devPtrTensorName(var) else: - return f'const_cast<{self._arch.typename} const**>({self._devPtrTensorName(var)})' + return f'const_cast<{var.datatype.ctype()} const**>({self._devPtrTensorName(var)})' @classmethod def _nameS(cls, var): @@ -529,16 +541,17 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, factory = UnitTestFactory(cpp, self._arch, self._name, testFramework) for i,scalar in enumerate(scalars): - cpp('{} {} = {};'.format(self._arch.typename, self._tensorNameS(scalar), float(i+2))) + cpp('{} {} = {};'.format(scalar.getDatatype(self._arch).ctype(), self._tensorNameS(scalar), float(i+2))) for var in variables: factory.tensor(var.tensor, self._tensorName(var)) - factory.temporary(self._name(var), var.memoryLayout().requiredReals(), iniZero=True) + factory.temporary(self._name(var), var.memoryLayout().requiredReals(), var.datatype, iniZero=True) shape = var.memoryLayout().shape() - cpp('{supportNS}::DenseTensorView<{dim},{arch.typename},{arch.uintTypename}> {viewName}({utName}, {{{shape}}}, {{{start}}}, {{{stop}}});'.format( + cpp('{supportNS}::DenseTensorView<{dim},{datatype},{arch.uintTypename}> {viewName}({utName}, {{{shape}}}, {{{start}}}, {{{stop}}});'.format( supportNS = SUPPORT_LIBRARY_NAMESPACE, dim=len(shape), + datatype=var.datatype.ctype(), arch = self._arch, utName=self._name(var), viewName=self._viewName(var), @@ -572,13 +585,13 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, kernelTensorName = lambda var: self._devTensorKernelArgument(var, writable) stream_new(self.STREAM) - data_malloc(self.TMP_MEM, self.TMP_SIZE, f'{self._arch.typename}*', self.STREAM) + data_malloc(self.TMP_MEM, self.TMP_SIZE, f'char*', self.STREAM) for var in variables: - data_malloc(self._devTensorName(var), f'sizeof({self._tensorName(var)})', f'{self._arch.typename}*', self.STREAM) - data_malloc(self._devPtrTensorName(var), f'sizeof({self._arch.typename}*)', f'{self._arch.typename}**', self.STREAM) + data_malloc(self._devTensorName(var), f'sizeof({self._tensorName(var)})', f'{var.datatype.ctype()}*', self.STREAM) + data_malloc(self._devPtrTensorName(var), f'sizeof({var.datatype.ctype()}*)', f'{var.datatype.ctype()}**', self.STREAM) for var in variables: data_memcpy(self._devTensorName(var), self._tensorName(var), f'sizeof({self._tensorName(var)})', self.STREAM) - data_memcpy(self._devPtrTensorName(var), f'&{self._devTensorName(var)}', f'sizeof({self._arch.typename}*)', self.STREAM) + data_memcpy(self._devPtrTensorName(var), f'&{self._devTensorName(var)}', f'sizeof({var.datatype.ctype()}*)', self.STREAM) stream_wait(self.STREAM) cpp.emptyline() @@ -614,7 +627,7 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, for var in variables: if var.writable: - factory.compare(var, Variable(self._tensorName(var), False, var.tensor.memoryLayout())) + factory.compare(var, Variable(self._tensorName(var), False, var.tensor.memoryLayout(), datatype=var.datatype)) factory.freeTmp() @@ -636,28 +649,28 @@ class InitializerGenerator(object): class TensorView(object): ARGUMENT_NAME = 'values' + def __init__(self, datatype): + self._datatype = datatype + def typename(self, dim, arch, const): constStr = 'true' if const else 'false' - return '::{}::{}<{},{},{},{}>'.format(SUPPORT_LIBRARY_NAMESPACE, type(self).__name__, dim, arch.typename, arch.uintTypename, constStr) + return f'::{SUPPORT_LIBRARY_NAMESPACE}::{type(self).__name__}<{dim},{self._datatype.ctype()},{arch.uintTypename},{constStr}>' - @classmethod - def arguments(cls, arch, const): - if const: - return '{} const* {}'.format(arch.typename, cls.ARGUMENT_NAME) - else: - return '{} * {}'.format(arch.typename, cls.ARGUMENT_NAME) + def arguments(self, const): + conststr = ' const*' if const else '*' + return f'{self._datatype.ctype()}{conststr} {self.ARGUMENT_NAME}' - def generate(cpp, group, memLayout, arch, index, const): + def generate(cpp, group, memLayout): raise NotImplementedError def listToInitializerList(self, lst): return '{{{}}}'.format(', '.join([str(l) for l in lst])) def formatArray(self, numberType, name, values, declarationOnly): - lhs = '{} {}[]'.format(numberType, name) + lhs = f'{numberType} {name}[]' if declarationOnly: return '' - return '{} {} = {};'.format(MODIFIERS, lhs, self.listToInitializerList(values)) + return f'{MODIFIERS} {lhs} = {self.listToInitializerList(values)};' class DenseTensorView(TensorView): START_NAME = 'Start' @@ -683,7 +696,7 @@ class CSCMatrixView(TensorView): def typename(self, dim, arch, const): constStr = 'true' if const else 'false' - return '::{}::{}<{},{},{}>'.format(SUPPORT_LIBRARY_NAMESPACE, type(self).__name__, arch.typename, arch.uintTypename, constStr) + return f'::{SUPPORT_LIBRARY_NAMESPACE}::{type(self).__name__}<{self._datatype.ctype()},{arch.uintTypename},{constStr}>' def generate(self, cpp, memLayout, arch, index, const): cpp( 'return {}({}, {}, {}, {});'.format( @@ -700,9 +713,9 @@ def arrays(self, cpp, memLayout, arch, namespace, index, numberType, declaration def __init__(self, arch, tensors, scalars): self._arch = arch - self._numberType = '{} const'.format(self._arch.uintTypename) - self._realType = '{} const'.format(self._arch.typename) - self._realPtrType = self._realType + '*' + self._numberType = f'{self._arch.uintTypename} const'.format(self._arch.uintTypename) + self._realType = lambda datatype: f'{datatype.ctype()} const' + self._realPtrType = lambda datatype: self._realType(datatype) + '*' self._scalarCollect = collections.OrderedDict() self._collect = collections.OrderedDict() for tensor in tensors: @@ -737,12 +750,13 @@ def __init__(self, arch, tensors, scalars): maxIndexScalar = {baseName: tuple(map(max, *groups.keys())) if len(groups) > 1 else next(iter(groups.keys())) for baseName, groups in self._scalarCollect.items()} self._groupSizeScalar = {baseName: tuple(map(lambda x: x+1, mi)) for baseName, mi in maxIndexScalar.items()} - def _tensorViewGenerator(self, memoryLayout): + def _tensorViewGenerator(self, tensor): + memoryLayout = tensor.memoryLayout() memLayoutMap = { 'DenseMemoryLayout': self.DenseTensorView, 'CSCMemoryLayout': self.CSCMatrixView } - return memLayoutMap[type(memoryLayout).__name__]() + return memLayoutMap[type(memoryLayout).__name__](tensor.getDatatype(self._arch)) def iterate_collect(self): cur_namespace = '' @@ -870,47 +884,52 @@ def _init(self, cpp, baseName, baseNameWithoutNamespace, name, tensors, declarat if declarationOnly: for group,tensor in tensors.items(): ml = tensor.memoryLayout() - tv = self._tensorViewGenerator(ml) + tv = self._tensorViewGenerator(tensor) tv.arrays(cpp, ml, self._arch, name, index(group), self._numberType, True) valueNames = dict() for group,tensor in tensors.items(): values = tensor.values() memLayout = tensor.memoryLayout() + datatype = tensor.getDatatype(self._arch) if values is not None: - memory = ['0.']*memLayout.requiredReals() + memory = [datatype.literal(0)]*memLayout.requiredReals() for idx,x in values.items(): - memory[memLayout.address(idx)] = x - valuesName = '{}{}{}'.format(name, self.VALUES_BASENAME, index(group)) - valueNames[group] = ['&{}[0]'.format(valuesName)] - cpp('{} {}[] = {{{}}};'.format(self._realType, valuesName, ', '.join(memory))) + memory[memLayout.address(idx)] = datatype.literal(x) + valuesName = f'{name}{self.VALUES_BASENAME}{index(group)}' + valueNames[group] = [f'&{valuesName}[0]'] + cpp('{} {}[] = {{{}}};'.format(self._realType(datatype), valuesName, ', '.join(memory))) if len(valueNames) > 1: - self._array(cpp, self._realPtrType, name + self.VALUES_BASENAME, valueNames, groupSize, alwaysArray=False, constexpr=False, static=False) + _,prototensor = next(iter(tensors.items())) + datatype = prototensor.getDatatype(self._arch) + self._array(cpp, self._realPtrType(datatype), name + self.VALUES_BASENAME, valueNames, groupSize, alwaysArray=False, constexpr=False, static=False) else: with cpp.Struct('{0} : {1}::{0}'.format(baseNameWithoutNamespace, self.TENSOR_NAMESPACE)): for group,tensor in tensors.items(): ml = tensor.memoryLayout() - tv = self._tensorViewGenerator(ml) + tv = self._tensorViewGenerator(tensor) tv.arrays(cpp, ml, self._arch, name, index(group), self._numberType, False) nValueArrays = 0 for group,tensor in tensors.items(): values = tensor.values() + datatype = tensor.getDatatype(self._arch) if values is not None: - name = '{}{}'.format(self.VALUES_BASENAME, index(group)) + name = f'{self.VALUES_BASENAME}{index(group)}' aligned = '' if tensor.memoryLayout().alignedStride(): - aligned = ' __attribute__((aligned({})))'.format(self._arch.cacheline) - cpp('{} {} {}[]{};'.format(STATIC, self._realType, name, aligned)) + aligned = f' __attribute__((aligned({self._arch.cacheline})))' + cpp(f'{STATIC} {self._realType(datatype)} {name}[]{aligned};') nValueArrays += 1 if nValueArrays > 1: - cpp('{} {} {}[];'.format(STATIC, self._realPtrType, self.VALUES_BASENAME)) + cpp(f'{STATIC} {self._realPtrType(datatype)} {self.VALUES_BASENAME}[];') cpp.emptyline() - viewArgs = self.TensorView.arguments(self._arch, False) - viewArgsConst = self.TensorView.arguments(self._arch, True) if len(groupSize) == 0: - ml = next(iter(tensors.values())).memoryLayout() - tv = self._tensorViewGenerator(ml) + prototensor = next(iter(tensors.values())) + ml = prototensor.memoryLayout() + tv = self._tensorViewGenerator(prototensor) + viewArgs = tv.arguments(False) + viewArgsConst = tv.arguments(True) with cpp.Struct(self.VIEW_STRUCT_NAME): cpp(f'using {self.VIEW_TYPE_NAME} = {tv.typename(len(ml.shape()), self._arch, False)};') cpp(f'using {self.VIEW_TYPE_NAME_CONST} = {tv.typename(len(ml.shape()), self._arch, True)};') @@ -925,7 +944,9 @@ def _init(self, cpp, baseName, baseNameWithoutNamespace, name, tensors, declarat if len(groupSize) > 0: for group,tensor in tensors.items(): ml = tensor.memoryLayout() - tv = self._tensorViewGenerator(ml) + tv = self._tensorViewGenerator(tensor) + viewArgs = tv.arguments(False) + viewArgsConst = tv.arguments(True) typename = tv.typename(len(ml.shape()), self._arch, False) typenameConst = tv.typename(len(ml.shape()), self._arch, True) special = ','.join(str(g) for g in group) diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index a5af954..3eb5dfb 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -1,16 +1,18 @@ from ..ast.node import Node, FusedGEMMs, LoopOverGEMM +from ..ast.indices import Indices from collections import OrderedDict from typing import Dict, List - +from ..type import Scalar class Variable(object): - def __init__(self, name, writable, memoryLayout, eqspp=None, tensor=None, is_temporary=False): + def __init__(self, name, writable, memoryLayout, eqspp=None, tensor=None, is_temporary=False, datatype=None): self.name = name self.writable = writable self.tensor = tensor self._memoryLayout = memoryLayout self._eqspp = eqspp self.is_temporary = is_temporary + self.datatype = datatype def variables(self): return {self} @@ -157,13 +159,13 @@ def setWritable(self, name): for v in self._variables: v.setWritable(name) - class ProgramAction(object): - def __init__(self, result, term, add, scalar=None): + def __init__(self, result, term, add, scalar=None, condition=True): self.result = result self.term = term self.add = add self.scalar = scalar + self.condition = condition def isRHSExpression(self): return isinstance(self.term, Expression) @@ -194,15 +196,22 @@ def maySubstitute(self, when, by, result = True, term = True): return (not term or maySubsTerm) and (not result or maySubsResult) and compatible - def substituted(self, when, by, result = True, term = True): + def substituted(self, when, by, cond, result = True, term = True): rsubs = self.result.substituted(when, by) if result else self.result tsubs = self.term.substituted(when, by, rsubs.memoryLayout()) if term else self.term - return ProgramAction(rsubs, tsubs, self.add, self.scalar) + csubs = self.condition & cond + return ProgramAction(rsubs, tsubs, self.add, self.scalar, csubs) def setVariablesWritable(self, name): self.result.setWritable(name) self.term.setWritable(name) + def getCondition(self): + if isinstance(self.condition, CNFCondition): + return self.condition + else: + return CNFCondition(self.condition) + # TODO: probably should be a subclass of ProgramAction class FusedActions(object): @@ -211,6 +220,7 @@ def __init__(self): self._variables: List[Variable] = [] self._adds: List[bool] = [] self._scalars = [] + self._conditions = [] def add(self, action: ProgramAction) -> None: if not isinstance(action.term.node, LoopOverGEMM): @@ -222,13 +232,15 @@ def add(self, action: ProgramAction) -> None: self._variables.extend(action.term.variableList()) self._adds.append(action.add) self._scalars.append(action.scalar) + self._conditions.append(action.condition) def gen_program_action(self) -> ProgramAction: last_action: ProgramAction = self._actions[-1] return ProgramAction(result=last_action.result, term=self._gen_expr(), add=self._adds, - scalar=self._scalars) + scalar=self._scalars, + condition=self._conditions) def _gen_expr(self) -> Expression: node = FusedGEMMs() @@ -252,6 +264,169 @@ def __init__(self, action): self.bufferMap = None +# a rather primitive CNF implementation. +# do not overuse (i.e. avoid conditional assigns where possible) + +class CNFClause: + def __init__(self, variables): + if isinstance(variables, list): + self.variables = {var: True for var in variables} + else: + self.variables = variables + self.fulfilled = False + + def negateVariables(self): + return {var: ~self.variables[var] for var in self.variables} + + def unite(self, clause): + output = CNFClause([]) + for v in self.variables: + if v in clause and clause.variables[v] != self.variables[v]: + output.fulfilled = True + if not output.fulfilled: + output.variables = {**self.variables, **clause.variables} + return output + + def __repr__(self): + formatvar = lambda name: f'{name}' if self.variables[name] else f'~{name}' + return f'[{", ".join(formatvar(var) for var in self.variables)}]' + + def ccode(self): + if not self.fulfilled and len(self.variables) == 0: + return 'false' + # for now, only allow scalarly-indexed variables + printvar = lambda var: f'{var}' if isinstance(var, Scalar) else f'{var}[{var.memoryLayout().addressString(Indices())}]' + formatvar = lambda name: f'{printvar(name)}' if self.variables[name] else f'!{printvar(name)}' + return f'({" || ".join(formatvar(var) for var in self.variables)})' + + def variableIterator(self): + return (var for var in self.variables if isinstance(var, Variable)) + +class CNFCondition: + def __init__(self, data): + if isinstance(data, bool): + if data == True: + self.clauses = [] + elif data == False: + self.clauses = [CNFClause([])] + else: + self.clauses = [CNFClause([data])] + + def _prune(self): + newclauses = [] + for clause in self.clauses: + if clause.fulfilled: + newclauses = [] + break + else: + if len(clause.variables) == 0: + newclauses = [clause] + break + else: + newclauses += [clause] + self.clauses = newclauses + + def tautology(self): + return len(self.clauses) == 0 + + def unfulfillable(self): + return any(not clause.fulfilled and len(clause.variables) == 0 for clause in self.clauses) + + def __not__(self): + if self.tautology(): + return CNFCondition(False) + + # this is the actually painful step (as it's also pretty inefficient right now) + result = CNFCondition(True) + for clause in self.clauses: + clauseInv = CNFCondition(True) + negVar = clause.negateVariables() + clauseInv.clauses = [CNFClause({var: negVar[var]}) for var in negVar] + + result = result | clauseInv + return result + + def __and__(self, other): + return self.__rand__(other) + + def __rand__(self, other): + if not isinstance(other, CNFCondition): + other = CNFCondition(other) + + clauses = self.clauses + other.clauses + + condition = CNFCondition(True) + condition.clauses = clauses + condition._prune() + + return condition + + def __or__(self, other): + return self.__ror__(other) + + def __ror__(self, other): + if not isinstance(other, CNFCondition): + other = CNFCondition(other) + + clauses = [clause.unite(oclause) for clause in self.clauses for oclause in other.clauses] + condition = CNFCondition(True) + condition.clauses = clauses + condition._prune() + + return condition + + def __repr__(self): + return f'{self.clauses}' + + def ccode(self): + if self.tautology(): + return 'true' + elif self.unfulfillable(): + return 'false' + return f'({" && ".join(clause.ccode() for clause in self.clauses)})' + + def variables(self): + return {var for clause in self.clauses for var in clause.variableIterator()} + +class LiveSet: + def __init__(self, data: dict): + self.data = data + + def __sub__(self, other): + if isinstance(other, dict): + other = LiveSet(other) + + result = {k:self.data[k] for k in self.data} + + for var in other.data: + if var in result: + result[var] &= ~other.data[var] + if not result[var]: + result.remove(var) + + return LiveSet(result) + + def __or__(self, other): + if isinstance(other, dict): + other = LiveSet(other) + + result = {k:self.data[k] for k in self.data} + + for var in other.data: + if var in result: + result[var] |= other.data[var] + + return LiveSet(result) + + def __contains__(self, element): + if isinstance(element, tuple): + return element[0] in self.data and (~self.data[element[0]] & element[1]).unfulfillable() + else: + return element in self.data + + def variables(self): + return set(k for k in self.data) + class FusedProgramPoint(ProgramPoint): def __init__(self, action: FusedActions): super().__init__(action.gen_program_action()) diff --git a/yateto/controlflow/transformer.py b/yateto/controlflow/transformer.py index 5d28db8..e48e91d 100644 --- a/yateto/controlflow/transformer.py +++ b/yateto/controlflow/transformer.py @@ -23,9 +23,9 @@ def visit(self, cfg): class LivenessAnalysis(object): def visit(self, cfg): - cfg[-1].live = set() + cfg[-1].live = LiveSet({}) for i in reversed(range(len(cfg)-1)): - cfg[i].live = (cfg[i+1].live - {cfg[i].action.result}) | cfg[i].action.variables() + cfg[i].live = (cfg[i+1].live - {cfg[i].action.result: cfg[i].action.condition}) | {var:cfg[i].action.condition for var in cfg[i].action.variables()} return cfg class SubstituteForward(object): @@ -39,7 +39,7 @@ def visit(self, cfg): and ua.isRHSVariable() \ and ua.term.writable \ and ua.result.isLocal() \ - and ua.term not in v.live \ + and (ua.term, ua.condition) not in v.live \ and (ua.hasTrivialScalar() or ua.term.isLocal()): when = ua.result @@ -47,7 +47,7 @@ def visit(self, cfg): maySubs = all([cfg[j].action.maySubstitute(when, by) for j in range(i, n)]) if maySubs: for j in range(i, n): - cfg[j].action = cfg[j].action.substituted(when, by) + cfg[j].action = cfg[j].action.substituted(when, by, ua.condition) cfg = LivenessAnalysis().visit(cfg) return cfg @@ -62,16 +62,16 @@ def visit(self, cfg): found = -1 for j in range(i): u = cfg[j] - if by not in u.live and not u.action.isCompound() and u.action.result == va.term: + if (va.result, va.condition) not in u.live and not u.action.isCompound() and u.action.result == va.term: found = j break if found >= 0: when = u.action.result maySubs = cfg[found].action.maySubstitute(when, by, term=False) and all([cfg[j].action.maySubstitute(when, by) for j in range(found+1,i+1)]) if maySubs: - cfg[found].action = cfg[found].action.substituted(when, by, term=False) + cfg[found].action = cfg[found].action.substituted(when, by, va.condition, term=False) for j in range(found+1,i+1): - cfg[j].action = cfg[j].action.substituted(when, by) + cfg[j].action = cfg[j].action.substituted(when, by, va.condition) cfg = LivenessAnalysis().visit(cfg) return cfg @@ -109,7 +109,7 @@ def visit(self, cfg): if found >= 0: va = cfg[found].action if ua.maySubstitute(ua.result, va.result, term=False): - cfg[i].action = ua.substituted(ua.result, va.result, term=False) + cfg[i].action = ua.substituted(ua.result, va.result, va.condition, term=False) cfg[i].action.add = va.add if not va.hasTrivialScalar(): cfg[i].action.scalar = va.scalar @@ -136,7 +136,7 @@ def visit(self, cfg): # assign buffer if ua and not ua.isCompound() and not ua.result.isGlobal(): if ua.result in usedBuffers: - buf = usedBuffers[ua.result] + buf = usedBuffers[ua.result] elif len(freeBuffers) > 0: buf = freeBuffers.pop() else: @@ -145,14 +145,15 @@ def visit(self, cfg): cfg[i].bufferMap[ua.result] = buf usedBuffers[ua.result] = buf - size = ua.result.viewed().memoryLayout().storage().requiredReals() + # NOTE: size in bytes + size = ua.result.viewed().memoryLayout().storage().requiredReals() * ua.result.datatype.size() if buf in bufferSize: bufferSize[buf] = max(bufferSize[buf], size) else: bufferSize[buf] = size # free buffers - free = cfg[i].live - cfg[i+1].live + free = cfg[i].live.variables() - cfg[i+1].live.variables() for local in free: # warning: local.isLocal() check is suboptimal (but currently good enough) # refactor liveness for better results diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index 1c111a2..707b562 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -2,7 +2,7 @@ from yateto import Scalar from .graph import * from ..memory import DenseMemoryLayout -from ..ast.node import Permute, Broadcast +from ..ast.node import Permute, Node, Broadcast class AST2ControlFlow(Visitor): TEMPORARY_RESULT = '_tmp' @@ -12,6 +12,7 @@ def __init__(self, simpleMemoryLayout=False): self._cfg = [] self._writable = set() self._simpleMemoryLayout = simpleMemoryLayout + self._condition = [True] def cfg(self): return self._cfg + [ProgramPoint(None)] @@ -23,8 +24,9 @@ def _addTransformOp(self, permute, variable): if not self._simpleMemoryLayout: permute.setEqspp( permute.computeSparsityPattern() ) permute.computeMemoryLayout() + permute.datatype = permute[0].datatype result = self._nextTemporary(permute) - action = ProgramAction(result, Expression(permute, self._ml(permute), [variable]), False) + action = ProgramAction(result, Expression(permute, self._ml(permute), [variable]), False, condition=self._condition[-1]) self._addAction(action) return result @@ -56,7 +58,7 @@ def generic_visit(self, node): variables = [self.visit(child) for child in node] result = self._nextTemporary(node) - action = ProgramAction(result, Expression(node, self._ml(node), variables), False) + action = ProgramAction(result, Expression(node, self._ml(node), variables), False, condition=self._condition[-1]) self._addAction(action) return result @@ -76,7 +78,7 @@ def visit_Add(self, node): add = False for i,var in enumerate(variables): rhs = self._addPermuteIfRequired(node.indices, node[i], var) - action = ProgramAction(tmp, rhs, add) + action = ProgramAction(tmp, rhs, add, condition=self._condition[-1]) self._addAction(action) add = True @@ -86,23 +88,48 @@ def visit_ScalarMultiplication(self, node): variable = self.visit(node.term()) result = self._nextTemporary(node) - action = ProgramAction(result, variable, False, node.scalar()) + action = ProgramAction(result, variable, False, node.scalar(), condition=self._condition[-1]) self._addAction(action) return result def visit_Assign(self, node): + condition = self._condition[-1] + if isinstance(node.condition(), Node): + myCondition = self.visit(node[2]) + else: + myCondition = node.condition() + self.updateWritable(node[0].name()) - variables = [self.visit(child) for child in node] - rhs = self._addPermuteIfRequired(node.indices, node.rightTerm(), variables[1]) - action = ProgramAction(variables[0], rhs, False) + newCondition = condition & CNFCondition(myCondition) + self._condition.append(newCondition) + self._condition = self._condition[:-1] + + rVar = self.visit(node[1]) + rhs = self._addPermuteIfRequired(node.indices, node.rightTerm(), rVar) + + lVar = self.visit(node[0]) + action = ProgramAction(lVar, rhs, False, condition=newCondition) self._addAction(action) - return variables[0] + return lVar def visit_IndexedTensor(self, node): - return Variable(node.name(), node.name() in self._writable, self._ml(node), node.eqspp(), node.tensor, is_temporary=node.tensor.temporary) + return Variable(node.name(), node.name() in self._writable, self._ml(node), node.eqspp(), node.tensor, datatype=node.datatype, is_temporary=node.tensor.temporary) + + def visit_IfThenElse(self, node): + if len(self._condition) > 0: + condition = self._condition.top() + else: + condition = True + self.visit(node.yesTerm()) + self.visit(node.noTerm()) + myCondition = node.condition() + self._condition.push(condition & myCondition) + self._condition.pop() + self._addAction(ProgramAction()) + return self.visit(node.term()) def _addAction(self, action): self._cfg.append(ProgramPoint(action)) @@ -110,7 +137,7 @@ def _addAction(self, action): def _nextTemporary(self, node): name = f'{self.TEMPORARY_RESULT}{self._tmp}' self._tmp += 1 - return Variable(name, True, self._ml(node), node.eqspp(), is_temporary=True) + return Variable(name, True, self._ml(node), node.eqspp(), is_temporary=True, datatype=node.datatype) def updateWritable(self, name): self._writable = self._writable | {name} @@ -124,7 +151,7 @@ def visit(self, cfg): V = set() for pp in cfg: if pp.action: - V = V | pp.action.result.variables() | pp.action.variables() + V = V | pp.action.result.variables() | pp.action.variables() | pp.action.getCondition().variables() return sorted([var for var in V if var.isGlobal()], key=lambda x: str(x)) class SortedPrefetchList(object): diff --git a/yateto/functions.py b/yateto/functions.py new file mode 100644 index 0000000..c5a9196 --- /dev/null +++ b/yateto/functions.py @@ -0,0 +1,65 @@ +from . import ops +from .ast import node +from .type import Datatype + +def add(x, y): return node.Elementwise(ops.Add(), x, y) +def mul(x, y): return node.Elementwise(ops.Mul(), x, y) +def bitwise_or(x, y): return node.Elementwise(ops.Or(), x, y) +def bitwise_and(x, y): return node.Elementwise(ops.And(), x, y) +def bitwise_xor(x, y): return node.Elementwise(ops.Xor(), x, y) + +def sin(x): return node.Elementwise(ops.Sin(), x) +def cos(x): return node.Elementwise(ops.Cos(), x) +def tan(x): return node.Elementwise(ops.Tan(), x) +def asin(x): return node.Elementwise(ops.Asin(), x) +def acos(x): return node.Elementwise(ops.Acos(), x) +def atan(x): return node.Elementwise(ops.Atan(), x) + +def sinh(x): return node.Elementwise(ops.Sinh(), x) +def cosh(x): return node.Elementwise(ops.Cosh(), x) +def tanh(x): return node.Elementwise(ops.Tanh(), x) +def asinh(x): return node.Elementwise(ops.Asinh(), x) +def acosh(x): return node.Elementwise(ops.Acosh(), x) +def atanh(x): return node.Elementwise(ops.Atanh(), x) + +def log(x): return node.Elementwise(ops.Log(), x) +def log1p(x): return node.Elementwise(ops.Log1p(), x) +def exp(x): return node.Elementwise(ops.Exp(), x) +def expm1(x): return node.Elementwise(ops.Expm1(), x) +def sqrt(x): return node.Elementwise(ops.Sqrt(), x) +def cbrt(x): return node.Elementwise(ops.Cbrt(), x) + +def abs(x): return node.Elementwise(ops.Abs(), x) + +def maximum(x, y): return node.Elementwise(ops.Max(), x, y) +def minimum(x, y): return node.Elementwise(ops.Min(), x, y) +def pow(x, y): return node.Elementwise(ops.Pow(), x, y) + +def assign(lhs, rhs): return node.Assign(lhs, rhs) +def assignIf(condition, lhs, rhs): return node.Assign(lhs, rhs, condition) + +# def where(condition, yes, no): return node.IfThenElse(condition, yes, no) +def where(condition, yes, no): return node.Elementwise(ops.Ternary(), yes, no, condition) + +def equal(x, y): return node.Elementwise(ops.CmpEq(), x, y) +def not_equal(x, y): return node.Elementwise(ops.CmpNe(), x, y) +def less(x, y): return node.Elementwise(ops.CmpLt(), x, y) +def less_equal(x, y): return node.Elementwise(ops.CmpLe(), x, y) +def greater(x, y): return node.Elementwise(ops.CmpGt(), x, y) +def greater_equal(x, y): return node.Elementwise(ops.CmpGe(), x, y) + +# extra reduction functions; e.g. for input to `where` +def reduction(op, term, indices): + if len(indices) == 0: + return term + else: + return reduction(op, node.Reduction(op, term, indices[0]), indices[1:]) + +def sum(term, indices): return reduction(ops.Add(), term, indices) +def prod(term, indices): return reduction(ops.Mul(), term, indices) +def all(term, indices): return reduction(ops.And(), term, indices) +def any(term, indices): return reduction(ops.Or(), term, indices) +def min(term, indices): return reduction(ops.Min(), term, indices) +def max(term, indices): return reduction(ops.Max(), term, indices) + +def cast(x, dtype): return node.Elementwise(ops.Typecast(dtype), x) diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index 6eed1ca..22ea731 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -1,5 +1,6 @@ from typing import List from abc import ABC, abstractmethod +from .type import Datatype import operator class Preference(object): @@ -18,31 +19,40 @@ def archSupported(self): return True @abstractmethod - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): pass @abstractmethod def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): pass + # shortcut for legacy reasons + @classmethod + def _equalType(cls, datatypeA, datatypeB, datatypeC, types=(Datatype.F32, Datatype.F64)): + return datatypeA == datatypeC and datatypeB == datatypeC and datatypeC in types + class BLASlike(GemmTool): - def __init__(self, operation_name: str, includes: List[str], c_code_init: str = ''): - super().__init__(operation_name, includes) + def __init__(self, prefix, includes: List[str], c_code_init: str = ''): + super().__init__(prefix, includes) self.c_code_init = c_code_init - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): return Preference.MODERATE def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return (not sparseA and not sparseB and target == 'cpu') + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return (not sparseA and not sparseB and target == 'cpu' and self._equalType(datatypeA, datatypeB, datatypeC)) def bool2Trans(self, trans): return 'Cblas{}Trans'.format('' if trans else 'No') def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, - alignedA, alignedC, prefetchName): + alignedA, alignedC, datatypeA, datatypeB, datatypeC, prefetchName): + precision = { + Datatype.F32: 's', + Datatype.F64: 'd' + }[datatypeC] parameters = [ 'CblasColMajor', self.bool2Trans(transA), @@ -51,39 +61,43 @@ def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, alpha, A, ldA, B, ldB, beta, C, ldC] - return '{}({});'.format(self.operation_name, ', '.join(str(p) for p in parameters)) + return '{}_{}gemm({});'.format(self.prefix, precision, ', '.join(str(p) for p in parameters)) class MKL(BLASlike): def __init__(self, arch): self._arch = arch - super().__init__('cblas_{}gemm'.format(arch.precision.lower()), ['mkl_cblas.h']) + super().__init__('cblas', ['mkl_cblas.h']) def archSupported(self): return self._arch.host_name.lower() in {'snb', 'hsw', 'skx', 'knl'} or self._arch.host_name.lower().startswith('avx') class OpenBLAS(BLASlike): def __init__(self, arch): - super().__init__('cblas_{}gemm'.format(arch.precision.lower()), ['cblas.h']) + super().__init__('cblas', ['cblas.h']) class BLIS(BLASlike): def __init__(self, arch): - super().__init__('bli_{}gemm'.format(arch.precision.lower()), ['blis.h'], '{0} _blis_alpha; {0} _blis_beta;'.format(arch.typename)) - self._typename = arch.typename + super().__init__('bli', ['blis.h']) def bool2Trans(self, trans): return 'BLIS{}TRANSPOSE'.format('_' if trans else '_NO_') def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, - alignedA, alignedC, prefetchName): - init = '_blis_alpha = {}; _blis_beta = {};'.format(alpha, beta) + alignedA, alignedC, datatypeA, datatypeB, datatypeC, prefetchName): + precision = { + Datatype.F32: 's', + Datatype.F64: 'd' + }[datatypeC] + initA = f'{datatypeC.ctype()} _blis_alpha = {alpha};' + initB = f'{datatypeC.ctype()} _blis_beta = {beta};' parameters = [ self.bool2Trans(transA), self.bool2Trans(transB), M, N, K, - '&_blis_alpha', 'const_cast<{}*>({})'.format(self._typename, A), 1, ldA, - 'const_cast<{}*>({})'.format(self._typename, B), 1, ldB, + '&_blis_alpha', f'const_cast<{datatypeA.ctype()}*>({A})', 1, ldA, + f'const_cast<{datatypeB.ctype()}*>({B})', 1, ldB, '&_blis_beta', C, 1, ldC] - return '{} {}({});'.format(init, self.operation_name, ', '.join(str(p) for p in parameters)) + return '{{ {}{} {}_{}gemm({}); }}'.format(initA, initB, self.prefix, precision, ', '.join(str(p) for p in parameters)) class Eigen(BLASlike): def __init__(self, arch): @@ -91,24 +105,24 @@ def __init__(self, arch): self._arch = arch def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): return (not sparseA and not sparseB and target == 'cpu') def bool2Trans(self, trans): return '.transpose()' if trans else '' def sizeTrans(self, rows, cols, trans): - return '{},{}'.format(cols,rows) if trans else '{},{}'.format(rows,cols) + return f'{cols},{rows}' if trans else f'{rows},{cols}' def align(self, ld): aligned = 'Unaligned' if self._arch.checkAlignment(ld) and self._arch.alignment in [16,32,64,128]: - aligned = 'Aligned{}'.format(self._arch.alignment) + aligned = f'Aligned{self._arch.alignment}' return aligned def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, - alignedA, alignedC, prefetchName): + alignedA, alignedC, datatypeA, datatypeB, datatypeC, prefetchName): AxB = '{alpha}_mapA{transA}*_mapB{transB}'.format( alpha=str(alpha) + '*' if alpha != 1.0 else '', transA=self.bool2Trans(transA), transB=self.bool2Trans(transB), @@ -124,12 +138,15 @@ def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, using Eigen::Matrix; using Eigen::Map; using Eigen::Stride; - Map,Eigen::{alignA},Stride<{ldA},1>> _mapA(const_cast<{prec}*>({A})); - Map,Eigen::Unaligned,Stride<{ldB},1>> _mapB(const_cast<{prec}*>({B})); - Map,Eigen::{alignC},Stride<{ldC},1>> _mapC({C}); + Map,Eigen::{alignA},Stride<{ldA},1>> _mapA(const_cast<{precA}*>({A})); + Map,Eigen::Unaligned,Stride<{ldB},1>> _mapB(const_cast<{precB}*>({B})); + Map,Eigen::{alignC},Stride<{ldC},1>> _mapC({C}); {code} }} - """.format(prec=self._arch.typename, M=M, N=N, + """.format(precA=datatypeA.ctype(TypeFlavor.EIGEN), + precB=datatypeB.ctype(TypeFlavor.EIGEN), + precC=datatypeC.ctype(TypeFlavor.EIGEN), + M=M, N=N, sizeA=self.sizeTrans(M,K,transA), sizeB=self.sizeTrans(K,N,transB), ldA=ldA, ldB=ldB, ldC=ldC, A=A, B=B, C=C, @@ -163,7 +180,7 @@ def __init__(self, arch, cmd: str = 'libxsmm_gemm_generator', threshold: int = 1 self._threshold = threshold self._arch = arch - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if (m*n*k)**(1./3.) <= self._threshold: return Preference.HIGH return Preference.LOW @@ -173,13 +190,13 @@ def archSupported(self): return self._arch.host_name.lower() in supported_set def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): # Note: # Libxsmm falls back to blas for transA and more general alpha/beta # See e.g. here: # https://libxsmm.readthedocs.io/en/latest/libxsmm_qna/#what-is-a-small-matrix-multiplication # https://github.com/hfp/libxsmm/issues/396#issuecomment-674741063 - return self.archSupported() and not (sparseA or sparseB) and (not transA) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' + return self.archSupported() and not (sparseA or sparseB) and (not transA) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' and self._equalType(datatypeA, datatypeB, datatypeC) # TODO: no, there's more class LIBXSMM(CodeGenerator): def __init__(self, arch, cmd: str = 'libxsmm_gemm_generator', threshold: int = 128): @@ -191,10 +208,10 @@ def archSupported(self): return self._arch.host_name.lower() in supported_set def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return self.archSupported() and not (sparseA and sparseB) and (not transA and not transB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return self.archSupported() and not (sparseA and sparseB) and (not transA and not transB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' and (self._equalType(datatypeA, datatypeB, datatypeC) or (self._equalType(datatypeA, datatypeB, datatypeC, (Datatype.I16,)) and not sparseA and not sparseB)) - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if sparseA: return Preference.LOW if sparseB: @@ -213,14 +230,14 @@ def archSupported(self): return self._arch.host_name.lower() in supported_set def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): # NOTE: PSpaMM 0.3.0+ supports SIMD-aligned block sparsity in A (which is currently covered by sparseA + alignedA) # also, it supports for AVX512/10 and SVE unaligned matmuls in 0.3.1 noAlign = self._arch.host_name.lower() in {'thunderx2t99', 'knl', 'skx', 'a64fx', 'bergamo', 'turin', 'sve128', 'sve256', 'sve512', 'sve1024', 'sve2048', 'avx10-128', 'avx10-256', 'avx10-512'} alignment = sparseA and alignedA or not sparseA and (noAlign or alignedA) - return self.archSupported() and (alignedC or noAlign) and alignment and (not transA and not transB) and target == 'cpu' + return self.archSupported() and (alignedC or noAlign) and alignment and (not transA and not transB) and target == 'cpu' and self._equalType(datatypeA, datatypeB, datatypeC, [Datatype.BF16, Datatype.F16, Datatype.F32, Datatype.F64]) - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if sparseB: return Preference.HIGH if sparseA and alignedA: @@ -246,10 +263,10 @@ def archSupported(self): return self._arch.backend.lower() in {'cuda', 'hip', 'oneapi', 'acpp', 'hipsycl'} def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return self.archSupported() and not (sparseA or sparseB) and target == 'gpu' + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return self.archSupported() and not (sparseA or sparseB) and target == 'gpu' and self._equalType(datatypeA, datatypeB, datatypeC) - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if sparseA and sparseB: return Preference.LOWEST if not transA: @@ -263,15 +280,15 @@ def __init__(self, arch): super().__init__('', [], '', arch) self._arch = arch - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): return Preference.HIGHEST def archSupported(self): return self._arch.backend.lower() in {'oneapi'} def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return self.archSupported() and not (sparseA or sparseB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return self.archSupported() and not (sparseA or sparseB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' and self._equalType(datatypeA, datatypeB, datatypeC) # TODO: really? class GeneratorCollection(object): @@ -280,13 +297,13 @@ def __init__(self, gemmTools: List[GemmTool]): self.selected = set() def getGemmTool(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): tools = dict() for gemmTool in reversed(self.gemmTools): if gemmTool.supported(m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): tools[gemmTool.preference(m, n, k, sparseA, sparseB, transA, transB, alpha, beta, - alignedA, alignedC)] = gemmTool + alignedA, alignedC, datatypeA, datatypeB, datatypeC, target)] = gemmTool select = None if tools: diff --git a/yateto/generator.py b/yateto/generator.py index 77ed2c9..9f97d58 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -56,8 +56,11 @@ def __init__(self, name, ast, prefetch=None, namespace=None, target='cpu'): def isValidName(cls, name): return re.match(cls.VALID_NAME, name) is not None - def prepareUntilUnitTest(self): + def prepareUntilUnitTest(self, arch): self.ast = [DeduceIndices().visit(ast) for ast in self.ast] + dtd = SetDatatype1(arch) + for ast in self.ast: + dtd.visit(ast) ast2cf = AST2ControlFlow(simpleMemoryLayout=True) for ast in self.ast: ast2cf.visit(ast) @@ -83,6 +86,7 @@ def prepareUntilCodeGen(self, cost_estimator, enableFusedGemm: bool): permutationVariants = FindIndexPermutations().visit(ast) ast = SelectIndexPermutations(permutationVariants).visit(ast) ast = ImplementContractions().visit(ast) + ast = SetDatatype2().visit(ast) if self._prefetch is not None: prefetchCapabilities = FindPrefetchCapabilities().visit(ast) assignPf = AssignPrefetch(prefetchCapabilities, prefetch) @@ -171,9 +175,9 @@ def add(self, name, ast, prefetch=None, namespace=None, target='cpu'): def kernels(self): return self._kernels.values() - def prepareUntilUnitTest(self): + def prepareUntilUnitTest(self, arch): for kernel in self._kernels.values(): - kernel.prepareUntilUnitTest() + kernel.prepareUntilUnitTest(arch) def prepareUntilCodeGen(self, costEstimator, enableFusedGemm: bool): for kernel in self._kernels.values(): @@ -300,9 +304,9 @@ def generate(self, print('Deducing indices...') for kernel in self._kernels: - kernel.prepareUntilUnitTest() + kernel.prepareUntilUnitTest(self._arch) for family in self._kernelFamilies.values(): - family.prepareUntilUnitTest() + family.prepareUntilUnitTest(self._arch) fUTdoctest = self.FileNames(outputDir, self.DOCTEST_FILE_NAME) fUTcxxtest = self.FileNames(outputDir, self.CXXTEST_FILE_NAME) @@ -360,6 +364,7 @@ def unit_test_body(cpp, testFramework): kernelSourceContent = '' with Cpp(kernelSource) as cpp: cpp.includeSys('cassert') + cpp.includeSys('cmath') cpp.includeSys('cstring') cpp.includeSys('cstdlib') cpp.includeSys('limits') @@ -455,7 +460,7 @@ def unit_test_body(cpp, testFramework): cpp.include(fInit.hName) with cpp.Namespace(namespace): initGen.generateInitCpp(cpp) - + prefixnsp = lambda a: a.name if a.namespace == '' else f'{a.namespace}::{a.name}' return { 'namespace': namespace, diff --git a/yateto/metagen.py b/yateto/metagen.py index cc910d2..fc1d36d 100644 --- a/yateto/metagen.py +++ b/yateto/metagen.py @@ -95,7 +95,7 @@ def headerForward(name, data): for entry in data: self.template(header, entry, data[entry], f'{name}') - + headerForward('tensor', tensors) headerForward('init', tensors) headerForward('kernel', kernels) @@ -105,7 +105,7 @@ def cppForward(name): for gendata in self.generators: outdirname = f'metagen_{gendata["name"]}' header.include(f'{outdirname}/{name}.cpp') - + cppForward('tensor') cppForward('init') cppForward('kernel') @@ -131,12 +131,12 @@ def inner(): templatetypes = ', '.join(f'{typ} Arg{i}' for i, typ in enumerate(self.templateType)) templateargs = ', '.join(f'Arg{i}' for i, _ in enumerate(self.templateType)) - + with header.Namespace('internal'): header(f'template<{templatetypes}> struct {internalName} {"{"} using Type = void; {"}"};') for gnsp, spec in foundin: spectext = ', '.join(str(specpart) for specpart in spec) header(f'template<> struct {internalName}<{spectext}> {"{"} using Type = ::{gnsp}::{fullname}; {"}"};') header(f'template<{templatetypes}> using {name} = typename internal::{internalName}<{templateargs}>::Type;') - + self.namespacing(header, splitname[:-1] + [subnsp], inner) diff --git a/yateto/ops.py b/yateto/ops.py new file mode 100644 index 0000000..825a05b --- /dev/null +++ b/yateto/ops.py @@ -0,0 +1,349 @@ +import numpy as np +from .type import Datatype + +class Operation: + #def callstr(self, *args) -> str: + # raise NotImplementedError() + + def call(self, *args): + raise NotImplementedError() + + def datatypeResult(self, argtypes): + raise NotImplementedError() # TODO + + def __str__(self): + return type(self).__name__ + + def __eq__(self, other): + # we're more or less using "dummy" types here + return type(self).__name__ == type(other).__name__ + +class CommutativeMonoidMixin: + def neutral(self): + pass + +class RingMixin: + def formsRing(self, op): + pass + +class UnaryArgsMixin: + pass + +class BinaryArgsMixin: + pass + +class CFunctionMixin: + def cppname(self) -> str: + raise NotImplementedError() + + def callstr(self, *args) -> str: + return f'{self.cppname()}({", ".join(str(arg) for arg in args)})' + +class CUnaryOperatorMixin: + def cppname(self) -> str: + raise NotImplementedError() + + def callstr(self, *args) -> str: + return f'{self.cppname()}({args[0]})' + +class CBinaryOperatorMixin: + def cppname(self) -> str: + raise NotImplementedError() + + def callstr(self, *args) -> str: + return f'({args[0]}) {self.cppname()} ({args[1]})' + +class Sin(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::sin' + def call(self, *args): + return np.sin(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Cos(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::cos' + def call(self, *args): + return np.cos(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Tan(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::tan' + def call(self, *args): + return np.tan(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Asin(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::asin' + def call(self, *args): + return np.asin(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Acos(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::acos' + def call(self, *args): + return np.acos(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Atan(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::atan' + def call(self, *args): + return np.atan(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] + +class Sinh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::sinh' + def call(self, *args): + return np.sinh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Cosh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::cosh' + def call(self, *args): + return np.cosh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Tanh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::tanh' + def call(self, *args): + return np.tanh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Asinh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::asinh' + def call(self, *args): + return np.asinh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Acosh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::acosh' + def call(self, *args): + return np.acosh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Atanh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::atanh' + def call(self, *args): + return np.atanh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] + +class Log(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::log' + def call(self, *args): + return np.log(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Exp(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::exp' + def call(self, *args): + return np.exp(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Log1p(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::log1p' + def call(self, *args): + return np.log1p(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Expm1(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::expm1' + def call(self, *args): + return np.expm1(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Sqrt(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::sqrt' + def call(self, *args): + return np.sqrt(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] +class Cbrt(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::cbrt' + def call(self, *args): + return np.cbrt(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] + +class Abs(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::abs' + def call(self, *args): + return np.abs(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] + + +class Max(Operation, CFunctionMixin, BinaryArgsMixin, CommutativeMonoidMixin): + def neutral(self): + return -float('inf') + def cppname(self): + return 'std::max' + def call(self, *args): + return max(args[0], args[1]) + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] +class Min(Operation, CFunctionMixin, BinaryArgsMixin, CommutativeMonoidMixin): + def neutral(self): + return float('inf') + def cppname(self): + return 'std::min' + def call(self, *args): + return min(args[0], args[1]) + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] +class Pow(Operation, CFunctionMixin, BinaryArgsMixin): + def cppname(self): + return 'std::pow' + def call(self, *args): + return pow(args[0], args[1]) + def datatypeResult(self, argtypes): + return argtypes[0] + +class Div(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '/' + def call(self, *args): + return args[0] / args[1] + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] + +class Add(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin): + def cppname(self, *args): + return '+' + def call(self, *args): + return args[0] + args[1] + def neutral(self): + return 0 + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] +class Mul(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): + def cppname(self, *args): + return '*' + def call(self, *args): + return args[0] * args[1] + def neutral(self): + return 1 + def formsRing(self, op): + return op == Add() + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] + +class And(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): + def cppname(self, *args): + return '&' + def call(self, *args): + return args[0] & args[1] + def neutral(self): + return True + def formsRing(self, op): + return op == Or() + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] +class Or(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): + def cppname(self, *args): + return '|' + def call(self, *args): + return args[0] | args[1] + def neutral(self): + return False + def formsRing(self, op): + return op == And() + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] +class Not(Operation, CUnaryOperatorMixin, UnaryArgsMixin): + def cppname(self, *args): + return '~' + def call(self, *args): + return ~args[0] + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] + +class CmpEq(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '==' + def call(self, *args): + return args[0] == args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL +class CmpNe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '!=' + def call(self, *args): + return args[0] != args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL +class CmpLt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '<' + def call(self, *args): + return args[0] < args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL +class CmpLe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '<=' + def call(self, *args): + return args[0] <= args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL +class CmpGt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '>' + def call(self, *args): + return args[0] > args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL +class CmpGe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '>=' + def call(self, *args): + return args[0] >= args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL + +# replacement; however it'll execute both code paths, regardless of the result +class Ternary(Operation): + def callstr(self, *args): + return f'(({args[2]}) ? ({args[0]}) : ({args[1]}))' + def call(self, *args): + return np.where(args[2], args[0], args[1]) + def datatypeResult(self, argtypes): + assert argtypes[0] == argtypes[1] + return argtypes[0] +class Typecast(Operation, CFunctionMixin, UnaryArgsMixin): + def __init__(self, target: Datatype): + self.target = target + def cppname(self, *args): + return f'static_cast<{self.target.ctype()}>' + def call(self, *args): + return np.astype(args[0], self.target.nptype()) + def datatypeResult(self, argtypes): + return self.target + def __str__(self): + return f'Cast<{self.target}>' diff --git a/yateto/type.py b/yateto/type.py index 874645e..ed6505c 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -1,10 +1,134 @@ import re -from .ast.node import Node, IndexedTensor from numpy import ndarray, zeros, float64 from .memory import DenseMemoryLayout from . import aspp +from enum import Enum + +import numpy as np + +class Datatype(Enum): + BOOL = 0 + I8 = 1 + I16 = 2 + I32 = 3 + I64 = 4 + F32 = 5 + F64 = 6 + F16 = 7 + BF16 = 8 + F128 = 9 + + def __str__(self): + return { + Datatype.BOOL: 'bool', + Datatype.I8: 'i8', + Datatype.I16: 'i16', + Datatype.I32: 'i32', + Datatype.I64: 'i64', + Datatype.F32: 'f32', + Datatype.F64: 'f64', + Datatype.F128: 'f128', + Datatype.F16: 'f16', + Datatype.BF16: 'bf16', + }[self] + + def ctype(self): + return { + Datatype.BOOL: 'bool', + Datatype.I8: 'int8_t', + Datatype.I16: 'int16_t', + Datatype.I32: 'int32_t', + Datatype.I64: 'int64_t', + Datatype.F32: 'float', + Datatype.F64: 'double', + Datatype.F16: 'yateto::f16_ty', + Datatype.BF16: 'yateto::bf16_ty', + Datatype.F128: 'yateto::f128_ty', + }[self] + + def nptype(self): + return { + Datatype.BOOL: np.bool, + Datatype.I8: np.int8, + Datatype.I16: np.int16, + Datatype.I32: np.int32, + Datatype.I64: np.int64, + Datatype.F32: np.float32, + Datatype.F64: np.float64, + Datatype.F16: np.float16, + Datatype.BF16: np.float32, # NYI + Datatype.F128: np.float128, + }[self] + + def size(self): + # unpacked size + return { + Datatype.BOOL: 1, + Datatype.I8: 1, + Datatype.I16: 2, + Datatype.I32: 4, + Datatype.I64: 8, + Datatype.F32: 4, + Datatype.F64: 8, + Datatype.F16: 2, + Datatype.BF16: 2, + Datatype.F128: 16, + }[self] + + def safeint(self, value): + # allow inf/-inf to be treated as int + return int(max(-2**64, min(2**64, value))) + + def literal(self, value): + # (note: the extra lambda mapping is needed to prevent type errors) + return { + Datatype.BOOL: lambda value: 'true' if value else 'false', + Datatype.I8: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.I16: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.I32: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.I64: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.F32: lambda value: f'{float(value):.16}f', + Datatype.F64: lambda value: f'{float(value):.16}', + Datatype.F16: lambda value: f'static_cast({float(value):.16})', + Datatype.BF16: lambda value: f'static_cast({float(value):.16})', + Datatype.F128: lambda value: f'static_cast({float(value):.32}q)', + }[self](value) + +class AddressingMode(Enum): + DIRECT = 0 + STRIDED = 1 + INDIRECT = 2 + SCALAR = 3 + + def pointer_type(self): + return { + AddressingMode.DIRECT: '*', + AddressingMode.STRIDED: '*', + AddressingMode.INDIRECT: '**', + AddressingMode.SCALAR: '', + }[self] + +class Symbol(object): + def __init__(self, datatype): + # datatype == None is treated as datatype == arch.datatype + self.datatype = datatype + + def getDatatype(self, arch): + return arch.datatype if self.datatype is None else self.datatype + +class ScalarMixin: + pass + +class ImmediateScalar(Symbol, ScalarMixin): + def __init__(self, data, datatype): + super().__init__(datatype) + self.data = data + +class AbstractType(Symbol): + def __init__(self, name, datatype): + super().__init__(datatype) + self._name = name -class AbstractType(object): @classmethod def isValidName(cls, name): return re.match(cls.VALID_NAME, name) is not None @@ -15,17 +139,18 @@ def name(self): class IdentifiedType(AbstractType): BASE_NAME = r'[a-zA-Z]\w*' GROUP_INDEX = r'(0|[1-9]\d*)' - GROUP_INDICES = r'\(({0}(,{0})*)\)'.format(GROUP_INDEX) - VALID_NAME = r'^{}({})?$'.format(BASE_NAME, GROUP_INDICES) + GROUP_INDICES = rf'\(({GROUP_INDEX}(,{GROUP_INDEX})*)\)' + VALID_NAME = rf'^{BASE_NAME}({GROUP_INDICES})?$' - def __init__(self, name, namespace=None): + def __init__(self, name, namespace=None, datatype=None): + super().__init__(name, datatype) if not self.isValidName(name): - raise ValueError('Invalid name (must match regexp {}): {}'.format(self.VALID_NAME, name)) + raise ValueError(f'Invalid name (must match regexp {self.VALID_NAME}): {name}') self._name = name self.namespace = namespace - self.datatype = None # TODO + self.datatype = datatype def __str__(self): return self._name @@ -69,9 +194,9 @@ def nameWithNamespace(self): def __hash__(self): return hash(self._name) -class Scalar(IdentifiedType): - def __init__(self, name, namespace=None): - super().__init__(name, namespace=namespace) +class Scalar(IdentifiedType, ScalarMixin): + def __init__(self, name, namespace=None, datatype=None): + super().__init__(name, namespace=namespace, datatype=datatype) def __hash__(self): return hash(self._name) @@ -84,8 +209,10 @@ def __init__(self, memoryLayoutClass=DenseMemoryLayout, alignStride=False, namespace=None, + datatype=None, + addressing=None, temporary=False): - super().__init__(name, namespace=namespace) + super().__init__(name, namespace=namespace, datatype=datatype) if not isinstance(shape, tuple): raise ValueError('shape must be a tuple') @@ -93,12 +220,15 @@ def __init__(self, raise ValueError('shape must not contain entries smaller than 1') if not self.isValidName(name): - raise ValueError('Tensor name invalid (must match regexp {}): {}'.format(self.VALID_NAME, name)) + raise ValueError(f'Tensor name invalid (must match regexp {self.VALID_NAME}): {name}') self._name = name self._shape = shape self._values = None + # default addressing mode. If not given, deduce it + self.addressing = addressing + self.temporary = temporary if namespace is None: @@ -147,6 +277,7 @@ def setGroupSpp(self, spp): self.setMemoryLayout(self._memoryLayout.__class__, alignStride=self._memoryLayout.alignedStride()) def __getitem__(self, indexNames): + from .ast.node import IndexedTensor return IndexedTensor(self, indexNames) def shape(self):