From 4124a58c0afb89376dbb929b80f89149c7024249 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Fri, 31 Jan 2025 01:07:47 +0100 Subject: [PATCH 01/18] Add a tensor datatype, introduce tensor addressing --- yateto/arch.py | 16 ++++++--- yateto/type.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/yateto/arch.py b/yateto/arch.py index 8690788..a322d04 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -38,6 +38,7 @@ # from .memory import DenseMemoryLayout +from .type import Datatype class Architecture(object): def __init__(self, @@ -66,13 +67,11 @@ def __init__(self, self.precision = precision.upper() if self.precision == 'D': - self.bytesPerReal = 8 - self.typename = 'double' self.epsilon = 2.22e-16 + self.datatype = Datatype.F64 elif self.precision == 'S': - self.bytesPerReal = 4 - self.typename = 'float' self.epsilon = 1.19e-7 + self.datatype = Datatype.F32 else: raise ValueError(f'Unknown precision type {self.precision}') self.alignment = alignment @@ -102,11 +101,18 @@ def checkAlignment(self, offset): return offset % self.alignedReals == 0 def formatConstant(self, constant): - return str(constant) + ('f' if self.precision == 'S' else '') + return self.datatype.literal(constant) def onHeap(self, numReals): return (numReals * self.bytesPerReal) > self._tmpStackLimit + @property + def typename(self): + return self.datatype.ctype() + + @property + def bytesPerReal(self): + return self.datatype.size() def _get_name_and_precision(ident): return ident[1:], ident[0].upper() diff --git a/yateto/type.py b/yateto/type.py index fba8f3b..4ed0980 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -3,6 +3,71 @@ from numpy import ndarray, zeros, float64 from .memory import DenseMemoryLayout from . import aspp +from enum import Enum + +class Datatype(Enum): + BOOL = 0 + I8 = 1 + I16 = 2 + I32 = 3 + I64 = 4 + F32 = 5 + F64 = 6 + F16 = 7 + BF16 = 8 + + def ctype(self): + return { + Datatype.BOOL: 'bool', + Datatype.I8: 'int8_t', + Datatype.I16: 'int16_t', + Datatype.I32: 'int32_t', + Datatype.I64: 'int64_t', + Datatype.F32: 'float', + Datatype.F64: 'double', + Datatype.F16: 'int16_t', + Datatype.BF16: 'int16_t', + }[self] + + def size(self): + # unpacked size + return { + Datatype.BOOL: 1, + Datatype.I8: 1, + Datatype.I16: 2, + Datatype.I32: 4, + Datatype.I64: 8, + Datatype.F32: 4, + Datatype.F64: 8, + Datatype.F16: 2, + Datatype.BF16: 2, + }[self] + + def literal(self, value): + # TODO: BF16, F16 + return { + Datatype.BOOL: 'true' if value else 'false', + Datatype.I8: f'{int(value)}', + Datatype.I16: f'{int(value)}', + Datatype.I32: f'{int(value)}', + Datatype.I64: f'{int(value)}LL', + Datatype.F32: f'{float(value):.16}f', + Datatype.F64: f'{float(value):.16}' + }[self] + +class AddressingMode(Enum): + DIRECT = 0 + STRIDED = 1 + INDIRECT = 2 + SCALAR = 3 + + def pointer_type(self): + return { + AddressingMode.DIRECT: '*', + AddressingMode.STRIDED: '*', + AddressingMode.INDIRECT: '**', + AddressingMode.SCALAR: '', + }[self] class AbstractType(object): @classmethod @@ -18,13 +83,19 @@ class IdentifiedType(AbstractType): GROUP_INDICES = r'\(({0}(,{0})*)\)'.format(GROUP_INDEX) VALID_NAME = r'^{}({})?$'.format(BASE_NAME, GROUP_INDICES) - def __init__(self, name, namespace=None): + def __init__(self, name, namespace=None, datatype=None): if not self.isValidName(name): raise ValueError('Invalid name (must match regexp {}): {}'.format(self.VALID_NAME, name)) self._name = name self.namespace = namespace - + + # datatype == None is treated as datatype == self._arch.datatype + self.datatype = datatype + + def getDatatype(self, arch): + return self._arch.datatype if self.datatype is None else self.datatype + def __str__(self): return self._name @@ -68,8 +139,8 @@ def __hash__(self): return hash(self._name) class Scalar(IdentifiedType): - def __init__(self, name, namespace=None): - super().__init__(name, namespace=namespace) + def __init__(self, name, namespace=None, datatype=None): + super().__init__(name, namespace=namespace, datatype=datatype) def __hash__(self): return hash(self._name) @@ -81,8 +152,10 @@ def __init__(self, spp=None, memoryLayoutClass=DenseMemoryLayout, alignStride=False, - namespace=None): - super().__init__(name, namespace=namespace) + namespace=None, + datatype=None, + addressing=None): + super().__init__(name, namespace=namespace, datatype=datatype) if not isinstance(shape, tuple): raise ValueError('shape must be a tuple') @@ -96,6 +169,9 @@ def __init__(self, self._shape = shape self._values = None + # default addressing mode. If not given, deduce it + self.addressing = addressing + if namespace is None: self.namespace = '' else: From 3dbe6e9ad029bbbed82803a3e8556852fb0fafc1 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Fri, 31 Jan 2025 01:08:03 +0100 Subject: [PATCH 02/18] Start propagating the datatype to the AST --- yateto/ast/node.py | 13 +++++++++++++ yateto/ast/transformer.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 93da011..3c0761e 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -9,6 +9,7 @@ def __init__(self): self.indices = None self._children = [] self._eqspp = None + self.datatype = None def size(self): return self.indices.size() @@ -337,6 +338,18 @@ def computeSparsityPattern(self, *spps): spp = spps[0] if len(spps) == 1 else self.term().eqspp() return spp.indexSum(self.term().indices, self.indices) +class DatatypeCast(UnaryOp): + def __init__(self, term, datatype): + super().__init__(term) + self.indices = term.indices + self.newDatatype = datatype + + def nonZeroFlops(self): + return self.term().eqspp().count_nonzero() + + def computeSparsityPattern(self, *spps): + return self.term().indices + class Contraction(BinOp): def __init__(self, indices, lTerm, rTerm, sumIndices): super().__init__(lTerm, rTerm) diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index 152ce50..e11e99e 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -238,3 +238,19 @@ def generic_visit(self, node): def visit_IndexedTensor(self, node): return node + +class SetDatatype(Transformer): + def generic_visit(self, node): + super().generic_visit(node) + assert(len(node) > 0) + assert(all(child.datatype == node[0].datatype for child in node)) + node.datatype = node[0].datatype + return node + + def visit_IndexedTensor(self, node): + node.datatype = node.tensor.datatype + return node + + def visit_DatatypeCast(self, node): + node.datatype = node.newDatatype + return node From 3193d89fd2d8ae2dfaf75e8e9adb63b0a27efd1a Mon Sep 17 00:00:00 2001 From: David Schneller Date: Fri, 31 Jan 2025 01:08:17 +0100 Subject: [PATCH 03/18] Begin adjusting the TensorDescriptions --- yateto/codegen/common.py | 45 ++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index e267494..dce5edd 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -5,9 +5,8 @@ from .tiny_tensor_language import Dump, Function, IntegerType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst import hashlib - class TensorDescription(object): - def __init__(self, name, memoryLayout, eqspp, is_compute_constant=False, is_temporary=False): + def __init__(self, name, memoryLayout, eqspp, is_compute_constant=False, is_temporary=False, values=None, datatype=None, addressing=None): """ Args: @@ -24,31 +23,32 @@ def __init__(self, name, memoryLayout, eqspp, is_compute_constant=False, is_temp self.eqspp = eqspp self.is_compute_constant = is_compute_constant self.is_temporary = is_temporary - BoundingBox(eqspp) + self.values = values + self.datatype = datatype + self.addressing = addressing @classmethod def fromNode(cls, name, node): return cls(name, node.memoryLayout(), node.eqspp()) class IndexedTensorDescription(TensorDescription): - def __init__(self, name, indices, memoryLayout, eqspp, is_compute_constant=False, is_temporary=False): - super().__init__(name, memoryLayout, eqspp, is_compute_constant, is_temporary) + def __init__(self, name, indices, memoryLayout, eqspp, is_compute_constant=False, is_temporary=False, values=None, datatype=None, addressing=None): + super().__init__(name, memoryLayout, eqspp, is_compute_constant, is_temporary, values, datatype, addressing) self.indices = indices @classmethod def fromNode(cls, var, node): is_const = False + values = None + datatype = None + addressing = None if hasattr(node, 'tensor'): is_const = node.tensor.is_compute_constant() - return cls(str(var), node.indices, var.memoryLayout(), node.eqspp(), is_const, var.is_temporary) - - @classmethod - def fromVar(cls, var, indices): - is_const = False - if hasattr(var, 'tensor'): - if var.tensor is not None: - is_const = var.tensor.is_compute_constant() - return cls(str(var), indices, var.memoryLayout(), var.eqspp(), is_const, var.is_temporary) + if is_const: + values = node.tensor.values() + datatype = node.datatype + addressing = node.tensor.addressing + return cls(str(var), node.indices, var.memoryLayout(), node.eqspp(), is_const, var.is_temporary, values, datatype, addressing) def forLoops(cpp, indexNames, ranges, body, pragmaSimd=True, prefix='_', indexNo=None): flops = 0 @@ -108,22 +108,27 @@ class BatchedOperationsAux: def __init__(self, underlying_data_type): self.underlying_data_type = underlying_data_type - def _get_ptr_type(self, addressing): - return '**' if addressing == 'pointer_based' else '*' + def _get_ptr_type(self, addressing: AddressingMode): + return addressing.pointer_type() def deduce_addresing(self, term): + if term.addressing is not None: + return term.addressing + + # default deduction if term.is_compute_constant: - return 'none' + return AddressingMode.DIRECT if term.is_temporary: - return 'strided' + return AddressingMode.STRIDED else: - return 'pointer_based' + return AddressingMode.INDIRECT def deduce_ptr_arg(self, term, as_const=False): if as_const: addressing = self.deduce_addresing(term) ptr = self._get_ptr_type(addressing) - const_ptr_type = f'const {self.underlying_data_type} {ptr}' + datatype = self.underlying_data_type if term.datatype is None else term.datatype.ctype() + const_ptr_type = f'const {datatype} {ptr}' return f'const_cast<{const_ptr_type}>({term.name})' else: return f'{term.name}' From b17385618c2e05534909d6ffb1ab616574456744 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Wed, 16 Apr 2025 15:04:48 +0200 Subject: [PATCH 04/18] Start adding (elementwise) nonlinear function support --- yateto/ast/node.py | 126 +++++++++++- yateto/ast/transformer.py | 11 ++ yateto/ast/visitor.py | 5 + yateto/codegen/elementwise/__init__.py | 1 + yateto/codegen/elementwise/factory.py | 40 ++++ yateto/codegen/elementwise/generic.py | 40 ++++ yateto/codegen/factory.py | 35 +++- yateto/functions.py | 41 ++++ yateto/generator.py | 1 + yateto/ops.py | 257 +++++++++++++++++++++++++ 10 files changed, 549 insertions(+), 8 deletions(-) create mode 100644 yateto/codegen/elementwise/__init__.py create mode 100644 yateto/codegen/elementwise/factory.py create mode 100644 yateto/codegen/elementwise/generic.py create mode 100644 yateto/functions.py create mode 100644 yateto/ops.py diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 93da011..780d469 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -2,7 +2,7 @@ from ..memory import DenseMemoryLayout from .indices import BoundingBox, Indices, LoGCost from abc import ABC, abstractmethod -from .. import aspp +from .. import aspp, ops class Node(ABC): def __init__(self): @@ -112,6 +112,12 @@ def __sub__(self, other): def __le__(self, other): return Assign(self, other) + + def __truediv__(self, other): + return Elementwise(ops.Div(), self, other) + + def __rtruediv__(self, other): + return Elementwise(ops.Div(), other, self) class IndexedTensor(Node): def __init__(self, tensor, indexNames): @@ -141,7 +147,7 @@ def __deepcopy__(self, memo): return it def __str__(self): - return '{}[{}]'.format(self.tensor.name(), str(self.indices)) + return f'{self.tensor.name()}[{str(self.indices)}]' class Op(Node): def __init__(self, *args): @@ -158,7 +164,7 @@ def setMemoryLayout(self, memLayout): def computeMemoryLayout(self): alignStride = False - if len(self.indices) > 0: + if self.indices is not None and len(self.indices) > 0: for child in self: if self.indices[0] in child.indices: position = child.indices.find(self.indices[0]) @@ -270,7 +276,24 @@ def setChildren(self, children): raise ValueError('BinOp node must have exactly 2 children.') super().setChildren(children) -class Assign(BinOp): +class Assign(Op): + def __init__(self, lTerm, rTerm, condition=None): + if isinstance(condition, Node): + super().__init__(lTerm, rTerm, condition) + else: + super().__init__(lTerm, rTerm) + + self._condition = condition + + def leftTerm(self): + return self._children[0] + + def rightTerm(self): + return self._children[1] + + def condition(self): + return self._condition + def setChildren(self, children): if not isinstance(children[0], IndexedTensor): raise ValueError('First child of Assign node must be an IndexedTensor: ' + str(children[0])) @@ -280,7 +303,7 @@ def nonZeroFlops(self): return 0 def computeSparsityPattern(self, *spps): - spp = spps[1] if len(spps) == 2 else self.rightTerm().eqspp() + spp = spps[1] if len(spps) >= 2 else self.rightTerm().eqspp() return self.permute(self.rightTerm().indices, spp) class Permute(UnaryOp): @@ -455,7 +478,6 @@ def is_pure_gemm(self): return True if len(left_indices - right_indices) == 1 else False - class FusedGEMMs(Op): def __init__(self): super().__init__() @@ -479,4 +501,94 @@ def nonZeroFlops(self): return nzFlops def is_empty(self): - return len(self._children) == 0 \ No newline at end of file + return len(self._children) == 0 + +class IfThenElse(Op): + def __init__(self, condition, yesTerm, noTerm): + if isinstance(condition, Node): + super().__init__(yesTerm, noTerm, condition) + else: + super().__init__(yesTerm, noTerm) + + self._condition = condition + + def condition(self): + return condition + + def nonZeroFlops(self): + return 0 + + def computeSparsityPattern(self, *spps): + # TODO: yesTerm OR noTerm + spp = spps[0] if len(spps) >= 2 else self.term().eqspp() + return spp + + def __str__(self): + indices = self.indices if self.indices != None else '' + return f'{type(self).__name__}[{indices}]' + +class Elementwise(Op): + def __init__(self, optype: ops.Operation, *terms): + nodeTerms = [term for term in terms if isinstance(term, Node)] + super().__init__(*nodeTerms) + + self.nodeTermIndices = [None] * len(terms) + self.termTemplate = [None] * len(terms) + index = 0 + for i, term in enumerate(terms): + if isinstance(term, Node): + self.nodeTermIndices[i] = index + index += 1 + else: + self.nodeTermIndices[i] = None + self.termTemplate[i] = term + + self.optype = optype + self.terms = terms + + self.indices = Indices() + for nodeTerm in nodeTerms: + nodeIndices = nodeTerm.indices if nodeTerm.indices is not None else Indices() + K = self.indices & nodeIndices + # assert self.indices.subShape(K) == nodeTerm.subShape(K) + self.indices = self.indices.merged(nodeIndices - K) + + def nonZeroFlops(self): + return self.eqspp().count_nonzero() + + def fillTerms(self, terms): + assert len(terms) == len(self) + return [terms[index] if template is None else template for template, index in zip(self.termTemplate, self.nodeTermIndices)] + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + xspp = spps[0] + return spps[0] + + def __str__(self): + indices = self.indices if self.indices != None else '' + return f'{type(self).__name__}({self.optype})[{indices}]' + +class Reduction(UnaryOp): + def __init__(self, optype, term, sumIndex): + # TODO: what if we datatype/field does not match the operation? (w.r.t. the sparsity patterns) + super().__init__(term) + self.indices = term.indices - set([sumIndex]) + self._reductionIndex = term.indices.extract(sumIndex) + self.optype = optype + + def nonZeroFlops(self): + return self.term().eqspp().count_nonzero() - self.eqspp().count_nonzero() + + def reductionIndex(self): + return self._reductionIndex + + def computeSparsityPattern(self, *spps): + assert len(spps) <= 1 + spp = spps[0] if len(spps) == 1 else self.term().eqspp() + return spp.indexSum(self.term().indices, self.indices) + + def __str__(self): + indices = self.indices if self.indices != None else '' + return f'{type(self).__name__}({self.optype})[{indices}]' diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index 152ce50..3f8dce4 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -87,6 +87,12 @@ def visit_ScalarMultiplication(self, node, bound): self.visit(node.term(), bound) node.indices = deepcopy(node.term().indices) return node + + def visit_Elementwise(self, node, bound): + for child in node: + self.visit(child, bound) + node.indices = deepcopy(node[0].indices) + return node def visit_Assign(self, node, bound): lhs = node[0] @@ -196,6 +202,11 @@ def visit_Assign(self, node): node.setEqspp( node.computeSparsityPattern() ) return node + def visit_Elementwise(self, node): + self.generic_visit(node) + node.setEqspp( node.computeSparsityPattern() ) + return node + def getEqspp(self, terms, targetIndices): # Shortcut if all terms have dense eqspps if all(term.eqspp().is_dense() for term in terms): diff --git a/yateto/ast/visitor.py b/yateto/ast/visitor.py index dd6562c..b32004b 100644 --- a/yateto/ast/visitor.py +++ b/yateto/ast/visitor.py @@ -306,6 +306,11 @@ def visit_IndexedTensor(self, node): term = node.tensor.values_as_ndarray(self._dtype) assert term is not None, '{} may only be used when all involved tensors are constant.'.format(self.__class__.__name__) return term + + def visit_Elementwise(self, node): + terms = self.generic_visit(node) + fullTerms = node.fillTerms(terms) + return node.optype.call(*fullTerms) class ComputeIndexSet(CachedVisitor): def generic_visit(self, node): diff --git a/yateto/codegen/elementwise/__init__.py b/yateto/codegen/elementwise/__init__.py new file mode 100644 index 0000000..fb914e0 --- /dev/null +++ b/yateto/codegen/elementwise/__init__.py @@ -0,0 +1 @@ +from .factory import Description, generator diff --git a/yateto/codegen/elementwise/factory.py b/yateto/codegen/elementwise/factory.py new file mode 100644 index 0000000..009fe41 --- /dev/null +++ b/yateto/codegen/elementwise/factory.py @@ -0,0 +1,40 @@ +from ...memory import CSCMemoryLayout +from ..common import * +from .generic import Generic + +from ...ops import Operation + +from typing import Union + +class Description(object): + def __init__(self, alpha, add: bool, optype: Operation, result: IndexedTensorDescription, terms: list[IndexedTensorDescription], termTemplate, nodeTermIndices): + self.alpha = alpha + self.add = add + self.result = result + self.terms = terms + self.optype = optype + self.termTemplate = termTemplate + self.nodeTermIndices = nodeTermIndices + + self.isSparse = [isinstance(term.memoryLayout, CSCMemoryLayout) for term in terms] + + rR = loopRanges(self.result, self.result.indices) + + # TODO: shall we allow boundingboxing? + if len(terms) == 0: + self.loopRanges = rR + else: + self.loopRanges = loopRanges(self.terms[0], self.result.indices) + assert testLoopRangesAContainedInB(self.loopRanges, rR) + for term in self.terms[1:]: + newRange = loopRanges(term, self.result.indices) + assert testLoopRangesEqual(newRange, self.loopRanges) + assert testLoopRangesAContainedInB(newRange, rR) + + self.loopRanges(newRange) + +def generator(arch, descr, target): + if target == 'cpu': + return Generic(arch, descr) + elif target == 'gpu': + raise RuntimeError("Elementwise operation has not been implemented for GPU-like architectures. At least not like this.") diff --git a/yateto/codegen/elementwise/generic.py b/yateto/codegen/elementwise/generic.py new file mode 100644 index 0000000..92187fb --- /dev/null +++ b/yateto/codegen/elementwise/generic.py @@ -0,0 +1,40 @@ +from ..common import * + +class Generic(object): + def __init__(self, arch, descr): + self._arch = arch + self._descr = descr + + def _affine(self, add, alpha): + flops = 1 + scale = f'{alpha} * ' if alpha != 1.0 else '' + assign = '+=' if add else '=' + + if alpha != 1.0: flops += 1 + if add: flops += 1 + + return flops, lambda left, right: f'{left} {assign} {scale}{right};' + + def _generateDenseDense(self, cpp): + d = self._descr + + if not d.add: + writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) + initializeWithZero(cpp, self._arch, d.result, writeBB) + + flops, assigner = self._affine(d.add, d.alpha) + + class ProductBody(object): + def __call__(s): + args = [f'{arg.name}[{arg.memoryLayout.addressString(arg.indices)}]' for arg in d.terms] + fullArgs = [args[index] if template is None else template for index, template in zip(d.nodeTermIndices, d.termTemplate)] + opstr = d.optype.callstr(*fullArgs) + resultstr = f'{d.result.name}[{d.result.memoryLayout.addressString(d.result.indices)}]' + cpp(assigner(resultstr, opstr)) + return flops + return forLoops(cpp, d.result.indices, d.loopRanges, ProductBody()) + + def generate(self, cpp, routineCache): + d = self._descr + + return self._generateDenseDense(cpp) diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index b07ee8f..2cfb332 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -3,7 +3,7 @@ from ..ast.node import IndexedTensor from ..memory import DenseMemoryLayout from .common import forLoops, TensorDescription, IndexedTensorDescription, BatchedOperationsAux -from . import copyscaleadd, indexsum, log, product, fused_gemms +from . import copyscaleadd, indexsum, log, product, fused_gemms, elementwise class KernelFactory(object): ERROR_NAME = '_error' @@ -148,6 +148,19 @@ def create_Permute(self, node, result, arguments, add, scalar, prefetchName, rou generator = copyscaleadd.generator(self._arch, description, gemm_cfg, self._target) return generator.generate(self._cpp, routineCache) + def create_Elementwise(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + description = elementwise.Description( + alpha = scalar, + add = add, + result = IndexedTensorDescription.fromNode(result, node), + terms = [IndexedTensorDescription.fromNode(argument, term) for argument, term in zip(arguments, node)], + optype = node.optype, + termTemplate = node.termTemplate, + nodeTermIndices = node.nodeTermIndices + ) + generator = elementwise.generator(self._arch, description, self._target) + return generator.generate(self._cpp, routineCache) + def simple(self, result, term, add, scalar, routineCache, gemm_cfg): description = copyscaleadd.Description( alpha = scalar, @@ -200,6 +213,26 @@ def create_Permute(self, node, result, arguments, add, scalar, prefetchName, rou resultTerm = self._formatTerm(result, node.indices) termTerm = self._formatTerm(arguments[0], node.term().indices) return self._simpleBody(resultTerm, termTerm, add, scalar, node.indices) + + def create_Elementwise(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + + argTerms = [self._formatTerm(argument, term.indices) for argument, term in zip(arguments, node)] + termTerm = node.optype.callstr(*node.fillTerms(argTerms)) + + return self._simpleBody(resultTerm, termTerm, add, scalar, g) + + def create_IfThenElse(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + yesTerm = self._formatTerm(arguments[0], node.yesTerm().indices) + noTerm = self._formatTerm(arguments[1], node.noTerm().indices) + conditionTerm = self._formatTerm(arguments[2], node.condition().indices) + + termTerm = f'(({conditionTerm}) ? ({yesTerm}) : ({noTerm}))' + + return self._simpleBody(resultTerm, termTerm, add, scalar, g) def _simpleBody(self, resultTerm, termTerm, add, scalar, indices): ranges = {idx: Range(0, indices.indexSize(idx)) for idx in indices} diff --git a/yateto/functions.py b/yateto/functions.py new file mode 100644 index 0000000..e5037b9 --- /dev/null +++ b/yateto/functions.py @@ -0,0 +1,41 @@ +from . import ops +from .ast import node + +def sin(x): return node.Elementwise(ops.Sin(), x) +def cos(x): return node.Elementwise(ops.Cos(), x) +def tan(x): return node.Elementwise(ops.Tan(), x) +def asin(x): return node.Elementwise(ops.Asin(), x) +def acos(x): return node.Elementwise(ops.Acos(), x) +def atan(x): return node.Elementwise(ops.Atan(), x) + +def sinh(x): return node.Elementwise(ops.Sinh(), x) +def cosh(x): return node.Elementwise(ops.Cosh(), x) +def tanh(x): return node.Elementwise(ops.Tanh(), x) +def asinh(x): return node.Elementwise(ops.Asinh(), x) +def acosh(x): return node.Elementwise(ops.Acosh(), x) +def atanh(x): return node.Elementwise(ops.Atanh(), x) + +def log(x): return node.Elementwise(ops.Log(), x) +def log1p(x): return node.Elementwise(ops.Log1p(), x) +def exp(x): return node.Elementwise(ops.Exp(), x) +def expm1(x): return node.Elementwise(ops.Expm1(), x) +def sqrt(x): return node.Elementwise(ops.Sqrt(), x) +def cbrt(x): return node.Elementwise(ops.Cbrt(), x) + +def abs(x): return node.Elementwise(ops.Abs(), x) + +def max(x, y): return node.Elementwise(ops.Max(), x, y) +def min(x, y): return node.Elementwise(ops.Min(), x, y) +def pow(x, y): return node.Elementwise(ops.Pow(), x, y) + +def assign(lhs, rhs): return node.Assign(lhs, rhs) +def assignIf(condition, lhs, rhs): return node.Assign(lhs, rhs, condition) + +# def where(condition, yes, no): return node.IfThenElse(condition, yes, no) +def where(condition, yes, no): return node.Elementwise(ops.Ternary(), yes, no, condition) + +# extra reduction functions; e.g. for input to `where` +def reductionSum(term, indices): return node.Reduction(ops.Add(), term, indices) +def reductionMul(term, indices): return node.Reduction(ops.Mul(), term, indices)ass +def reductionAnd(term, indices): return node.Reduction(ops.And(), term, indices) +def reductionOr(term, indices): return node.Reduction(ops.Or(), term, indices) diff --git a/yateto/generator.py b/yateto/generator.py index 49c2559..61c3b2f 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -356,6 +356,7 @@ def unit_test_body(cpp, testFramework): kernelSourceContent = '' with Cpp(kernelSource) as cpp: cpp.includeSys('cassert') + cpp.includeSys('cmath') cpp.includeSys('cstring') cpp.includeSys('cstdlib') cpp.includeSys('limits') diff --git a/yateto/ops.py b/yateto/ops.py new file mode 100644 index 0000000..24493d6 --- /dev/null +++ b/yateto/ops.py @@ -0,0 +1,257 @@ +import numpy as np + +class Operation: + #def callstr(self, *args) -> str: + # raise NotImplementedError() + + def call(self, *args): + raise NotImplementedError() + + def datatypeArgs(self): + raise NotImplementedError() # TODO + + def datatypeResult(self): + raise NotImplementedError() # TODO + + def __str__(self): + return type(self).__name__ + + def __eq__(self, other): + # we're more or less using "dummy" types here + return type(self).__name__ == type(other).__name__ + +class CommutativeMonoidMixin: + def neutral(self): + pass + +class RingMixin: + def formsRing(self, op): + pass + +class UnaryArgsMixin: + pass + +class BinaryArgsMixin: + pass + +class CFunctionMixin: + def cppname(self) -> str: + raise NotImplementedError() + + def callstr(self, *args) -> str: + return f'{self.cppname()}({", ".join(str(arg) for arg in args)})' + +class CUnaryOperatorMixin: + def cppname(self) -> str: + raise NotImplementedError() + + def callstr(self, *args) -> str: + return f'{self.cppname()}({args[0]})' + +class CBinaryOperatorMixin: + def cppname(self) -> str: + raise NotImplementedError() + + def callstr(self, *args) -> str: + return f'({args[0]}) {self.cppname()} ({args[1]})' + +class Sin(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::sin' + def call(self, *args): + return np.sin(args[0]) +class Cos(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::cos' + def call(self, *args): + return np.cos(args[0]) +class Tan(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::tan' + def call(self, *args): + return np.tan(args[0]) +class Asin(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::asin' + def call(self, *args): + return np.asin(args[0]) +class Acos(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::acos' + def call(self, *args): + return np.acos(args[0]) +class Atan(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::atan' + def call(self, *args): + return np.atan(args[0]) + +class Sinh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::sinh' + def call(self, *args): + return np.sinh(args[0]) +class Cosh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::cosh' + def call(self, *args): + return np.cosh(args[0]) +class Tanh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::tanh' + def call(self, *args): + return np.tanh(args[0]) +class Asinh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::asinh' + def call(self, *args): + return np.asinh(args[0]) +class Acosh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::acosh' + def call(self, *args): + return np.acosh(args[0]) +class Atanh(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::atanh' + def call(self, *args): + return np.atanh(args[0]) + +class Log(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::log' + def call(self, *args): + return np.log(args[0]) +class Exp(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::exp' + def call(self, *args): + return np.exp(args[0]) +class Log1p(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::log1p' + def call(self, *args): + return np.log1p(args[0]) +class Expm1(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::expm1' + def call(self, *args): + return np.expm1(args[0]) +class Sqrt(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::sqrt' + def call(self, *args): + return np.sqrt(args[0]) +class Cbrt(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::cbrt' + def call(self, *args): + return np.cbrt(args[0]) + +class Cbrt(Operation, CFunctionMixin, UnaryArgsMixin): + def cppname(self): + return 'std::abs' + def call(self, *args): + return np.abs(args[0]) + + +class Max(Operation, CFunctionMixin, BinaryArgsMixin): + def cppname(self): + return 'std::max' + def call(self, *args): + return max(args[0], args[1]) +class Min(Operation, CFunctionMixin, BinaryArgsMixin): + def cppname(self): + return 'std::min' + def call(self, *args): + return min(args[0], args[1]) +class Pow(Operation, CFunctionMixin, BinaryArgsMixin): + def cppname(self): + return 'std::pow' + def call(self, *args): + return pow(args[0], args[1]) + +class Div(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '/' + def call(self, *args): + return args[0] / args[1] + +class Add(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin): + def cppname(self, *args): + return '+' + def call(self, *args): + return args[0] + args[1] + def neutral(self): + return 0 +class Mul(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): + def cppname(self, *args): + return '*' + def call(self, *args): + return args[0] * args[1] + def neutral(self): + return 1 + def formsRing(self, op): + return op == Add() + +class And(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): + def cppname(self, *args): + return '&' + def call(self, *args): + return args[0] & args[1] + def neutral(self): + return True + def formsRing(self, op): + return op == Or() +class Or(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): + def cppname(self, *args): + return '|' + def call(self, *args): + return args[0] | args[1] + def neutral(self): + return False + def formsRing(self, op): + return op == And() +class Not(Operation, CUnaryOperatorMixin, UnaryArgsMixin): + def cppname(self, *args): + return '~' + def call(self, *args): + return ~args[0] + +class CmpEq(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '==' + def call(self, *args): + return args[0] == args[1] +class CmpNe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '!=' + def call(self, *args): + return args[0] != args[1] +class CmpLt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '<' + def call(self, *args): + return args[0] < args[1] +class CmpLe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '<=' + def call(self, *args): + return args[0] <= args[1] +class CmpGt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '>' + def call(self, *args): + return args[0] > args[1] +class CmpGt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): + def cppname(self, *args): + return '>=' + def call(self, *args): + return args[0] >= args[1] + +# replacement; however it'll execute both code paths, regardless of the result +class Ternary(Operation): + def callstr(self, *args): + return f'(({args[2]}) ? ({args[0]}) : ({args[1]}))' + def call(self, *args): + return np.where(args[2], args[0], args[1]) From 075853d85a3dc18a6c131bc9f8769cfcb8674be1 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Wed, 16 Apr 2025 22:23:30 +0200 Subject: [PATCH 05/18] Fix bugs; propagate datatype --- yateto/ast/node.py | 3 ++ yateto/ast/transformer.py | 26 ++++++++++++-- yateto/codegen/common.py | 27 ++++++++++---- yateto/codegen/copyscaleadd/generic.py | 2 +- yateto/codegen/factory.py | 50 +++++++++++++------------- yateto/codegen/gemm/gemmgen.py | 20 ++++++++--- yateto/codegen/indexsum/generic.py | 12 +++---- yateto/codegen/log/generic.py | 8 ++--- yateto/codegen/product/generic.py | 4 +-- yateto/codegen/visitor.py | 14 ++++---- yateto/controlflow/graph.py | 3 +- yateto/controlflow/transformer.py | 5 +-- yateto/controlflow/visitor.py | 5 +-- yateto/generator.py | 14 +++++--- yateto/type.py | 17 +++++++-- 15 files changed, 139 insertions(+), 71 deletions(-) diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 3c0761e..e88eadc 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -113,6 +113,9 @@ def __sub__(self, other): def __le__(self, other): return Assign(self, other) + + def cast(self, datatype): + return DatatypeCast(self, datatype) class IndexedTensor(Node): def __init__(self, tensor, indexNames): diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index e11e99e..578997c 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -239,7 +239,10 @@ def generic_visit(self, node): def visit_IndexedTensor(self, node): return node -class SetDatatype(Transformer): +class SetDatatype1(Transformer): + def __init__(self, arch): + self.arch = arch + def generic_visit(self, node): super().generic_visit(node) assert(len(node) > 0) @@ -248,9 +251,28 @@ def generic_visit(self, node): return node def visit_IndexedTensor(self, node): - node.datatype = node.tensor.datatype + super().generic_visit(node) + node.datatype = node.tensor.getDatatype(self.arch) return node def visit_DatatypeCast(self, node): + super().generic_visit(node) + node.datatype = node.newDatatype + return node + +class SetDatatype2(Transformer): + def generic_visit(self, node): + super().generic_visit(node) + assert(len(node) > 0) + assert(all(child.datatype == node[0].datatype for child in node)) + node.datatype = node[0].datatype + return node + + def visit_IndexedTensor(self, node): + super().generic_visit(node) + return node + + def visit_DatatypeCast(self, node): + super().generic_visit(node) node.datatype = node.newDatatype return node diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index dce5edd..0325aa9 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -38,17 +38,32 @@ def __init__(self, name, indices, memoryLayout, eqspp, is_compute_constant=False @classmethod def fromNode(cls, var, node): + datatype = node.datatype + is_const = False values = None - datatype = None addressing = None if hasattr(node, 'tensor'): is_const = node.tensor.is_compute_constant() if is_const: values = node.tensor.values() - datatype = node.datatype addressing = node.tensor.addressing return cls(str(var), node.indices, var.memoryLayout(), node.eqspp(), is_const, var.is_temporary, values, datatype, addressing) + + @classmethod + def fromVar(cls, var, indices): + datatype = var.datatype + + is_const = False + values = None + addressing = None + if hasattr(var, 'tensor'): + if var.tensor is not None: + is_const = var.tensor.is_compute_constant() + if is_const: + values = var.tensor.values() + addressing = var.tensor.addressing + return cls(str(var), indices, var.memoryLayout(), var.eqspp(), is_const, var.is_temporary, values, datatype, addressing) def forLoops(cpp, indexNames, ranges, body, pragmaSimd=True, prefix='_', indexNo=None): flops = 0 @@ -85,17 +100,17 @@ def boundingBoxFromLoopRanges(indices, loopRanges): def reduceSpp(spp, sourceIndices, targetIndices): return spp.indexSum(sourceIndices, targetIndices) -def initializeWithZero(cpp, arch, result: TensorDescription, writeBB = None): +def initializeWithZero(cpp, result: TensorDescription, writeBB = None): if writeBB: addresses = sorted(result.memoryLayout.notWrittenAddresses(writeBB)) if len(addresses) > 0: regions = splitByDistance(addresses) for region in regions: m, M = min(region), max(region) - initialAddress = '{} + {}'.format(result.name, m) - cpp.memset(initialAddress, M-m+1, arch.typename) + initialAddress = f'{result.name} + {m}' + cpp.memset(initialAddress, M-m+1, result.datatype.ctype()) else: - cpp.memset(result.name, result.memoryLayout.requiredReals(), arch.typename) + cpp.memset(result.name, result.memoryLayout.requiredReals(), result.datatype.ctype()) class BatchedOperationsAux: diff --git a/yateto/codegen/copyscaleadd/generic.py b/yateto/codegen/copyscaleadd/generic.py index 44bf704..3f47658 100644 --- a/yateto/codegen/copyscaleadd/generic.py +++ b/yateto/codegen/copyscaleadd/generic.py @@ -20,7 +20,7 @@ def generate(self, cpp, routineCache): if d.beta == 0.0: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) class CopyScaleAddBody(object): def __call__(s): diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index b07ee8f..19f33ec 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -4,6 +4,7 @@ from ..memory import DenseMemoryLayout from .common import forLoops, TensorDescription, IndexedTensorDescription, BatchedOperationsAux from . import copyscaleadd, indexsum, log, product, fused_gemms +from ..type import Datatype class KernelFactory(object): ERROR_NAME = '_error' @@ -25,22 +26,20 @@ def generic_create(self, node, *args): def simple(self, result, term, add, scalar, routineCache, gemm_cfg): raise NotImplementedError - def temporary(self, bufname, size, iniZero=False, memory=list()): + def temporary(self, bufname, size, datatype, iniZero=False, memory=list()): assert(iniZero == False or len(memory) == 0) + if datatype is None: + datatype = Datatype.I8 + if self._target == 'cpu': if self._arch.onHeap(size): if len(self._freeList) == 0: - self._cpp('int {};'.format(self.ERROR_NAME)) - self._cpp('{}* {};'.format(self._arch.typename, bufname)) - self._cpp('{} = posix_memalign(reinterpret_cast(&{}), {}, {}*sizeof({}));'.format( - self.ERROR_NAME, - bufname, - self._arch.alignment, - size, - self._arch.typename)) + self._cpp(f'int {self.ERROR_NAME};') + self._cpp(f'{datatype.ctype()}* {bufname};') + self._cpp(f'{self.ERROR_NAME} = posix_memalign(reinterpret_cast(&{bufname}), {self._arch.alignment}, {size}*sizeof({datatype.ctype()}));') if iniZero: - self._cpp.memset(bufname, size, self._arch.typename) + self._cpp.memset(bufname, size, datatype.ctype()) if memory: for i, data in enumerate(memory): self._cpp(f'{bufname}[{i}] = {data};') @@ -51,13 +50,12 @@ def temporary(self, bufname, size, iniZero=False, memory=list()): ini = ' = {}' elif memory: ini = ' = {{{}}}'.format(', '.join(memory)) - self._cpp(f'alignas({self._arch.alignment}) {self._arch.typename} {bufname}[{size}] {ini};') + self._cpp(f'alignas({self._arch.alignment}) {datatype.ctype()} {bufname}[{size}] {ini};') else: - declaration = f'{self._arch.typename}* {bufname}' + declaration = f'{datatype.ctype()}* {bufname}' total_size = f'{BatchedOperationsAux.NUM_ELEMENTS_NAME} * {size}' self._cpp(f'{declaration} = linearAllocator.allocate({total_size});') - def freeTmp(self): if self._target == 'cpu': for free in self._freeList: @@ -167,7 +165,7 @@ def __init__(self, cpp, arch, nameFun, testFramework): def _formatTerm(self, var, indices): address = var.memoryLayout().addressString(indices) - return '{}[{}]'.format(self._name(var), address) + return f'{self._name(var)}[{address}]' def create_Einsum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): g = node.indices @@ -187,7 +185,7 @@ def create_Einsum(self, node, result, arguments, add, scalar, prefetchName, rout class EinsumBody(object): def __call__(s): - self._cpp( '{} += {};'.format(resultTerm, ' * '.join(terms)) ) + self._cpp(f"{resultTerm} += {' * '.join(terms)};") return len(terms) return forLoops(self._cpp, g, ranges, EinsumBody(), pragmaSimd=False) @@ -205,11 +203,11 @@ def _simpleBody(self, resultTerm, termTerm, add, scalar, indices): ranges = {idx: Range(0, indices.indexSize(idx)) for idx in indices} if scalar and scalar != 1.0: - termTerm = '{} * {}'.format(scalar, termTerm) + termTerm = f'{scalar} * {termTerm}' class AssignBody(object): def __call__(s): - self._cpp( '{} {} {};'.format(resultTerm, '+=' if add else '=', termTerm) ) + self._cpp(f"{resultTerm} {'+=' if add else '='} {termTerm};") return 1 if add else 0 return forLoops(self._cpp, indices, ranges, AssignBody(), pragmaSimd=False) @@ -229,8 +227,8 @@ def compare(self, ref, target, epsMult = 100.0): class CompareBody(object): def __call__(s): - self._cpp( 'double ref = {};'.format(refTerm) ) - self._cpp( 'double diff = ref - {};'.format(targetTerm) ) + self._cpp( f'double ref = {refTerm};' ) + self._cpp( f'double diff = ref - {targetTerm};' ) self._cpp( 'error += diff * diff;' ) self._cpp( 'refNorm += ref * ref;' ) return 0 @@ -247,19 +245,19 @@ def tensor(self, node, resultName, maxValue = 512): ml = node.memoryLayout() size = ml.requiredReals() + datatype = node.getDatatype(self._arch) + spp = node.spp() isDense = spp.count_nonzero() == size if isDense: - self.temporary(resultName, size) - with self._cpp.For('int i = 0; i < {}; ++i'.format(size)): - self._cpp('{}[i] = static_cast<{}>((i + {}) % {} + 1);'.format(resultName, self._arch.typename, self._rand, maxValue)) + self.temporary(resultName, size, node.getDatatype(self._arch)) + with self._cpp.For(f'int i = 0; i < {size}; ++i'): + self._cpp(f'{resultName}[i] = static_cast<{datatype.ctype()}>((i + {self._rand}) % {maxValue} + 1);') else: - memory = ['0.0']*size + memory = [datatype.literal(0)]*size nz = spp.nonzero() for entry in zip(*nz): addr = ml.address(entry) memory[addr] = str(float((addr + self._rand) % maxValue)+1.0) - self.temporary(resultName, size, memory=memory) + self.temporary(resultName, size, datatype, memory=memory) self._rand += 1 - - diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index b59797a..0b09411 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -64,9 +64,8 @@ def generateRoutineName(self, gemm, sppA, sppB): sha = hashlib.md5() sha.update(str(sppB).encode()) name += '_' + sha.hexdigest() - return '{name}_{datatype}_m{M}_n{N}_k{K}_ldA{LDA}_ldB{LDB}_ldC{LDC}_alpha{alphaSubs}_beta{betaSubs}_alignedA{alignedA}_alignedC{alignedC}_transA{transA}_transB{transB}_{prefetch}'.format( + return '{name}_{datatypeA}_{datatypeB}_{datatypeC}_m{M}_n{N}_k{K}_ldA{LDA}_ldB{LDB}_ldC{LDC}_alpha{alphaSubs}_beta{betaSubs}_alignedA{alignedA}_alignedC{alignedC}_transA{transA}_transB{transB}_{prefetch}'.format( name=name, - datatype=self._arch.typename, alphaSubs=self._alpha(gemm['alpha']), betaSubs=self._beta(gemm['beta']), **gemm @@ -240,7 +239,9 @@ def call_arg(name, term, modified, offset): 'prefetch': 'BL2viaC' if self._arch.enablePrefetch and d.prefetchName is not None else 'nopf', 'transA': d.transA, 'transB': d.transB, - + 'datatypeA': d.leftTerm.datatype, + 'datatypeB': d.rightTerm.datatype, + 'datatypeC': d.result.datatype, } routineName = self.generateRoutineName(gemm, sppA, sppB) @@ -320,6 +321,15 @@ def _callGenerator(self, argList): def __call__(self, routineName, fileName): cpu_arch = self._arch.host_name if self._arch.host_name else self._arch.name + assert self._gemmDescr['datatypeC'] == self._gemmDescr['datatypeA'] + assert self._gemmDescr['datatypeC'] == self._gemmDescr['datatypeB'] + assert self._gemmDescr['datatypeC'] in [Datatype.F32, Datatype.F64] + + precision = { + Datatype.F32: 'F', + Datatype.F64: 'D' + }[self._gemmDescr['datatypeC']] + if self._mode == 'pspamm': pspamm_arch = cpu_arch if cpu_arch == 'a64fx': @@ -351,7 +361,7 @@ def __call__(self, routineName, fileName): '--output_filename', fileName, '--precision', - self._arch.precision + precision ] if self._gemmDescr['prefetch'] != 'nopf': argList.extend(['--prefetching', self._gemmDescr['prefetch']]) @@ -386,7 +396,7 @@ def __call__(self, routineName, fileName): self._gemmDescr['alignedC'], libxsmm_arch, # libxsmm has no support for rome, hsw works well in practice self._gemmDescr['prefetch'], - self._arch.precision + 'P' + precision + 'P' ] class SparsityWrapper: def __init__(self, shape, spp): diff --git a/yateto/codegen/indexsum/generic.py b/yateto/codegen/indexsum/generic.py index 36df9f5..2df7c54 100644 --- a/yateto/codegen/indexsum/generic.py +++ b/yateto/codegen/indexsum/generic.py @@ -10,19 +10,19 @@ def generate(self, cpp, routineCache): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) sumIndex = d.term.indices - d.result.indices assert len(sumIndex) == 1 class IndexSumBody(object): def __call__(s): target = '{}[{}]'.format(d.result.name, d.result.memoryLayout.addressString(d.result.indices)) - initialValue = target if d.add else self._arch.formatConstant(0.0) - cpp( '{} sum = {};'.format(self._arch.typename, initialValue) ) + initialValue = target if d.add else d.result.datatype.literal(0.0) + cpp(f'{d.result.datatype.ctype()} sum = {initialValue};') with cpp.For('int {0} = {1}; {0} < {2}; ++{0}'.format(sumIndex, d.sumLoopRange.start, d.sumLoopRange.stop)): - cpp( 'sum += {}[{}];'.format(d.term.name, d.term.memoryLayout.addressString(d.term.indices)) ) - mult = '{} * '.format(d.alpha) if d.alpha != 1.0 else '' - cpp( '{} = {}sum;'.format(target, mult) ) + cpp( f'sum += {d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}];' ) + mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' + cpp( f'{target} = {mult}sum;' ) flop = 1 if d.alpha != 1.0 else 0 return d.sumLoopRange.size() + flop diff --git a/yateto/codegen/log/generic.py b/yateto/codegen/log/generic.py index 68c2e50..366d041 100644 --- a/yateto/codegen/log/generic.py +++ b/yateto/codegen/log/generic.py @@ -80,9 +80,9 @@ def generate(self, cpp, routineCache, gemm_cfg): Ceqspp = self._reduce(d.result, C, CmemLayout) gemmDescr = gemm.Description( - leftTerm = TensorDescription(innerAname, AmemLayout, Aeqspp, d.leftTerm.is_compute_constant, d.leftTerm.is_temporary), - rightTerm = TensorDescription(innerBname, BmemLayout, Beqspp, d.rightTerm.is_compute_constant, d.rightTerm.is_temporary), - result = TensorDescription(innerCname, CmemLayout, Ceqspp, d.result.is_compute_constant, d.result.is_temporary), + leftTerm = TensorDescription(innerAname, AmemLayout, Aeqspp, d.leftTerm.is_compute_constant, d.leftTerm.is_temporary, datatype=d.leftTerm.datatype), + rightTerm = TensorDescription(innerBname, BmemLayout, Beqspp, d.rightTerm.is_compute_constant, d.rightTerm.is_temporary, datatype=d.rightTerm.datatype), + result = TensorDescription(innerCname, CmemLayout, Ceqspp, d.result.is_compute_constant, d.result.is_temporary, datatype=d.result.datatype), transA = d.transA, transB = d.transB, alpha = d.alpha, @@ -100,7 +100,7 @@ def generate(self, cpp, routineCache, gemm_cfg): lr.update( self._defuse(m, d.leftTerm, Im) ) lr.update( self._defuse(n, d.rightTerm, In) ) writeBB = boundingBoxFromLoopRanges(d.result.indices, lr) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) class LoGBody(object): def __call__(s): diff --git a/yateto/codegen/product/generic.py b/yateto/codegen/product/generic.py index 5cbba15..5e11437 100644 --- a/yateto/codegen/product/generic.py +++ b/yateto/codegen/product/generic.py @@ -21,7 +21,7 @@ def _generateDenseDense(self, cpp): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) class ProductBody(object): def __call__(s): @@ -46,7 +46,7 @@ def _generateSparseSparse(self, cpp): assert d.isACsc and d.isBCsc if not d.add: - initializeWithZero(cpp, self._arch, d.result) + initializeWithZero(cpp, d.result) left = d.result.indices.positions(d.leftTerm.indices, sort=False) right = d.result.indices.positions(d.rightTerm.indices, sort=False) diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index f0e67f5..bc71aa1 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -77,15 +77,15 @@ def generate(self, cpp, cfg, factory, routineCache, gemm_cfg): localPtrs = set() for pp in cfg: localPtrs.update(pp.bufferMap.keys()) - if localPtrs: - cpp( '{}{};'.format(self._arch.typename, ','.join(map(lambda x: ' *' + str(x), localPtrs))) ) + for localPtr in localPtrs: + cpp(f'{localPtr.datatype.ctype()}* {localPtr};') for pp in cfg: for buf, size in pp.initBuffer.items(): - required_tmp_mem += size * self._arch.bytesPerReal + required_tmp_mem += size bufname = self._bufferName(buf) - factory.temporary(bufname, size) + factory.temporary(bufname, size, None) for local, buf in pp.bufferMap.items(): - cpp('{} = {};'.format(local, self._bufferName(buf))) + cpp(f'{local} = reinterpret_cast<{localPtr.datatype.ctype()}*>({self._bufferName(buf)});') action = pp.action if action: scalar = self.deduce_scalar(action) @@ -490,7 +490,7 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, for var in variables: factory.tensor(var.tensor, self._tensorName(var)) - factory.temporary(self._name(var), var.memoryLayout().requiredReals(), iniZero=True) + factory.temporary(self._name(var), var.memoryLayout().requiredReals(), var.datatype, iniZero=True) shape = var.memoryLayout().shape() cpp('{supportNS}::DenseTensorView<{dim},{arch.typename},{arch.uintTypename}> {viewName}({utName}, {{{shape}}}, {{{start}}}, {{{shape}}});'.format( @@ -564,7 +564,7 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, for var in variables: if var.writable: - factory.compare(var, Variable(self._tensorName(var), False, var.tensor.memoryLayout())) + factory.compare(var, Variable(self._tensorName(var), False, var.tensor.memoryLayout(), datatype=var.datatype)) factory.freeTmp() diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index f48b87e..771964a 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -4,13 +4,14 @@ class Variable(object): - def __init__(self, name, writable, memoryLayout, eqspp=None, tensor=None, is_temporary=False): + def __init__(self, name, writable, memoryLayout, eqspp=None, tensor=None, is_temporary=False, datatype=None): self.name = name self.writable = writable self.tensor = tensor self._memoryLayout = memoryLayout self._eqspp = eqspp self.is_temporary = is_temporary + self.datatype = datatype def variables(self): return {self} diff --git a/yateto/controlflow/transformer.py b/yateto/controlflow/transformer.py index 741ba36..605ba0c 100644 --- a/yateto/controlflow/transformer.py +++ b/yateto/controlflow/transformer.py @@ -128,7 +128,7 @@ def visit(self, cfg): # assign buffer if ua and not ua.isCompound() and ua.result.isLocal(): if ua.result in usedBuffers: - buf = usedBuffers[ua.result] + buf = usedBuffers[ua.result] elif len(freeBuffers) > 0: buf = freeBuffers.pop() else: @@ -137,7 +137,8 @@ def visit(self, cfg): cfg[i].bufferMap[ua.result] = buf usedBuffers[ua.result] = buf - size = ua.result.memoryLayout().requiredReals() + # NOTE: size in bytes + size = ua.result.memoryLayout().requiredReals() * ua.result.datatype.size() if buf in bufferSize: bufferSize[buf] = max(bufferSize[buf], size) else: diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index b06d301..d5a0ddd 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -25,6 +25,7 @@ def _addPermuteIfRequired(self, indices, term, variable): if not self._simpleMemoryLayout: permute.setEqspp( permute.computeSparsityPattern() ) permute.computeMemoryLayout() + permute.datatype = term.datatype result = self._nextTemporary(permute) action = ProgramAction(result, Expression(permute, self._ml(permute), [variable]), False) self._addAction(action) @@ -76,7 +77,7 @@ def visit_Assign(self, node): return variables[0] def visit_IndexedTensor(self, node): - return Variable(node.name(), node.name() in self._writable, self._ml(node), node.eqspp(), node.tensor) + return Variable(node.name(), node.name() in self._writable, self._ml(node), node.eqspp(), node.tensor, datatype=node.datatype) def _addAction(self, action): self._cfg.append(ProgramPoint(action)) @@ -84,7 +85,7 @@ def _addAction(self, action): def _nextTemporary(self, node): name = '{}{}'.format(self.TEMPORARY_RESULT, self._tmp) self._tmp += 1 - return Variable(name, True, self._ml(node), node.eqspp(), is_temporary=True) + return Variable(name, True, self._ml(node), node.eqspp(), is_temporary=True, datatype=node.datatype) def updateWritable(self, name): self._writable = self._writable | {name} diff --git a/yateto/generator.py b/yateto/generator.py index d5959f2..2598e1c 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -57,8 +57,11 @@ def __init__(self, name, ast, prefetch=None, namespace=None, target='cpu'): def isValidName(cls, name): return re.match(cls.VALID_NAME, name) is not None - def prepareUntilUnitTest(self): + def prepareUntilUnitTest(self, arch): self.ast = [DeduceIndices().visit(ast) for ast in self.ast] + dtd = SetDatatype1(arch) + for ast in self.ast: + dtd.visit(ast) ast2cf = AST2ControlFlow(simpleMemoryLayout=True) for ast in self.ast: ast2cf.visit(ast) @@ -84,6 +87,7 @@ def prepareUntilCodeGen(self, cost_estimator, enableFusedGemm: bool): permutationVariants = FindIndexPermutations().visit(ast) ast = SelectIndexPermutations(permutationVariants).visit(ast) ast = ImplementContractions().visit(ast) + ast = SetDatatype2().visit(ast) if self._prefetch is not None: prefetchCapabilities = FindPrefetchCapabilities().visit(ast) assignPf = AssignPrefetch(prefetchCapabilities, prefetch) @@ -172,9 +176,9 @@ def add(self, name, ast, prefetch=None, namespace=None, target='cpu'): def kernels(self): return self._kernels.values() - def prepareUntilUnitTest(self): + def prepareUntilUnitTest(self, arch): for kernel in self._kernels.values(): - kernel.prepareUntilUnitTest() + kernel.prepareUntilUnitTest(arch) def prepareUntilCodeGen(self, costEstimator, enableFusedGemm: bool): for kernel in self._kernels.values(): @@ -273,9 +277,9 @@ def generate(self, print('Deducing indices...') for kernel in self._kernels: - kernel.prepareUntilUnitTest() + kernel.prepareUntilUnitTest(self._arch) for family in self._kernelFamilies.values(): - family.prepareUntilUnitTest() + family.prepareUntilUnitTest(self._arch) fUTdoctest = self.FileNames(outputDir, self.DOCTEST_FILE_NAME) fUTcxxtest = self.FileNames(outputDir, self.CXXTEST_FILE_NAME) diff --git a/yateto/type.py b/yateto/type.py index 4ed0980..e1b085e 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -16,6 +16,19 @@ class Datatype(Enum): F16 = 7 BF16 = 8 + def __str__(self): + return { + Datatype.BOOL: 'bool', + Datatype.I8: 'i8', + Datatype.I16: 'i16', + Datatype.I32: 'i32', + Datatype.I64: 'i64', + Datatype.F32: 'f32', + Datatype.F64: 'f64', + Datatype.F16: 'f16', + Datatype.BF16: 'bf16', + }[self] + def ctype(self): return { Datatype.BOOL: 'bool', @@ -90,11 +103,11 @@ def __init__(self, name, namespace=None, datatype=None): self._name = name self.namespace = namespace - # datatype == None is treated as datatype == self._arch.datatype + # datatype == None is treated as datatype == arch.datatype self.datatype = datatype def getDatatype(self, arch): - return self._arch.datatype if self.datatype is None else self.datatype + return arch.datatype if self.datatype is None else self.datatype def __str__(self): return self._name From 197575345c1f60c5c2ab10a0d4badbcc5fce2081 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Thu, 17 Apr 2025 01:08:17 +0200 Subject: [PATCH 06/18] Fix build --- yateto/codegen/elementwise/generic.py | 2 +- yateto/functions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/yateto/codegen/elementwise/generic.py b/yateto/codegen/elementwise/generic.py index 92187fb..c0a0f4f 100644 --- a/yateto/codegen/elementwise/generic.py +++ b/yateto/codegen/elementwise/generic.py @@ -20,7 +20,7 @@ def _generateDenseDense(self, cpp): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) - initializeWithZero(cpp, self._arch, d.result, writeBB) + initializeWithZero(cpp, d.result, writeBB) flops, assigner = self._affine(d.add, d.alpha) diff --git a/yateto/functions.py b/yateto/functions.py index e5037b9..bbcbe38 100644 --- a/yateto/functions.py +++ b/yateto/functions.py @@ -36,6 +36,6 @@ def where(condition, yes, no): return node.Elementwise(ops.Ternary(), yes, no, c # extra reduction functions; e.g. for input to `where` def reductionSum(term, indices): return node.Reduction(ops.Add(), term, indices) -def reductionMul(term, indices): return node.Reduction(ops.Mul(), term, indices)ass +def reductionMul(term, indices): return node.Reduction(ops.Mul(), term, indices) def reductionAnd(term, indices): return node.Reduction(ops.And(), term, indices) def reductionOr(term, indices): return node.Reduction(ops.Or(), term, indices) From 4c0fdb6841d9cfa43a448beb8180a51f2eca2637 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Thu, 17 Apr 2025 15:57:55 +0200 Subject: [PATCH 07/18] Fix bugs; propagate datatypes --- yateto/ast/node.py | 45 ++++++--- yateto/ast/transformer.py | 22 +++-- yateto/codegen/copyscaleadd/csa_gen.py | 4 +- yateto/codegen/elementwise/factory.py | 2 +- .../codegen/fused_gemms/external_generator.py | 7 +- yateto/codegen/gemm/gemmgen.py | 7 +- yateto/codegen/log/generic.py | 2 +- yateto/codegen/reduction/__init__.py | 1 + yateto/codegen/reduction/factory.py | 30 ++++++ yateto/codegen/reduction/generic.py | 31 ++++++ yateto/codegen/visitor.py | 21 +++- yateto/functions.py | 10 ++ yateto/ops.py | 98 ++++++++++++++++++- yateto/type.py | 17 +++- 14 files changed, 256 insertions(+), 41 deletions(-) create mode 100644 yateto/codegen/reduction/__init__.py create mode 100644 yateto/codegen/reduction/factory.py create mode 100644 yateto/codegen/reduction/generic.py diff --git a/yateto/ast/node.py b/yateto/ast/node.py index f99fe54..be51efb 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -120,9 +120,6 @@ def __truediv__(self, other): def __rtruediv__(self, other): return Elementwise(ops.Div(), other, self) - def cast(self, datatype): - return DatatypeCast(self, datatype) - class IndexedTensor(Node): def __init__(self, tensor, indexNames): super().__init__() @@ -281,7 +278,7 @@ def setChildren(self, children): super().setChildren(children) class Assign(Op): - def __init__(self, lTerm, rTerm, condition=None): + def __init__(self, lTerm, rTerm, condition=True): if isinstance(condition, Node): super().__init__(lTerm, rTerm, condition) else: @@ -309,6 +306,12 @@ def nonZeroFlops(self): def computeSparsityPattern(self, *spps): spp = spps[1] if len(spps) >= 2 else self.rightTerm().eqspp() return self.permute(self.rightTerm().indices, spp) + + def __str__(self): + selfname = type(self).__name__ + indices = self.indices if self.indices != None else '' + condition = '' if isinstance(self.condition(), bool) and self.condition() else f' if {self.condition()}' + return f'{selfname}[{indices}]: {self.leftTerm()} <- {self.rightTerm()}{condition}' class Permute(UnaryOp): def __init__(self, term, targetIndices): @@ -364,18 +367,6 @@ def computeSparsityPattern(self, *spps): spp = spps[0] if len(spps) == 1 else self.term().eqspp() return spp.indexSum(self.term().indices, self.indices) -class DatatypeCast(UnaryOp): - def __init__(self, term, datatype): - super().__init__(term) - self.indices = term.indices - self.newDatatype = datatype - - def nonZeroFlops(self): - return self.term().eqspp().count_nonzero() - - def computeSparsityPattern(self, *spps): - return self.term().indices - class Contraction(BinOp): def __init__(self, indices, lTerm, rTerm, sumIndices): super().__init__(lTerm, rTerm) @@ -608,3 +599,25 @@ def computeSparsityPattern(self, *spps): def __str__(self): indices = self.indices if self.indices != None else '' return f'{type(self).__name__}({self.optype})[{indices}]' + +class Accumulate(Op): + def __init__(self, optype, *operands): + super().__init__(*operands) + + self.optype = optype + + def computeSparsityPattern(self, *spps): + if len(spps) == 0: + spps = [node.eqspp() for node in self] + permute_summand = lambda i: self.permute(self[i].indices, spps[i]) + spp = permute_summand(0) + for i in range(1, len(spps)): + add_spp = permute_summand(i) + spp = aspp.add(spp, add_spp) + return spp + + def nonZeroFlops(self): + nzFlops = 0 + for child in self: + nzFlops += child.eqspp().count_nonzero() + return nzFlops - self.eqspp().count_nonzero() diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index 9ad2218..157a42d 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -265,10 +265,15 @@ def visit_IndexedTensor(self, node): super().generic_visit(node) node.datatype = node.tensor.getDatatype(self.arch) return node - - def visit_DatatypeCast(self, node): + + def visit_Elementwise(self, node): + super().generic_visit(node) + node.datatype = node.optype.datatypeResult([c.datatype for c in node]) + return node + + def visit_Assign(self, node): super().generic_visit(node) - node.datatype = node.newDatatype + node.datatype = node[0].datatype return node class SetDatatype2(Transformer): @@ -282,8 +287,13 @@ def generic_visit(self, node): def visit_IndexedTensor(self, node): super().generic_visit(node) return node - - def visit_DatatypeCast(self, node): + + def visit_Elementwise(self, node): + super().generic_visit(node) + node.datatype = node.optype.datatypeResult([c.datatype for c in node]) + return node + + def visit_Assign(self, node): super().generic_visit(node) - node.datatype = node.newDatatype + node.datatype = node[0].datatype return node diff --git a/yateto/codegen/copyscaleadd/csa_gen.py b/yateto/codegen/copyscaleadd/csa_gen.py index 8f8c6d5..12707b7 100644 --- a/yateto/codegen/copyscaleadd/csa_gen.py +++ b/yateto/codegen/copyscaleadd/csa_gen.py @@ -72,7 +72,7 @@ def generate(self, cpp, routineCache): n = d.loopRanges[d.result.indices[1]] alpha = d.alpha - aux = BatchedOperationsAux(self._arch.typename) + aux = BatchedOperationsAux(d.result.datatype.ctype()) matrix_a = gf.YatetoInterface.produce_dense_matrix((m, n), d.term.memoryLayout.bbox(), addressing=aux.deduce_addresing(d.term), @@ -84,7 +84,7 @@ def generate(self, cpp, routineCache): transpose=False) try: - vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=self._arch.typename) + vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=d.result.datatype.ctype()) forge_generator = gf.CsaGenerator(vm) forge_generator.set(matrix_a, matrix_b, alpha, d.beta) routine_name = forge_generator.get_base_name() diff --git a/yateto/codegen/elementwise/factory.py b/yateto/codegen/elementwise/factory.py index 009fe41..8771498 100644 --- a/yateto/codegen/elementwise/factory.py +++ b/yateto/codegen/elementwise/factory.py @@ -31,7 +31,7 @@ def __init__(self, alpha, add: bool, optype: Operation, result: IndexedTensorDes assert testLoopRangesEqual(newRange, self.loopRanges) assert testLoopRangesAContainedInB(newRange, rR) - self.loopRanges(newRange) + self.loopRanges.update(newRange) def generator(arch, descr, target): if target == 'cpu': diff --git a/yateto/codegen/fused_gemms/external_generator.py b/yateto/codegen/fused_gemms/external_generator.py index 4fea176..3c5037e 100644 --- a/yateto/codegen/fused_gemms/external_generator.py +++ b/yateto/codegen/fused_gemms/external_generator.py @@ -11,7 +11,8 @@ class FusedGemms: def __init__(self, arch, descr): self._arch = arch self._descr = descr - self._batch_aux = BatchedOperationsAux(self._arch.typename) + self._datatype = self._descr[0].node.datatype + self._batch_aux = BatchedOperationsAux(self._datatype.ctype()) self._cache = {} self._tmp_matrices = {} @@ -39,7 +40,7 @@ def generate(self, cpp, routineCache, cfg): context = Context(arch=self._arch.name, backend=self._arch.backend, - fp_type=FloatingPointType.str2enum(self._arch.typename)) + fp_type=FloatingPointType.str2enum(self._datatype.ctype())) chainforge_generator = ChainForgeGenerator(gemm_list, context) chainforge_generator.register() @@ -123,7 +124,7 @@ def _gen_call_site(self, generator): offset_name_map = {} for name, matrix in self._cache.items(): if matrix.direction == DataFlowDirection.SOURCE: - ptr_type = f'{self._arch.typename} {Addressing.addr2ptr_type(matrix.addressing)}' + ptr_type = f'{self._datatype.ctype()} {Addressing.addr2ptr_type(matrix.addressing)}' mat_name_map[name] = f'const_cast<{ptr_type}>({name})' else: mat_name_map[name] = name diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 0c90612..86bca4c 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -9,6 +9,7 @@ from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge, tinytc from ..common import BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper from ..tiny_tensor_language import * +from ...type import Datatype import importlib.util @@ -448,7 +449,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._callGenerator(argList) if self._mode == 'pspamm': - return 'void {name}(const {type}* A, const {type}* B, {type}* C, {type} alpha, {type} beta, const {type}* prefetch);'.format(name=routineName, type=self._arch.typename) + return 'void {name}(const {atype}* A, const {btype}* B, {ctype}* C, {ctype} alpha, {ctype} beta, const {ctype}* prefetch);'.format(name=routineName, + atype=self._gemmDescr['datatypeA'].ctype(), + btype=self._gemmDescr['datatypeB'].ctype(), + ctype=self._gemmDescr['datatypeC'].ctype(), + ) # LIBXSMM header if self._gemmDescr['prefetch'] == 'nopf': diff --git a/yateto/codegen/log/generic.py b/yateto/codegen/log/generic.py index 366d041..10f22fb 100644 --- a/yateto/codegen/log/generic.py +++ b/yateto/codegen/log/generic.py @@ -14,7 +14,7 @@ def _pointer(self, cpp, targetName, baseName, term, loopIndices, const=True): addressStr = term.memoryLayout.addressString(term.indices, indices) if len(indices) > 0 else '' if len(addressStr) > 0: addressStr = ' + ' + addressStr - cpp('{} {}* {} = {}{};'.format(self._arch.typename, 'const' if const else '', targetName, baseName, addressStr)) + cpp('{} {}* {} = {}{};'.format(term.datatype.ctype(), 'const' if const else '', targetName, baseName, addressStr)) def _alignedStart(self, term, loopIndices): if len(loopIndices) == 0: diff --git a/yateto/codegen/reduction/__init__.py b/yateto/codegen/reduction/__init__.py new file mode 100644 index 0000000..fb914e0 --- /dev/null +++ b/yateto/codegen/reduction/__init__.py @@ -0,0 +1 @@ +from .factory import Description, generator diff --git a/yateto/codegen/reduction/factory.py b/yateto/codegen/reduction/factory.py new file mode 100644 index 0000000..69f8fea --- /dev/null +++ b/yateto/codegen/reduction/factory.py @@ -0,0 +1,30 @@ +from ..common import * +from .generic import Generic + +from ...ops import Operation + +class Description(object): + def __init__(self, alpha, add: bool, result: IndexedTensorDescription, term: IndexedTensorDescription, optype: Operation): + self.alpha = alpha + self.add = add + self.result = result + self.term = term + self.optype = optype + + rA = loopRanges(self.term, self.result.indices) + rB = loopRanges(self.result, self.result.indices) + assert testLoopRangesAContainedInB(rA, rB) + + self.loopRanges = rA + + self.sumIndex = self.term.indices - self.result.indices + assert len(self.sumIndex) == 1 + + self.sumLoopRange = loopRanges(self.term, self.sumIndex)[str(self.sumIndex)] + + +def generator(arch, descr, target): + if target == 'cpu': + return Generic(arch, descr) + elif target == 'gpu': + raise RuntimeError("IndexSum operation has not been implemented for GPU-like architectures") \ No newline at end of file diff --git a/yateto/codegen/reduction/generic.py b/yateto/codegen/reduction/generic.py new file mode 100644 index 0000000..d257bd2 --- /dev/null +++ b/yateto/codegen/reduction/generic.py @@ -0,0 +1,31 @@ +from ..common import * + +class Generic(object): + def __init__(self, arch, descr): + self._arch = arch + self._descr = descr + + def generate(self, cpp, routineCache): + d = self._descr + + if not d.add: + writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) + initializeWithZero(cpp, d.result, writeBB) + + sumIndex = d.term.indices - d.result.indices + assert len(sumIndex) == 1 + class IndexSumBody(object): + def __call__(s): + target = '{}[{}]'.format(d.result.name, d.result.memoryLayout.addressString(d.result.indices)) + initialValue = target if d.add else d.result.datatype.literal(d.optype.neutral()) + cpp(f'{d.result.datatype.ctype()} acc = {initialValue};') + with cpp.For('int {0} = {1}; {0} < {2}; ++{0}'.format(sumIndex, d.sumLoopRange.start, d.sumLoopRange.stop)): + argstr = {d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}] + cpp( f'acc = {d.optype.callstr('acc', argstr)};' ) + mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' + cpp( f'{target} = {mult}acc;' ) + + flop = 1 if d.alpha != 1.0 else 0 + return d.sumLoopRange.size() + flop + + return forLoops(cpp, d.result.indices, d.loopRanges, IndexSumBody()) diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index a873a4e..3312d2c 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -134,6 +134,7 @@ def __init__(self, function, tmp_mem_size, is_compute_constant_tensors, + datatype, target): self.nonZeroFlops = nonZeroFlops @@ -145,6 +146,7 @@ def __init__(self, self.function = function self.tmp_mem_size = tmp_mem_size self.is_compute_constant_tensors = is_compute_constant_tensors + self.datatype = datatype self.target = target @classmethod @@ -166,12 +168,16 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): writable = dict() is_compute_constant_tensors = dict() scalars = collections.OrderedDict() + datatype = dict() for scalar in scalarsP: self.KernelOutline._addTensor(scalar, scalars) + datatype[scalar.baseNameWithNamespace()] = scalar.datatype for var in variables: self.KernelOutline._addTensor(var.tensor, tensors) bn = var.tensor.baseNameWithNamespace() + datatype[bn] = var.datatype + if bn in writable: if var.writable: writable[bn] = True @@ -204,6 +210,7 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): function, tmp_memory, is_compute_constant_tensors, + datatype, target) @classmethod @@ -221,6 +228,7 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): writable = dict() scalars = collections.OrderedDict() is_compute_constant_tensors = dict() + datatype = dict() for ko in kernelOutlines: if ko: self._addFromKO(ko.scalars, scalars) @@ -228,6 +236,7 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): self._addFromKO(ko.writable, writable) self._addFromKO(ko.prefetch, prefetch) self._addFromKO(ko.is_compute_constant_tensors, is_compute_constant_tensors) + self._addFromKO(ko.datatype, datatype) target = kernelOutlines[-1].target is_same_target = True @@ -283,9 +292,9 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): header.emptyline() - def kernelArgs(base_name_with_namespace, groups, writable, is_constant, target): + def kernelArgs(base_name_with_namespace, groups, writable, is_constant, datatype, target): prefix, base_name = Tensor.splitBasename(base_name_with_namespace) - typ = self._arch.typename + typ = datatype.ctype() ptr_type = '**' if not is_constant and target == 'gpu' else '*' if not writable: typ += ' const' @@ -294,11 +303,11 @@ def kernelArgs(base_name_with_namespace, groups, writable, is_constant, target): container_type = f'{InitializerGenerator.CONTAINER_CLASS_NAME}<{typ}{ptr_type}>' header(f'{class_name}::{container_type} {base_name};') else: - header(f'{typ}{ptr_type} {base_name}{{}};') + header(f'{typ}{ptr_type} {base_name}{{nullptr}};') - def scalarArgs(base_name_with_namespace, groups): + def scalarArgs(base_name_with_namespace, datatype, groups): prefix, base_name = Tensor.splitBasename(base_name_with_namespace) - typ = self._arch.typename + typ = datatype.ctype() if len(next(iter(groups))) > 0: class_name = f'{prefix}{InitializerGenerator.TENSOR_NAMESPACE}::{base_name}' container_type = f'{InitializerGenerator.CONTAINER_CLASS_NAME}<{typ}>' @@ -308,12 +317,14 @@ def scalarArgs(base_name_with_namespace, groups): for baseName, groups in scalars.items(): scalarArgs(baseName, + datatype[baseName], groups) for baseName, groups in tensors.items(): kernelArgs(baseName, groups, writable[baseName], is_compute_constant_tensors[baseName], + datatype[baseName], target) header.emptyline() diff --git a/yateto/functions.py b/yateto/functions.py index bbcbe38..b4dfa37 100644 --- a/yateto/functions.py +++ b/yateto/functions.py @@ -1,5 +1,6 @@ from . import ops from .ast import node +from .type import Datatype def sin(x): return node.Elementwise(ops.Sin(), x) def cos(x): return node.Elementwise(ops.Cos(), x) @@ -34,8 +35,17 @@ def assignIf(condition, lhs, rhs): return node.Assign(lhs, rhs, condition) # def where(condition, yes, no): return node.IfThenElse(condition, yes, no) def where(condition, yes, no): return node.Elementwise(ops.Ternary(), yes, no, condition) +def equal(x, y): return node.Elementwise(ops.CmpEq(), x, y) +def not_equal(x, y): return node.Elementwise(ops.CmpNe(), x, y) +def less(x, y): return node.Elementwise(ops.CmpLt(), x, y) +def less_equal(x, y): return node.Elementwise(ops.CmpLe(), x, y) +def greater(x, y): return node.Elementwise(ops.CmpGt(), x, y) +def greater_equal(x, y): return node.Elementwise(ops.CmpGe(), x, y) + # extra reduction functions; e.g. for input to `where` def reductionSum(term, indices): return node.Reduction(ops.Add(), term, indices) def reductionMul(term, indices): return node.Reduction(ops.Mul(), term, indices) def reductionAnd(term, indices): return node.Reduction(ops.And(), term, indices) def reductionOr(term, indices): return node.Reduction(ops.Or(), term, indices) + +def cast(x, dtype): return node.Elementwise(ops.Typecast(dtype), x) diff --git a/yateto/ops.py b/yateto/ops.py index 24493d6..2b94525 100644 --- a/yateto/ops.py +++ b/yateto/ops.py @@ -1,4 +1,5 @@ import numpy as np +from .type import Datatype class Operation: #def callstr(self, *args) -> str: @@ -7,10 +8,7 @@ class Operation: def call(self, *args): raise NotImplementedError() - def datatypeArgs(self): - raise NotImplementedError() # TODO - - def datatypeResult(self): + def datatypeResult(self, argtypes): raise NotImplementedError() # TODO def __str__(self): @@ -60,99 +58,137 @@ def cppname(self): return 'std::sin' def call(self, *args): return np.sin(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Cos(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::cos' def call(self, *args): return np.cos(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Tan(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::tan' def call(self, *args): return np.tan(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Asin(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::asin' def call(self, *args): return np.asin(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Acos(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::acos' def call(self, *args): return np.acos(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Atan(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::atan' def call(self, *args): return np.atan(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Sinh(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::sinh' def call(self, *args): return np.sinh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Cosh(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::cosh' def call(self, *args): return np.cosh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Tanh(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::tanh' def call(self, *args): return np.tanh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Asinh(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::asinh' def call(self, *args): return np.asinh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Acosh(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::acosh' def call(self, *args): return np.acosh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Atanh(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::atanh' def call(self, *args): return np.atanh(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Log(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::log' def call(self, *args): return np.log(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Exp(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::exp' def call(self, *args): return np.exp(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Log1p(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::log1p' def call(self, *args): return np.log1p(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Expm1(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::expm1' def call(self, *args): return np.expm1(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Sqrt(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::sqrt' def call(self, *args): return np.sqrt(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Cbrt(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::cbrt' def call(self, *args): return np.cbrt(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] -class Cbrt(Operation, CFunctionMixin, UnaryArgsMixin): +class Abs(Operation, CFunctionMixin, UnaryArgsMixin): def cppname(self): return 'std::abs' def call(self, *args): return np.abs(args[0]) + def datatypeResult(self, argtypes): + return argtypes[0] class Max(Operation, CFunctionMixin, BinaryArgsMixin): @@ -160,22 +196,33 @@ def cppname(self): return 'std::max' def call(self, *args): return max(args[0], args[1]) + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class Min(Operation, CFunctionMixin, BinaryArgsMixin): def cppname(self): return 'std::min' def call(self, *args): return min(args[0], args[1]) + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class Pow(Operation, CFunctionMixin, BinaryArgsMixin): def cppname(self): return 'std::pow' def call(self, *args): return pow(args[0], args[1]) + def datatypeResult(self, argtypes): + return argtypes[0] class Div(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '/' def call(self, *args): return args[0] / args[1] + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class Add(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin): def cppname(self, *args): @@ -184,6 +231,9 @@ def call(self, *args): return args[0] + args[1] def neutral(self): return 0 + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class Mul(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): def cppname(self, *args): return '*' @@ -193,6 +243,9 @@ def neutral(self): return 1 def formsRing(self, op): return op == Add() + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class And(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): def cppname(self, *args): @@ -203,6 +256,9 @@ def neutral(self): return True def formsRing(self, op): return op == Or() + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class Or(Operation, CBinaryOperatorMixin, BinaryArgsMixin, CommutativeMonoidMixin, RingMixin): def cppname(self, *args): return '|' @@ -212,42 +268,60 @@ def neutral(self): return False def formsRing(self, op): return op == And() + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class Not(Operation, CUnaryOperatorMixin, UnaryArgsMixin): def cppname(self, *args): return '~' def call(self, *args): return ~args[0] + def datatypeResult(self, argtypes): + # assert argtypes[0] == argtypes[1] + return argtypes[0] class CmpEq(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '==' def call(self, *args): return args[0] == args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL class CmpNe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '!=' def call(self, *args): return args[0] != args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL class CmpLt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '<' def call(self, *args): return args[0] < args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL class CmpLe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '<=' def call(self, *args): return args[0] <= args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL class CmpGt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '>' def call(self, *args): return args[0] > args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL class CmpGt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '>=' def call(self, *args): return args[0] >= args[1] + def datatypeResult(self, argtypes): + return Datatype.BOOL # replacement; however it'll execute both code paths, regardless of the result class Ternary(Operation): @@ -255,3 +329,17 @@ def callstr(self, *args): return f'(({args[2]}) ? ({args[0]}) : ({args[1]}))' def call(self, *args): return np.where(args[2], args[0], args[1]) + def datatypeResult(self, argtypes): + assert argtypes[0] == argtypes[1] + return argtypes[0] +class Typecast(Operation, CFunctionMixin, UnaryArgsMixin): + def __init__(self, target: Datatype): + self.target = target + def cppname(self, *args): + return f'static_cast<{self.target.ctype()}>' + def call(self, *args): + return np.astype(args[0], self.target.nptype()) + def datatypeResult(self, argtypes): + return self.target + def __str__(self): + return f'Cast<{self.target}>' diff --git a/yateto/type.py b/yateto/type.py index e1b085e..1c00db3 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -1,10 +1,11 @@ import re -from .ast.node import Node, IndexedTensor from numpy import ndarray, zeros, float64 from .memory import DenseMemoryLayout from . import aspp from enum import Enum +import numpy as np + class Datatype(Enum): BOOL = 0 I8 = 1 @@ -42,6 +43,19 @@ def ctype(self): Datatype.BF16: 'int16_t', }[self] + def nptype(self): + return { + Datatype.BOOL: np.bool, + Datatype.I8: np.int8, + Datatype.I16: np.int16, + Datatype.I32: np.int32, + Datatype.I64: np.int64, + Datatype.F32: np.float32, + Datatype.F64: np.float64, + Datatype.F16: np.float16, + Datatype.BF16: np.float32, # NYI + }[self] + def size(self): # unpacked size return { @@ -231,6 +245,7 @@ def setGroupSpp(self, spp): self.setMemoryLayout(self._memoryLayout.__class__, alignStride=self._memoryLayout.alignedStride()) def __getitem__(self, indexNames): + from .ast.node import IndexedTensor return IndexedTensor(self, indexNames) def shape(self): From 8c5afc4f7fc6eb0eb3337c51f719cc30c61ca4de Mon Sep 17 00:00:00 2001 From: David Schneller Date: Thu, 17 Apr 2025 16:18:59 +0200 Subject: [PATCH 08/18] Add primitive conditional execution --- yateto/codegen/factory.py | 109 ++++++++++++++------- yateto/codegen/visitor.py | 4 +- yateto/controlflow/graph.py | 157 +++++++++++++++++++++++++++++- yateto/controlflow/transformer.py | 10 +- yateto/controlflow/visitor.py | 43 ++++++-- 5 files changed, 268 insertions(+), 55 deletions(-) diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index da5646e..02758e6 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -23,7 +23,7 @@ def create(self, node, *args): def generic_create(self, node, *args): raise NotImplementedError - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): raise NotImplementedError def temporary(self, bufname, size, datatype, iniZero=False, memory=list()): @@ -92,12 +92,22 @@ def reset_flags(self): def _indices(self, var): shape = var.memoryLayout().shape() return Indices(string.ascii_lowercase[:len(shape)], shape) + + def _conditional(self, condition, generate): + if isinstance(condition, bool): + if condition: + return generate() + else: + return 0 + else: + with self._cpp.If(f'{condition.ccode()}'): + return generate() class OptimizedKernelFactory(KernelFactory): def __init__(self, cpp, arch, target): super().__init__(cpp, arch, target) - def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 description = log.Description( alpha = scalar, @@ -111,14 +121,14 @@ def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName prefetchName = prefetchName ) generator = log.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache, gemm_cfg) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache, gemm_cfg)) - def create_FusedGEMMs(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): - description = fused_gemms.Description(node, result, arguments, add, scalar) + def create_FusedGEMMs(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + description = fused_gemms.Description(node, result, arguments, condition, add, scalar) generator = fused_gemms.generator(self._arch, description, gemm_cfg, self._target) - return generator.generate(self._cpp, routineCache, gemm_cfg) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache, gemm_cfg)) - def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 1 description = indexsum.Description( alpha = scalar, @@ -127,9 +137,9 @@ def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, ro term = IndexedTensorDescription.fromNode(arguments[0], node.term()) ) generator = indexsum.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - def create_Product(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 description = product.Description( alpha = scalar, @@ -139,9 +149,9 @@ def create_Product(self, node, result, arguments, add, scalar, prefetchName, rou rightTerm = IndexedTensorDescription.fromNode(arguments[1], node.rightTerm()) ) generator = product.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] description = copyscaleadd.Description( alpha = scalar, @@ -150,9 +160,9 @@ def create_Permute(self, node, result, arguments, add, scalar, prefetchName, rou term = IndexedTensorDescription.fromVar(term, node.term().indices) ) generator = copyscaleadd.generator(self._arch, description, gemm_cfg, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - def create_Elementwise(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): description = elementwise.Description( alpha = scalar, add = add, @@ -163,9 +173,20 @@ def create_Elementwise(self, node, result, arguments, add, scalar, prefetchName, nodeTermIndices = node.nodeTermIndices ) generator = elementwise.generator(self._arch, description, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) + + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + description = reduction.Description( + alpha = scalar, + add = add, + result = IndexedTensorDescription.fromNode(result, node), + term = IndexedTensorDescription.fromNode(arguments[0], node.term()), + optype = node.optype, + ) + generator = reduction.generator(self._arch, description, self._target) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): description = copyscaleadd.Description( alpha = scalar, beta = 1.0 if add else 0.0, @@ -173,7 +194,7 @@ def simple(self, result, term, add, scalar, routineCache, gemm_cfg): term = IndexedTensorDescription.fromVar(term, self._indices(term)) ) generator = copyscaleadd.generator(self._arch, description, gemm_cfg, self._target) - return generator.generate(self._cpp, routineCache) + return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) class UnitTestFactory(KernelFactory): def __init__(self, cpp, arch, nameFun, testFramework): @@ -186,7 +207,7 @@ def _formatTerm(self, var, indices): address = var.memoryLayout().addressString(indices) return f'{self._name(var)}[{address}]' - def create_Einsum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Einsum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = node.indices for child in node: g = g.merged(child.indices - g) @@ -200,34 +221,52 @@ def create_Einsum(self, node, result, arguments, add, scalar, prefetchName, rout terms.insert(0, str(scalar)) if not add: - self._cpp.memset(self._name(result), result.memoryLayout().requiredReals(), self._arch.typename) + self._cpp.memset(self._name(result), result.memoryLayout().requiredReals(), result.datatype.ctype()) class EinsumBody(object): def __call__(s): self._cpp(f"{resultTerm} += {' * '.join(terms)};") return len(terms) - return forLoops(self._cpp, g, ranges, EinsumBody(), pragmaSimd=False) + return self._conditional(condition, lambda: forLoops(self._cpp, g, ranges, EinsumBody(), pragmaSimd=False)) - def create_ScalarMultiplication(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): - return self.simple(result, arguments[0], add, scalar, routineCache) + def create_ScalarMultiplication(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + return self._conditional(condition, lambda: self.simple(result, arguments[0], add, scalar, routineCache)) - def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert node.indices <= node.term().indices and node.term().indices <= node.indices resultTerm = self._formatTerm(result, node.indices) termTerm = self._formatTerm(arguments[0], node.term().indices) - return self._simpleBody(resultTerm, termTerm, add, scalar, node.indices) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, node.indices)) - def create_Elementwise(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + + argTerms = [self._formatTerm(argument, term.indices) for argument, term in zip(arguments, node)] + termTerm = f'({argTerms[0]}) * ({argTerms[1]})' + + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) + + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, node.indices) argTerms = [self._formatTerm(argument, term.indices) for argument, term in zip(arguments, node)] termTerm = node.optype.callstr(*node.fillTerms(argTerms)) - return self._simpleBody(resultTerm, termTerm, add, scalar, g) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) + + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + g = self._indices(result) + resultTerm = self._formatTerm(result, node.indices) + argTerm = self._formatTerm(arguments[0], node.term()) + + termTerm = node.optype.callstr(*node.fillTerms(argTerms)) + + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) - def create_IfThenElse(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_IfThenElse(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, node.indices) yesTerm = self._formatTerm(arguments[0], node.yesTerm().indices) @@ -236,9 +275,9 @@ def create_IfThenElse(self, node, result, arguments, add, scalar, prefetchName, termTerm = f'(({conditionTerm}) ? ({yesTerm}) : ({noTerm}))' - return self._simpleBody(resultTerm, termTerm, add, scalar, g) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) - def _simpleBody(self, resultTerm, termTerm, add, scalar, indices): + def _simpleBody(self, resultTerm, termTerm, add, scalar, indices, reduceIdx = None): ranges = {idx: Range(0, indices.indexSize(idx)) for idx in indices} if scalar and scalar != 1.0: @@ -251,13 +290,13 @@ def __call__(s): return forLoops(self._cpp, indices, ranges, AssignBody(), pragmaSimd=False) - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, g) termTerm = self._formatTerm(term, g) - return self._simpleBody(resultTerm, termTerm, add, scalar, g) + return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) def compare(self, ref, target, epsMult = 100.0): g = self._indices(ref) @@ -328,29 +367,29 @@ def post_generate(self, routine_cache): def allocateTemporary(self): return False - def create_LoopOverGEMM(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 makeNode = IndexedTensorDescription.fromNode argnodes = [makeNode(arguments[0], node.leftTerm()), makeNode(arguments[1], node.rightTerm())] return self.handleLinear(makeNode(result, node), argnodes, add, scalar, node.transA(), node.transB()) - def create_IndexSum(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 1 makeNode = IndexedTensorDescription.fromNode argnodes = [makeNode(arguments[0], node.term())] return self.handleLinear(makeNode(result, node), argnodes, add, scalar, False, False) - def create_Product(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 makeNode = IndexedTensorDescription.fromNode argnodes = [makeNode(arguments[0], node.leftTerm()), makeNode(arguments[1], node.rightTerm())] return self.handleLinear(makeNode(result, node), argnodes, add, scalar, False, False) - def create_Permute(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): + def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] return self.handleLinear(IndexedTensorDescription(str(result), node.indices, result.memoryLayout(), result.eqspp()), [IndexedTensorDescription(str(term), node.term().indices, term.memoryLayout(), term.eqspp())], add, scalar, False, False) - def simple(self, result, term, add, scalar, routineCache, gemm_cfg): + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): return self.handleLinear(IndexedTensorDescription(str(result), self._indices(result), result.memoryLayout(), result.eqspp()), [IndexedTensorDescription(str(term), self._indices(term), term.memoryLayout(), term.eqspp())], add, scalar, False, False) def getIndices(self, dest, ops): diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index 3312d2c..e1a26d7 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -93,9 +93,9 @@ def generate(self, cpp, cfg, factory, routineCache, gemm_cfg): scalar = self.deduce_scalar(action) if action.isRHSExpression(): prefetchName = '{}.{}'.format(self.PREFETCHVAR_NAME, action.term.node.prefetch.name()) if action.term.node.prefetch is not None else None - hwFlops += factory.create(action.term.node, action.result, action.term.variableList(), action.add, scalar, prefetchName, routineCache, gemm_cfg) + hwFlops += factory.create(action.term.node, action.result, action.term.variableList(), action.condition, action.add, scalar, prefetchName, routineCache, gemm_cfg) else: - hwFlops += factory.simple(action.result, action.term, action.add, scalar, routineCache, gemm_cfg) + hwFlops += factory.simple(action.result, action.term, action.condition, action.add, scalar, routineCache, gemm_cfg) return hwFlops, required_tmp_mem class OptimizedKernelGenerator(KernelGenerator): diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index 771964a..b2045a9 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -1,7 +1,8 @@ from ..ast.node import Node, FusedGEMMs, LoopOverGEMM +from ..ast.indices import Indices from collections import OrderedDict from typing import Dict, List - +from ..type import Scalar class Variable(object): def __init__(self, name, writable, memoryLayout, eqspp=None, tensor=None, is_temporary=False, datatype=None): @@ -97,11 +98,12 @@ def setWritable(self, name): class ProgramAction(object): - def __init__(self, result, term, add, scalar=None): + def __init__(self, result, term, add, scalar=None, condition=True): self.result = result self.term = term self.add = add self.scalar = scalar + self.condition = condition def isRHSExpression(self): return isinstance(self.term, Expression) @@ -135,7 +137,7 @@ def maySubstitute(self, when, by, result = True, term = True): def substituted(self, when, by, result = True, term = True): rsubs = self.result.substituted(when, by) if result else self.result tsubs = self.term.substituted(when, by, rsubs.memoryLayout()) if term else self.term - return ProgramAction(rsubs, tsubs, self.add, self.scalar) + return ProgramAction(rsubs, tsubs, self.add, self.scalar, self.condition) def setVariablesWritable(self, name): self.result.setWritable(name) @@ -149,6 +151,7 @@ def __init__(self): self._variables: List[Variable] = [] self._adds: List[bool] = [] self._scalars = [] + self._conditions = [] def add(self, action: ProgramAction) -> None: if not isinstance(action.term.node, LoopOverGEMM): @@ -160,13 +163,15 @@ def add(self, action: ProgramAction) -> None: self._variables.extend(action.term.variableList()) self._adds.append(action.add) self._scalars.append(action.scalar) + self._conditions.append(action.condition) def gen_program_action(self) -> ProgramAction: last_action: ProgramAction = self._actions[-1] return ProgramAction(result=last_action.result, term=self._gen_expr(), add=self._adds, - scalar=self._scalars) + scalar=self._scalars, + condition=self._conditions) def _gen_expr(self) -> Expression: node = FusedGEMMs() @@ -189,6 +194,150 @@ def __init__(self, action): self.initBuffer = None self.bufferMap = None +# a rather primitive CNF implementation. +# do not overuse (i.e. avoid conditional assigns where possible) + +class CNFClause: + def __init__(self, variables): + if isinstance(variables, list): + self.variables = {var: True for var in variables} + else: + self.variables = variables + self.fulfilled = False + + def negateVariables(self): + return {var: ~self.variables[var] for var in self.variables} + + def unite(self, clause): + output = CNFClause([]) + for v in self.variables: + if v in clause and clause.variables[v] != self.variables[v]: + output.fulfilled = True + if not output.fulfilled: + output.variables = {**self.variables, **clause.variables} + return output + + def __repr__(self): + formatvar = lambda name: f'{name}' if self.variables[name] else f'~{name}' + return f'[{", ".join(formatvar(var) for var in self.variables)}]' + + def ccode(self): + # for now, only allow scalarly-indexed variables + printvar = lambda var: f'{var}' if isinstance(var, Scalar) else f'{var}[{var.memoryLayout().addressString(Indices())}]' + formatvar = lambda name: f'{printvar(name)}' if self.variables[name] else f'!{printvar(name)}' + return f'({" || ".join(formatvar(var) for var in self.variables)})' + +class CNFCondition: + def __init__(self, data): + if isinstance(data, bool): + if data == True: + self.clauses = [] + elif data == False: + self.clauses = [CNFClause([])] + else: + self.clauses = [CNFClause([data])] + + def _prune(self): + newclauses = [] + for clause in self.clauses: + if clause.fulfilled: + newclauses = [] + break + else: + if len(clause.variables) == 0: + newclauses = [clause] + break + else: + newclauses += [clause] + self.clauses = newclauses + + def tautology(self): + return len(self.clauses) == 0 + + def unfulfillable(self): + return any(not clause.fulfilled and len(clause.variables) == 0 for clause in self.clauses) + + def __not__(self): + if self.tautology(): + return CNFCondition(False) + + # this is the actually painful step (as it's also pretty inefficient right now) + result = CNFCondition(True) + for clause in self.clauses: + clauseInv = CNFCondition(True) + negVar = clause.negateVariables() + clauseInv.clauses = [CNFClause({var: negVar[var]}) for var in negVar] + + result = result | clauseInv + return result + + def __rand__(self, other): + if not isinstance(other, CNFCondition): + other = CNFCondition(other) + + clauses = self.clauses + other.clauses + + condition = CNFCondition(True) + condition.clauses = clauses + condition._prune() + + return condition + + def __ror__(self, other): + if not isinstance(other, CNFCondition): + other = CNFCondition(other) + + clauses = [clause.unite(oclause) for clause in self.clauses for oclause in other.clauses] + condition = CNFCondition(True) + condition.clauses = clauses + condition._prune() + + return condition + + def __repr__(self): + return f'{self.clauses}' + + def ccode(self): + return f'({" && ".join(clause.ccode() for clause in self.clauses)})' + +class LiveSet: + def __init__(self, data: dict): + self.data = data + + def __sub__(self, other): + if isinstance(other, dict): + other = LiveSet(other) + + result = {k:self.data[k] for k in self.data} + + for var in other.data: + if var in result: + result[var] &= ~other.data[var] + if not result[var]: + result.remove(var) + + return LiveSet(result) + + def __or__(self, other): + if isinstance(other, dict): + other = LiveSet(other) + + result = {k:self.data[k] for k in self.data} + + for var in other.data: + if var in result: + result[var] |= other.data[var] + + return LiveSet(result) + + def __contains__(self, element): + if isinstance(element, tuple): + return element[0] in self.data and (~self.data[element[0]] & element[1]).unfulfillable() + else: + return element in self.data + + def variables(self): + return set(k for k in self.data) class FusedProgramPoint(ProgramPoint): def __init__(self, action: FusedActions): diff --git a/yateto/controlflow/transformer.py b/yateto/controlflow/transformer.py index 605ba0c..9610c40 100644 --- a/yateto/controlflow/transformer.py +++ b/yateto/controlflow/transformer.py @@ -23,9 +23,9 @@ def visit(self, cfg): class LivenessAnalysis(object): def visit(self, cfg): - cfg[-1].live = set() + cfg[-1].live = LiveSet({}) for i in reversed(range(len(cfg)-1)): - cfg[i].live = (cfg[i+1].live - {cfg[i].action.result}) | cfg[i].action.variables() + cfg[i].live = (cfg[i+1].live - {cfg[i].action.result: cfg[i].action.condition}) | {var:cfg[i].action.condition for var in cfg[i].action.variables()} return cfg class SubstituteForward(object): @@ -34,7 +34,7 @@ def visit(self, cfg): for i in range(n): ua = cfg[i].action v = cfg[i+1] - if not ua.isCompound() and ua.isRHSVariable() and ua.term.writable and ua.result.isLocal() and ua.term not in v.live: + if not ua.isCompound() and ua.isRHSVariable() and ua.term.writable and ua.result.isLocal() and (ua.term, ua.condition) not in v.live: when = ua.result by = ua.term maySubs = all([cfg[j].action.maySubstitute(when, by) for j in range(i, n)]) @@ -54,7 +54,7 @@ def visit(self, cfg): found = -1 for j in range(i): u = cfg[j] - if by not in u.live and not u.action.isCompound() and u.action.result == va.term: + if (va.result, va.condition) not in u.live and not u.action.isCompound() and u.action.result == va.term: found = j break if found >= 0: @@ -145,7 +145,7 @@ def visit(self, cfg): bufferSize[buf] = size # free buffers - free = cfg[i].live - cfg[i+1].live + free = cfg[i].live.variables() - cfg[i+1].live.variables() for local in free: if local in usedBuffers: freeBuffers.appendleft(usedBuffers.pop(local)) diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index d5a0ddd..8c73249 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -2,7 +2,7 @@ from yateto import Scalar from .graph import * from ..memory import DenseMemoryLayout -from ..ast.node import Permute +from ..ast.node import Permute, Node class AST2ControlFlow(Visitor): TEMPORARY_RESULT = '_tmp' @@ -12,6 +12,7 @@ def __init__(self, simpleMemoryLayout=False): self._cfg = [] self._writable = set() self._simpleMemoryLayout = simpleMemoryLayout + self._condition = [True] def cfg(self): return self._cfg + [ProgramPoint(None)] @@ -27,7 +28,7 @@ def _addPermuteIfRequired(self, indices, term, variable): permute.computeMemoryLayout() permute.datatype = term.datatype result = self._nextTemporary(permute) - action = ProgramAction(result, Expression(permute, self._ml(permute), [variable]), False) + action = ProgramAction(result, Expression(permute, self._ml(permute), [variable]), False, condition=self._condition[-1]) self._addAction(action) return result return variable @@ -36,7 +37,7 @@ def generic_visit(self, node): variables = [self.visit(child) for child in node] result = self._nextTemporary(node) - action = ProgramAction(result, Expression(node, self._ml(node), variables), False) + action = ProgramAction(result, Expression(node, self._ml(node), variables), False, condition=self._condition[-1]) self._addAction(action) return result @@ -51,7 +52,7 @@ def visit_Add(self, node): add = False for i,var in enumerate(variables): rhs = self._addPermuteIfRequired(node.indices, node[i], var) - action = ProgramAction(tmp, rhs, add) + action = ProgramAction(tmp, rhs, add, condition=self._condition[-1]) self._addAction(action) add = True @@ -61,24 +62,48 @@ def visit_ScalarMultiplication(self, node): variable = self.visit(node.term()) result = self._nextTemporary(node) - action = ProgramAction(result, variable, False, node.scalar()) + action = ProgramAction(result, variable, False, node.scalar(), condition=self._condition[-1]) self._addAction(action) return result def visit_Assign(self, node): + condition = self._condition[-1] + if isinstance(node.condition(), Node): + myCondition = self.visit(node[2]) + else: + myCondition = node.condition() + self.updateWritable(node[0].name()) - variables = [self.visit(child) for child in node] - rhs = self._addPermuteIfRequired(node.indices, node.rightTerm(), variables[1]) - action = ProgramAction(variables[0], rhs, False) + newCondition = condition & CNFCondition(myCondition) + self._condition.append(newCondition) + lVar = self.visit(node[0]) + rVar = self.visit(node[1]) + self._condition = self._condition[:-1] + + rhs = self._addPermuteIfRequired(node.indices, node.rightTerm(), rVar) + action = ProgramAction(lVar, rhs, False, condition=newCondition) self._addAction(action) - return variables[0] + return lVar def visit_IndexedTensor(self, node): return Variable(node.name(), node.name() in self._writable, self._ml(node), node.eqspp(), node.tensor, datatype=node.datatype) + def visit_IfThenElse(self, node): + if len(self._condition) > 0: + condition = self._condition.top() + else: + condition = True + self.visit(node.yesTerm()) + self.visit(node.noTerm()) + myCondition = node.condition() + self._condition.push(condition & myCondition) + self._condition.pop() + self._addAction(ProgramAction()) + return self.visit(node.term()) + def _addAction(self, action): self._cfg.append(ProgramPoint(action)) From 52c0fb348abaa93d748925142e76ee81d54339fe Mon Sep 17 00:00:00 2001 From: David Schneller Date: Thu, 24 Apr 2025 15:51:16 +0200 Subject: [PATCH 09/18] Continue the datatype and elementwise propagation --- yateto/arch.py | 16 +-- yateto/codegen/common.py | 53 ++++++--- yateto/codegen/copyscaleadd/csa_gen.py | 6 +- yateto/codegen/copyscaleadd/tinytc.py | 16 ++- yateto/codegen/factory.py | 55 +++++++-- .../codegen/fused_gemms/external_generator.py | 2 +- yateto/codegen/fused_gemms/tinytc.py | 13 +-- yateto/codegen/gemm/factory.py | 3 + yateto/codegen/gemm/gemmgen.py | 108 ++++++++++++------ yateto/codegen/visitor.py | 98 +++++++++------- yateto/controlflow/graph.py | 20 +++- yateto/controlflow/visitor.py | 2 +- yateto/functions.py | 14 ++- yateto/gemm_configuration.py | 105 ++++++++++------- yateto/type.py | 59 ++++++---- 15 files changed, 360 insertions(+), 210 deletions(-) diff --git a/yateto/arch.py b/yateto/arch.py index a2a87b4..27fb6ea 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -75,8 +75,8 @@ def __init__(self, else: raise ValueError(f'Unknown precision type {self.precision}') self.alignment = alignment - assert self.alignment % self.bytesPerReal == 0 - self.alignedReals = self.alignment // self.bytesPerReal + assert self.alignment % self.datatype.size() == 0 + self.alignedReals = self.alignment // self.datatype.size() self.enablePrefetch = enablePrefetch self.uintTypename = 'unsigned' @@ -105,20 +105,12 @@ def checkAlignment(self, offset): def formatConstant(self, constant): return self.datatype.literal(constant) - def onHeap(self, numReals): - return (numReals * self.bytesPerReal) > self._tmpStackLimit + def onHeap(self, byteCount): + return byteCount > self._tmpStackLimit def __eq__(self, other): return self.name == other.name - @property - def typename(self): - return self.datatype.ctype() - - @property - def bytesPerReal(self): - return self.datatype.size() - def _get_name_and_precision(ident): return ident[1:], ident[0].upper() diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index 0325aa9..299fd28 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -120,13 +120,12 @@ class BatchedOperationsAux: FLAGS_NAME = 'flags' FORBIDDEN_STREAM_PTR = 'reinterpret_cast(std::numeric_limits::max())' - def __init__(self, underlying_data_type): - self.underlying_data_type = underlying_data_type - - def _get_ptr_type(self, addressing: AddressingMode): + @classmethod + def _get_ptr_type(cls, addressing: AddressingMode): return addressing.pointer_type() - def deduce_addresing(self, term): + @classmethod + def deduce_addresing(cls, term): if term.addressing is not None: return term.addressing @@ -138,17 +137,20 @@ def deduce_addresing(self, term): else: return AddressingMode.INDIRECT - def deduce_ptr_arg(self, term, as_const=False): + @classmethod + def deduce_ptr_arg(cls, term, as_const=False): if as_const: - addressing = self.deduce_addresing(term) - ptr = self._get_ptr_type(addressing) - datatype = self.underlying_data_type if term.datatype is None else term.datatype.ctype() + addressing = cls.deduce_addresing(term) + ptr = cls._get_ptr_type(addressing) + assert term.datatype is not None + datatype = term.datatype.ctype() const_ptr_type = f'const {datatype} {ptr}' return f'const_cast<{const_ptr_type}>({term.name})' else: return f'{term.name}' - def deduce_offset_arg(self, term): + @classmethod + def deduce_offset_arg(cls, term): if term.is_compute_constant or term.is_temporary: return '0' else: @@ -156,11 +158,12 @@ def deduce_offset_arg(self, term): class TinytcKernelArgument: - def __init__(self, name: str, call_expr: str, constant: bool, temporary: bool, modified: bool, offset: int = 0): + def __init__(self, name: str, datatype: str, call_expr: str, constant: bool, temporary: bool, modified: bool, offset: int = 0): """Kernel argument for TinytcWrapper. Arguments: name -- Argument name + datatype -- Argument datatype in C/C++ call_expr -- Expression used in calling wrapper constant -- Whether a tensor is invariant to group id temporary -- Whether a tensor is stored in a temporary buffer @@ -181,7 +184,7 @@ def __init__(self, name: str, call_expr: str): class TinytcWrapper: - def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | TinytcScalarKernelArgument], real_type: str, name: str = ''): + def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | TinytcScalarKernelArgument], name: str = ''): self.kernel_name = kernel.name self.source = Dump().visit(kernel) if name: @@ -196,13 +199,13 @@ def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | Tiny self.call_args = [] for arg in arguments: if isinstance(arg, TinytcScalarKernelArgument): - self.wrapper_args.append(f'{real_type} {arg.name}') + self.wrapper_args.append(f'{arg.datatype} {arg.name}') self.wrapper_call_args.append(arg.name) self.call_args.append(arg.call_expr) else: ptr2ptr = '*' if not (arg.constant or arg.temporary) else '' const = ' const' if not (arg.modified or arg.temporary) else '' - wrapper_type = f'{real_type}{const}*{ptr2ptr}' + wrapper_type = f'{arg.datatype}{const}*{ptr2ptr}' self.wrapper_args.append(f'{wrapper_type} {arg.name}') self.wrapper_call_args.append(arg.name) self.call_args.append(f'const_cast<{wrapper_type}>({arg.call_expr})') @@ -280,3 +283,25 @@ def makeLoad(bb, operand, gid, isComputeConstant: bool, isTemporary: bool): return bb.add(SubviewInst(operand, offsetList, sizeList)) else: return bb.add(LoadInst(operand, [gid])) + +def toTinyTCType(datatype: Datatype): + return { + Datatype.BOOL: ScalarType(IntegerType.i1), # presumably, maybe i8 + Datatype.I8: ScalarType(IntegerType.i8), + Datatype.I16: ScalarType(IntegerType.i16), + Datatype.I32: ScalarType(IntegerType.i32), + Datatype.I64: ScalarType(IntegerType.i64), + Datatype.F32: ScalarType(IntegerType.f32), + Datatype.F64: ScalarType(IntegerType.f64) + }[datatype] + +def toTinyTCImmediate(datatype: Datatype, value): + immtype = { + Datatype.BOOL: lambda value: IntImmValue(IntegerType.i1, value), + Datatype.I8: lambda value: IntImmValue(IntegerType.i8, value), + Datatype.I16: lambda value: IntImmValue(IntegerType.i16, value), + Datatype.I32: lambda value: IntImmValue(IntegerType.i32, value), + Datatype.I64: lambda value: IntImmValue(IntegerType.i64, value), + Datatype.F32: lambda value: FloatImmValue(IntegerType.f32, value), + Datatype.F64: lambda value: FloatImmValue(IntegerType.f64, value), + }[datatype](value) diff --git a/yateto/codegen/copyscaleadd/csa_gen.py b/yateto/codegen/copyscaleadd/csa_gen.py index 12707b7..fad4c81 100644 --- a/yateto/codegen/copyscaleadd/csa_gen.py +++ b/yateto/codegen/copyscaleadd/csa_gen.py @@ -53,9 +53,9 @@ def _formatTerm(self, alpha, term): if alpha == 1.0: prefix = term.name else: - prefix = '{} * {}'.format(alpha, term.name) + prefix = f'{alpha} * {term.name}' - return '{}[{}]'.format(prefix, term.memoryLayout.addressString(term.indices)) + return f'{prefix}[{term.memoryLayout.addressString(term.indices)}]' def generate(self, cpp, routineCache): """Generates a tensor equation of a form: B = beta * B + alpha * A @@ -72,7 +72,7 @@ def generate(self, cpp, routineCache): n = d.loopRanges[d.result.indices[1]] alpha = d.alpha - aux = BatchedOperationsAux(d.result.datatype.ctype()) + aux = BatchedOperationsAux() matrix_a = gf.YatetoInterface.produce_dense_matrix((m, n), d.term.memoryLayout.bbox(), addressing=aux.deduce_addresing(d.term), diff --git a/yateto/codegen/copyscaleadd/tinytc.py b/yateto/codegen/copyscaleadd/tinytc.py index 3068b73..c94e203 100644 --- a/yateto/codegen/copyscaleadd/tinytc.py +++ b/yateto/codegen/copyscaleadd/tinytc.py @@ -11,17 +11,15 @@ class CopyScaleAddTinytc: def __init__(self, arch, descr): self._arch = arch self._descr = descr - self._scalar_type = 'f64' if self._arch.bytesPerReal == 8 else 'f32' - self._ty = ScalarType( - FloatingType.f64) if self._arch.bytesPerReal == 8 else ScalarType( - FloatingType.f32) def generate(self, cpp, routineCache): d = self._descr + ty = toTinyTCType(d.result.datatype) + # Order can be 1 or 2 def MakeLoopOverAxpby(d, order, transpose, A, B): - beta = FloatImmValue(self._ty, d.beta) + beta = FloatImmValue(ty, d.beta) A_offset_list = [None] * len(d.term.indices) A_size_list = [None] * len(d.term.indices) B_offset_list = [None] * len(d.result.indices) @@ -61,13 +59,13 @@ def MakeLoopOverAxpby(d, order, transpose, A, B): [ForInst(B_offset_list[j], start, stop, csa_region)]) return csa_region - alpha = LocalValue(self._ty, 'alpha') + alpha = LocalValue(ty, 'alpha') Abatch = LocalValue( - makeBatchType(self._ty, d.term.memoryLayout, + makeBatchType(ty, d.term.memoryLayout, d.term.is_compute_constant, d.term.is_temporary), 'A') Bbatch = LocalValue( - makeBatchType(self._ty, d.result.memoryLayout, + makeBatchType(ty, d.result.memoryLayout, d.result.is_compute_constant, d.result.is_temporary), 'B') kernel = Function('copyscaleadd', [alpha, Abatch, Bbatch], None) @@ -102,7 +100,7 @@ def MakeLoopOverAxpby(d, order, transpose, A, B): d.result.is_compute_constant, d.result.is_temporary, True) ] - wrapper = TinytcWrapper(kernel, args, self._arch.typename) + wrapper = TinytcWrapper(kernel, args) prototype = wrapper.prototype() routineCache.addRoutine(prototype, diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index 02758e6..e78b5a1 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -100,8 +100,13 @@ def _conditional(self, condition, generate): else: return 0 else: - with self._cpp.If(f'{condition.ccode()}'): + if condition.tautology(): return generate() + elif condition.unfulfillable(): + return 0 + else: + with self._cpp.If(f'{condition.ccode()}'): + return generate() class OptimizedKernelFactory(KernelFactory): def __init__(self, cpp, arch, target): @@ -351,6 +356,9 @@ def generate(self, cpp, cache): def add_linear_operation(self, dest, ops, target, permute, add): pass + + def add_operation(self, description): + pass class ExportFactory(KernelFactory): @classmethod @@ -367,6 +375,41 @@ def post_generate(self, routine_cache): def allocateTemporary(self): return False + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + result = IndexedTensorDescription.fromNode(result, node) + preArgs = [IndexedTensorDescription.fromNode(argument, term) for argument, term in zip(arguments, node)] + args = node.fillTerms(preArgs) + + description = { + 'type': 'elementwise', + 'result': result, + 'args': args, + 'linear': { + 'alpha': scalar, + 'add': add, + }, + 'optype': node.optype + } + return self.generator.add_operation(description) + + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): + assert len(arguments) == 1 + makeNode = IndexedTensorDescription.fromNode + result = makeNode(result, node) + argnodes = [makeNode(arguments[0], node.term())] + + description = { + 'type': 'reduction', + 'result': result, + 'args': argnodes, + 'linear': { + 'alpha': scalar, + 'add': add, + }, + 'optype': node.optype + } + return self.generator.add_operation(description) + def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 makeNode = IndexedTensorDescription.fromNode @@ -374,16 +417,10 @@ def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, p return self.handleLinear(makeNode(result, node), argnodes, add, scalar, node.transA(), node.transB()) def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): - assert len(arguments) == 1 - makeNode = IndexedTensorDescription.fromNode - argnodes = [makeNode(arguments[0], node.term())] - return self.handleLinear(makeNode(result, node), argnodes, add, scalar, False, False) + return create_Reduction(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): - assert len(arguments) == 2 - makeNode = IndexedTensorDescription.fromNode - argnodes = [makeNode(arguments[0], node.leftTerm()), makeNode(arguments[1], node.rightTerm())] - return self.handleLinear(makeNode(result, node), argnodes, add, scalar, False, False) + return create_Elementwise(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] diff --git a/yateto/codegen/fused_gemms/external_generator.py b/yateto/codegen/fused_gemms/external_generator.py index 3c5037e..04a5d7b 100644 --- a/yateto/codegen/fused_gemms/external_generator.py +++ b/yateto/codegen/fused_gemms/external_generator.py @@ -12,7 +12,7 @@ def __init__(self, arch, descr): self._arch = arch self._descr = descr self._datatype = self._descr[0].node.datatype - self._batch_aux = BatchedOperationsAux(self._datatype.ctype()) + self._batch_aux = BatchedOperationsAux() self._cache = {} self._tmp_matrices = {} diff --git a/yateto/codegen/fused_gemms/tinytc.py b/yateto/codegen/fused_gemms/tinytc.py index 18a59e0..14cbfa7 100644 --- a/yateto/codegen/fused_gemms/tinytc.py +++ b/yateto/codegen/fused_gemms/tinytc.py @@ -13,9 +13,6 @@ class FusedGemmsTinytc: def __init__(self, arch, descr): self._arch = arch self._descr = descr - self._ty = ScalarType( - FloatingType.f64) if self._arch.bytesPerReal == 8 else ScalarType( - FloatingType.f32) def generate(self, cpp, routineCache, cfg): args = dict() @@ -39,7 +36,7 @@ def addVal(var, node): is_constant[var] = node.tensor.is_compute_constant( ) if isinstance(node, IndexedTensor) else False arg = LocalValue( - makeBatchType(self._ty, node.memoryLayout(), + makeBatchType(toTinyTCType(var.datatype), node.memoryLayout(), is_constant[var], var.is_temporary), name) args[var] = arg vals[var] = makeLoad(bb, arg, gid, is_constant[var], var.is_temporary) @@ -59,7 +56,7 @@ def addVal(var, node): if res.is_temporary: res_val = bb.add( AllocaInst( - makeMemrefType(self._ty, res.memoryLayout(), False))) + makeMemrefType(toTinyTCType(res.datatype), res.memoryLayout(), False))) vals[res] = res_val else: modified.add(res) @@ -100,8 +97,8 @@ def offsetSizeLists(ml, range0, range1): *offsetSizeLists(node.memoryLayout(), m, n))) trans = lambda t: Transpose.t if t else Transpose.n - alpha = FloatImmValue(self._ty, scalar) - beta = FloatImmValue(self._ty, 1.0 if add else 0.0) + alpha = FloatImmValue(toTinyTCType(res.datatype), scalar) + beta = FloatImmValue(toTinyTCType(res.datatype), 1.0 if add else 0.0) bb.add( GemmInst(trans(node.transA()), trans(node.transB()), alpha, op1_sub, op2_sub, beta, res_sub)) @@ -117,7 +114,7 @@ def offsetSizeLists(ml, range0, range1): wrapper_args.append( TinytcKernelArgument(name, str(key), is_constant[key], key.is_temporary, key in modified)) - wrapper = TinytcWrapper(kernel, wrapper_args, self._arch.typename) + wrapper = TinytcWrapper(kernel, wrapper_args) cpp(wrapper.call()) prototype = wrapper.prototype() routineCache.addRoutine(prototype, diff --git a/yateto/codegen/gemm/factory.py b/yateto/codegen/gemm/factory.py index a2a98c0..f357c9a 100644 --- a/yateto/codegen/gemm/factory.py +++ b/yateto/codegen/gemm/factory.py @@ -85,6 +85,9 @@ def generator(arch, descr, gemm_cfg, target): descr.beta, descr.alignedA, descr.alignedC, + descr.leftTerm.datatype, + descr.rightTerm.datatype, + descr.result.datatype, target) if gemmTool: return GemmGen(arch, descr, gemmTool) diff --git a/yateto/codegen/gemm/gemmgen.py b/yateto/codegen/gemm/gemmgen.py index 86bca4c..10714bc 100644 --- a/yateto/codegen/gemm/gemmgen.py +++ b/yateto/codegen/gemm/gemmgen.py @@ -7,7 +7,7 @@ from ..cache import RoutineGenerator, GpuRoutineGenerator, TinytcWriter from ...gemm_configuration import BLASlike, CodeGenerator, GemmForge, tinytc -from ..common import BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper +from ..common import BatchedOperationsAux, TinytcKernelArgument, TinytcScalarKernelArgument, TinytcWrapper, toTinyTCType, toTinyTCImmediate from ..tiny_tensor_language import * from ...type import Datatype import importlib.util @@ -139,11 +139,18 @@ def generate(self, cpp, routineCache): d.beta, ptr_c, ldC, alignedA=d.alignedA, alignedC=d.alignedC, - prefetchName=d.prefetchName)) + prefetchName=d.prefetchName, + datatypeA=d.leftTerm.datatype, + datatypeB=d.rightTerm.datatype, + datatypeC=d.result.datatype)) elif isinstance(self._gemm_cfg, GemmForge): + assert d.result.datatype == d.leftTerm.datatype + assert d.result.datatype == d.rightTerm.datatype + ctype = d.result.datatype.ctype() + if gf_spec: - aux = BatchedOperationsAux(self._arch.typename) + aux = BatchedOperationsAux() matrix_a = gf.YatetoInterface.produce_dense_matrix((m, k), d.leftTerm.memoryLayout.bbox(), @@ -161,7 +168,7 @@ def generate(self, cpp, routineCache): transpose=False) try: - vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=self._arch.typename) + vm = gf.vm_factory(self._arch.name, self._arch.backend, fp_type=ctype) forge_generator = gf.GemmGenerator(vm) forge_generator.set(d.transA, d.transB, matrix_a, matrix_b, matrix_c, d.alpha, d.beta) routine_name = forge_generator.get_base_name() @@ -191,7 +198,7 @@ def generate(self, cpp, routineCache): raise RuntimeError('gemmforge module is not found. You can install it with pip3. ' 'e.g., pip3 install gemmforge') elif isinstance(self._gemm_cfg, tinytc): - aux = BatchedOperationsAux(self._arch.typename) + aux = BatchedOperationsAux() gemm = { 'M': m.size(), 'N': n.size(), @@ -209,12 +216,15 @@ def generate(self, cpp, routineCache): 'beta': self._beta(d.beta), 'transA': d.transA, 'transB': d.transB, + 'datatypeA': d.leftTerm.datatype, + 'datatypeB': d.rightTerm.datatype, + 'datatypeC': d.result.datatype, } kernel = tinytcGemmGen(self._arch, gemm) def call_arg(name, term, modified, offset): - return TinytcKernelArgument(name, term.name, term.is_compute_constant, term.is_temporary, modified, offset) + return TinytcKernelArgument(name, term.datatype, term.name, term.is_compute_constant, term.is_temporary, modified, offset) offset_a = self._offset(term=d.leftTerm, offset2=(m.start, k.start), transpose=d.transA) offset_b = self._offset(term=d.rightTerm, offset2=(k.start, n.start), transpose=d.transB) offset_c = self._offset(term=d.result, offset2=(m.start, n.start), transpose=False) @@ -222,8 +232,8 @@ def call_arg(name, term, modified, offset): call_arg('A', d.leftTerm, False, offset_a), call_arg('B', d.rightTerm, False, offset_b), call_arg('C', d.result, True, offset_c)] - routine_name = 'tinytc_wrapper_{typename}_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_tA{transA}_tB{transB}'.format(typename=self._arch.typename, **gemm) - wrapper = TinytcWrapper(kernel, args, self._arch.typename, routine_name) + routine_name = 'tinytc_wrapper_{datatypeA}_{datatypeB}_{datatypeC}_m{M}_n{N}_k{K}_ldA{LDA}_{addrA}_{distA}_ldB{LDB}_{addrB}_{distB}_ldC{LDC}_{addrC}_{distC}_alpha{alpha}_beta{beta}_tA{transA}_tB{transB}'.format(**gemm) + wrapper = TinytcWrapper(kernel, args, routine_name) cpp(wrapper.call()) prototype = wrapper.prototype() @@ -327,14 +337,17 @@ def __call__(self, routineName, fileName): assert self._gemmDescr['datatypeC'] == self._gemmDescr['datatypeA'] assert self._gemmDescr['datatypeC'] == self._gemmDescr['datatypeB'] - assert self._gemmDescr['datatypeC'] in [Datatype.F32, Datatype.F64] - - precision = { - Datatype.F32: 'F', - Datatype.F64: 'D' - }[self._gemmDescr['datatypeC']] if self._mode == 'pspamm': + assert self._gemmDescr['datatypeC'] in [Datatype.BF16, Datatype.F16, Datatype.F32, Datatype.F64] + + precision = { + Datatype.BF16: 'BF16', + Datatype.F16: 'H', + Datatype.F32: 'S', + Datatype.F64: 'D' + }[self._gemmDescr['datatypeC']] + pspamm_arch = cpu_arch if cpu_arch == 'a64fx': pspamm_arch = 'arm_sve512' @@ -380,6 +393,14 @@ def __call__(self, routineName, fileName): for key, val in self._blockSize.items(): argList.extend(['--' + key, val]) else: + assert self._gemmDescr['datatypeC'] in [Datatype.I16, Datatype.F32, Datatype.F64] + + precision = { + Datatype.I16: 'I16', + Datatype.F32: 'SP', + Datatype.F64: 'DP' + }[self._gemmDescr['datatypeC']] + libxsmm_arch = cpu_arch if cpu_arch in ['naples', 'rome', 'milan']: # names are Zen1, Zen2, Zen3, respectively @@ -404,7 +425,7 @@ def __call__(self, routineName, fileName): self._gemmDescr['alignedC'], libxsmm_arch, # libxsmm has no support for rome, hsw works well in practice self._gemmDescr['prefetch'], - precision + 'P' + precision ] class SparsityWrapper: def __init__(self, shape, spp): @@ -456,10 +477,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) # LIBXSMM header + datatype = self._gemmDescr['datatypeC'].ctype() if self._gemmDescr['prefetch'] == 'nopf': - return 'void {name}(const {type}* A, const {type}* B, {type}* C);'.format(name=routineName, type=self._arch.typename) + return 'void {name}(const {type}* A, const {type}* B, {type}* C);'.format(name=routineName, type=datatype) else: - return 'void {name}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch);'.format(name=routineName, type=self._arch.typename) + return 'void {name}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch);'.format(name=routineName, type=datatype) class GemmForgeWriter(GpuRoutineGenerator): def __init__(self, forge_generator, headers): @@ -515,6 +537,9 @@ def _kernel(self, routine_name): prefetch = self._gemmDescr['prefetch'] transA = self._gemmDescr['transA'] transB = self._gemmDescr['transB'] + datatypeA = self._gemmDescr['datatypeA'] + datatypeB = self._gemmDescr['datatypeB'] + datatypeC = self._gemmDescr['datatypeC'] flags = ["LIBXSMM_GEMM_FLAG_NONE"] if transA: @@ -550,7 +575,7 @@ def _kernel(self, routine_name): {prefetch_flag} // prefetch ); """.format(kernel_var_name=kernel_var_name, - prec=self._arch.typename, M=M, N=N, K=K, + prec=datatypeC.ctype(), M=M, N=N, K=K, ldA=ldA, ldB=ldB, ldC=ldC, alpha=alpha, beta=beta, flag=libxsmm_flag_str, @@ -578,7 +603,8 @@ def _call(self, routineName): """ def _functionSignature(self, routineName): - return 'void {routineName}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch)'.format(routineName=routineName, type=self._arch.typename) + datatypeC = self._gemmDescr['datatypeC'].ctype() + return 'void {routineName}(const {type}* A, const {type}* B, {type}* C, const {type}* A_prefetch, const {type}* B_prefetch, const {type}* C_prefetch)'.format(routineName=routineName, type=datatypeC) def __call__(self, routineName, fileName): func_signature = self._functionSignature(routineName) @@ -589,50 +615,56 @@ def __call__(self, routineName, fileName): return func_signature + ";" def tinytcGemmGen(arch, gd): - scalar_ty = ScalarType(FloatingType.f64) if arch.bytesPerReal == 8 else ScalarType(FloatingType.f32) - - Operand = namedtuple('Operand', ['name', 'addr', 'rows', 'cols', 'ld', 'dist']) + Operand = namedtuple('Operand', ['name', 'addr', 'rows', 'cols', 'ld', 'dist', 'datatype']) + def scalar_type(op): + return toTinyTCType(op.datatype) def data_type(op): - if op.addr == 'pointer_based': + scalar_ty = scalar_type(op) + if op.addr == AddressingMode.INDIRECT: return GroupType(MemrefType(scalar_ty, (op.rows, op.cols), (1, op.ld)), DYNAMIC) - elif op.addr == 'strided': + elif op.addr == AddressingMode.STRIDED: return MemrefType(scalar_ty, (op.rows, op.cols, DYNAMIC), (1, op.ld, op.dist)) - elif op.addr == 'none': + elif op.addr == AddressingMode.NONE: return MemrefType(scalar_ty, (op.rows, op.cols), (1, op.ld)) else: raise NameError(op.addr) def load_inst(op, batch, gid): zero = IntImmValue(IntegerType.index, 0) dyn = IntImmValue(IntegerType.index, DYNAMIC) - if op.addr == 'pointer_based': + if op.addr == AddressingMode.INDIRECT: return LoadInst(batch, [gid]) - elif op.addr == 'strided': + elif op.addr == AddressingMode.STRIDED: return SubviewInst(batch, [zero,zero,gid], [dyn,dyn,None]) - elif op.addr == 'none': + elif op.addr == AddressingMode.NONE: return SubviewInst(batch, [zero,zero], [dyn,dyn]) else: raise NameError(op.addr) - opA = Operand('A', gd['addrA'], gd['M'], gd['K'], gd['LDA'], gd['distA']) - opB = Operand('B', gd['addrB'], gd['K'], gd['N'], gd['LDB'], gd['distB']) - opC = Operand('C', gd['addrC'], gd['M'], gd['N'], gd['LDC'], gd['distC']) + opA = Operand('A', gd['addrA'], gd['M'], gd['K'], gd['LDA'], gd['distA'], gd['datatypeA']) + opB = Operand('B', gd['addrB'], gd['K'], gd['N'], gd['LDB'], gd['distB'], gd['datatypeB']) + opC = Operand('C', gd['addrC'], gd['M'], gd['N'], gd['LDC'], gd['distC'], gd['datatypeC']) T = lambda x: Transpose.t if x else Transpose.n tA = T(gd['transA']) tB = T(gd['transB']) - beta = gd['beta'] - - alpha = LocalValue(scalar_ty, 'alpha') - Abatch = LocalValue(data_type(opA), 'Abatch') - Bbatch = LocalValue(data_type(opB), 'Bbatch') - Cbatch = LocalValue(data_type(opC), 'Cbatch') + alphaV = gd['alpha'] + betaV = gd['beta'] + + scalar_ty_a = toTinyTCType(ocA.datatype) + scalar_ty_b = toTinyTCType(ocB.datatype) + scalar_ty_c = toTinyTCType(ocC.datatype) + alpha = toTinyTCImmediate(scalar_ty_c, alphaV) if isinstance(alphaV, (int, float)) else LocalValue(scalar_ty_c, 'alpha') + beta = toTinyTCImmediate(scalar_ty_c, betaV) if isinstance(betaV, (int, float)) else LocalValue(scalar_ty_c, 'beta') + Abatch = LocalValue(scalar_ty_a, 'Abatch') + Bbatch = LocalValue(scalar_ty_b, 'Bbatch') + Cbatch = LocalValue(scalar_ty_c, 'Cbatch') kernel = Function('gemm', [alpha, Abatch, Bbatch, Cbatch], None) bb = RegionBuilder() gid = bb.add(GroupIdInst()) A = bb.add(load_inst(opA, Abatch, gid)) B = bb.add(load_inst(opB, Bbatch, gid)) C = bb.add(load_inst(opC, Cbatch, gid)) - bb.add(GemmInst(tA, tB, alpha, A, B, FloatImmValue(scalar_ty, beta), C)) + bb.add(GemmInst(tA, tB, alpha, A, B, beta, C)) kernel.body = bb.get_product() AssignIdentifiers().visit(kernel) diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index e1a26d7..c3e77a1 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -171,7 +171,7 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target): datatype = dict() for scalar in scalarsP: self.KernelOutline._addTensor(scalar, scalars) - datatype[scalar.baseNameWithNamespace()] = scalar.datatype + datatype[scalar.baseNameWithNamespace()] = scalar.getDatatype(self._arch) for var in variables: self.KernelOutline._addTensor(var.tensor, tensors) bn = var.tensor.baseNameWithNamespace() @@ -218,7 +218,7 @@ def _addFromKO(cls, koEntries, entries): for key, value in koEntries.items(): if key not in entries: entries[key] = value - else: + elif entries[key] != value: entries[key] = entries[key] | value @@ -288,7 +288,8 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None): if target == 'gpu': # LinearAllocatorT controls external extra mem. allocated on gpu for tmp. variables - header(f'yateto::LinearAllocatorT<{self._arch.typename}> linearAllocator;') + # back-casted to char for now + header(f'yateto::LinearAllocatorT linearAllocator;') header.emptyline() @@ -303,7 +304,7 @@ def kernelArgs(base_name_with_namespace, groups, writable, is_constant, datatype container_type = f'{InitializerGenerator.CONTAINER_CLASS_NAME}<{typ}{ptr_type}>' header(f'{class_name}::{container_type} {base_name};') else: - header(f'{typ}{ptr_type} {base_name}{{nullptr}};') + header(f'{typ}{ptr_type} {base_name}{"{"}nullptr{"}"};') def scalarArgs(base_name_with_namespace, datatype, groups): prefix, base_name = Tensor.splitBasename(base_name_with_namespace) @@ -352,7 +353,7 @@ def generate_extra_offset_args(base_name_with_namespace, groups): if len(prefetch) > 0: with header.Struct(self.PREFETCHSTRUCT_NAME): for baseName, groups in prefetch.items(): - kernelArgs(baseName, groups, writable=False, is_constant=False, target='any') + kernelArgs(baseName, groups, writable=False, is_constant=False, datatype=self._arch.datatype, target='any') header('{} {};'.format(self.PREFETCHSTRUCT_NAME, self.PREFETCHVAR_NAME)) header.emptyline() @@ -465,7 +466,7 @@ def _devTensorKernelArgument(self, var, writable): elif writable[var.tensor.baseNameWithNamespace()]: return self._devPtrTensorName(var) else: - return f'const_cast<{self._arch.typename} const**>({self._devPtrTensorName(var)})' + return f'const_cast<{var.datatype.ctype()} const**>({self._devPtrTensorName(var)})' @classmethod def _nameS(cls, var): @@ -509,16 +510,17 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, factory = UnitTestFactory(cpp, self._arch, self._name, testFramework) for i,scalar in enumerate(scalars): - cpp('{} {} = {};'.format(self._arch.typename, self._tensorNameS(scalar), float(i+2))) + cpp('{} {} = {};'.format(scalar.getDatatype(self._arch).ctype(), self._tensorNameS(scalar), float(i+2))) for var in variables: factory.tensor(var.tensor, self._tensorName(var)) factory.temporary(self._name(var), var.memoryLayout().requiredReals(), var.datatype, iniZero=True) shape = var.memoryLayout().shape() - cpp('{supportNS}::DenseTensorView<{dim},{arch.typename},{arch.uintTypename}> {viewName}({utName}, {{{shape}}}, {{{start}}}, {{{shape}}});'.format( + cpp('{supportNS}::DenseTensorView<{dim},{datatype},{arch.uintTypename}> {viewName}({utName}, {{{shape}}}, {{{start}}}, {{{shape}}});'.format( supportNS = SUPPORT_LIBRARY_NAMESPACE, dim=len(shape), + datatype=var.datatype.ctype(), arch = self._arch, utName=self._name(var), viewName=self._viewName(var), @@ -550,15 +552,14 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, kernelTensorName = lambda var: self._devTensorKernelArgument(var, writable) cpp ( f'auto {self.QUEUE} = sycl::queue{{sycl::property::queue::in_order()}};' ) - cpp ( f'auto {self.TMP_MEM} = ({self._arch.typename}*) sycl::malloc_device({self.TMP_SIZE}, {self.QUEUE});' ) + cpp ( f'auto {self.TMP_MEM} = reinterpret_cast(sycl::malloc_device({self.TMP_SIZE}, {self.QUEUE}));' ) for var in variables: - cpp( f'auto {self._devTensorName(var)} = ({self._arch.typename}*) sycl::malloc_device(sizeof({self._tensorName(var)}), {self.QUEUE});' ) - cpp( f'auto {self._devPtrTensorName(var)} = ({self._arch.typename}**) sycl::malloc_device(sizeof({self._arch.typename}*), {self.QUEUE});' ) + cpp( f'auto {self._devTensorName(var)} = reinterpret_cast<{var.datatype.ctype()}*>(sycl::malloc_device(sizeof({self._tensorName(var)}), {self.QUEUE}));' ) + cpp( f'auto {self._devPtrTensorName(var)} = reinterpret_cast<{var.datatype.ctype()}**>(sycl::malloc_device(sizeof({var.datatype.ctype()}*), {self.QUEUE}));' ) cpp( f'{self.QUEUE}.memcpy({self._devTensorName(var)}, {self._tensorName(var)}, sizeof({self._tensorName(var)})).wait();' ) - cpp( f'{self.QUEUE}.memcpy({self._devPtrTensorName(var)}, &{self._devTensorName(var)}, sizeof({self._arch.typename}*)).wait();' ) + cpp( f'{self.QUEUE}.memcpy({self._devPtrTensorName(var)}, &{self._devTensorName(var)}, sizeof({var.datatype.ctype()}*)).wait();' ) cpp.emptyline() - cpp( '{}{}::{} {};'.format(kernel_prefix, OptimizedKernelGenerator.NAMESPACE, kernelClass, self.KERNEL_VAR) ) for var in scalars: cpp( '{}.{}{} = {};'.format(self.KERNEL_VAR, var.baseName(), self._groupIndex(var), self._tensorNameS(var)) ) @@ -608,12 +609,14 @@ class InitializerGenerator(object): class TensorView(object): ARGUMENT_NAME = 'values' + def __init__(self, datatype): + self._datatype = datatype + def typename(self, dim, arch): - return '::{}::{}<{},{},{}>'.format(SUPPORT_LIBRARY_NAMESPACE, type(self).__name__, dim, arch.typename, arch.uintTypename) + return f'::{SUPPORT_LIBRARY_NAMESPACE}::{type(self).__name__}<{dim},{self._datatype.ctype()},{arch.uintTypename}>' - @classmethod - def arguments(cls, arch): - return '{}* {}'.format(arch.typename, cls.ARGUMENT_NAME) + def arguments(self): + return f'{self._datatype.ctype()}* {self.ARGUMENT_NAME}' def generate(cpp, group, memLayout): raise NotImplementedError @@ -622,10 +625,10 @@ def listToInitializerList(self, lst): return '{{{}}}'.format(', '.join([str(l) for l in lst])) def formatArray(self, numberType, name, values, declarationOnly): - lhs = '{} {}[]'.format(numberType, name) + lhs = f'{numberType} {name}[]' if declarationOnly: - return '{} {};'.format(CONSTEXPR, lhs) - return '{} {} = {};'.format(MODIFIERS, lhs, self.listToInitializerList(values)) + return f'{CONSTEXPR} {lhs};' + return f'{MODIFIERS} {lhs} = {self.listToInitializerList(values)};' class DenseTensorView(TensorView): START_NAME = 'Start' @@ -650,7 +653,7 @@ class CSCMatrixView(TensorView): COLPTR_NAME = 'ColPtr' def typename(self, dim, arch): - return '::{}::{}<{},{}>'.format(SUPPORT_LIBRARY_NAMESPACE, type(self).__name__, arch.typename, arch.uintTypename) + return f'::{SUPPORT_LIBRARY_NAMESPACE}::{type(self).__name__}<{self._datatype.ctype()},{arch.uintTypename}>' def generate(self, cpp, memLayout, arch, index): cpp( 'return {}({}, {}, {}, {});'.format( @@ -667,9 +670,9 @@ def arrays(self, cpp, memLayout, arch, namespace, index, numberType, declaration def __init__(self, arch, tensors, scalars): self._arch = arch - self._numberType = '{} const'.format(self._arch.uintTypename) - self._realType = '{} const'.format(self._arch.typename) - self._realPtrType = self._realType + '*' + self._numberType = f'{self._arch.uintTypename} const'.format(self._arch.uintTypename) + self._realType = lambda datatype: f'{datatype.ctype()} const' + self._realPtrType = lambda datatype: self._realType(datatype) + '*' self._scalarCollect = collections.OrderedDict() self._collect = collections.OrderedDict() for tensor in tensors: @@ -704,12 +707,13 @@ def __init__(self, arch, tensors, scalars): maxIndexScalar = {baseName: tuple(map(max, *groups.keys())) if len(groups) > 1 else next(iter(groups.keys())) for baseName, groups in self._scalarCollect.items()} self._groupSizeScalar = {baseName: tuple(map(lambda x: x+1, mi)) for baseName, mi in maxIndexScalar.items()} - def _tensorViewGenerator(self, memoryLayout): + def _tensorViewGenerator(self, tensor): + memoryLayout = tensor.memoryLayout() memLayoutMap = { 'DenseMemoryLayout': self.DenseTensorView, 'CSCMemoryLayout': self.CSCMatrixView } - return memLayoutMap[type(memoryLayout).__name__]() + return memLayoutMap[type(memoryLayout).__name__](tensor.getDatatype(self._arch)) def iterate_collect(self): cur_namespace = '' @@ -831,49 +835,54 @@ def _init(self, cpp, baseName, baseNameWithoutNamespace, name, tensors, declarat if declarationOnly: for group,tensor in tensors.items(): ml = tensor.memoryLayout() - tv = self._tensorViewGenerator(ml) + tv = self._tensorViewGenerator(tensor) tv.arrays(cpp, ml, self._arch, name, index(group), self._numberType, True) valueNames = dict() for group,tensor in tensors.items(): values = tensor.values() memLayout = tensor.memoryLayout() + datatype = tensor.getDatatype(self._arch) if values is not None: - memory = ['0.']*memLayout.requiredReals() + memory = [datatype.literal(0)]*memLayout.requiredReals() for idx,x in values.items(): - memory[memLayout.address(idx)] = x - valuesName = '{}{}{}'.format(name, self.VALUES_BASENAME, index(group)) - valueNames[group] = ['&{}[0]'.format(valuesName)] - cpp('{} {}[] = {{{}}};'.format(self._realType, valuesName, ', '.join(memory))) + memory[memLayout.address(idx)] = datatype.literal(x) + valuesName = f'{name}{self.VALUES_BASENAME}{index(group)}' + valueNames[group] = [f'&{valuesName}[0]'] + cpp('{} {}[] = {{{}}};'.format(self._realType(datatype), valuesName, ', '.join(memory))) if len(valueNames) > 1: - self._array(cpp, self._realPtrType, name + self.VALUES_BASENAME, valueNames, groupSize, alwaysArray=False, constexpr=False, static=False) + _,prototensor = next(iter(tensors.items())) + datatype = prototensor.getDatatype(self._arch) + self._array(cpp, self._realPtrType(datatype), name + self.VALUES_BASENAME, valueNames, groupSize, alwaysArray=False, constexpr=False, static=False) else: with cpp.Struct('{0} : {1}::{0}'.format(baseNameWithoutNamespace, self.TENSOR_NAMESPACE)): for group,tensor in tensors.items(): ml = tensor.memoryLayout() - tv = self._tensorViewGenerator(ml) + tv = self._tensorViewGenerator(tensor) tv.arrays(cpp, ml, self._arch, name, index(group), self._numberType, False) nValueArrays = 0 for group,tensor in tensors.items(): values = tensor.values() + datatype = tensor.getDatatype(self._arch) if values is not None: - name = '{}{}'.format(self.VALUES_BASENAME, index(group)) + name = f'{self.VALUES_BASENAME}{index(group)}' aligned = '' if tensor.memoryLayout().alignedStride(): - aligned = ' __attribute__((aligned({})))'.format(self._arch.alignment) - cpp('{} {} {}[]{};'.format(STATIC, self._realType, name, aligned)) + aligned = f' __attribute__((aligned({self._arch.alignment})))' + cpp(f'{STATIC} {self._realType(datatype)} {name}[]{aligned};') nValueArrays += 1 if nValueArrays > 1: - cpp('{} {} {}[];'.format(STATIC, self._realPtrType, self.VALUES_BASENAME)) + cpp(f'{STATIC} {self._realPtrType(datatype)} {self.VALUES_BASENAME}[];') cpp.emptyline() - viewArgs = self.TensorView.arguments(self._arch) if len(groupSize) == 0: - ml = next(iter(tensors.values())).memoryLayout() - tv = self._tensorViewGenerator(ml) + prototensor = next(iter(tensors.values())) + ml = prototensor.memoryLayout() + tv = self._tensorViewGenerator(prototensor) + viewArgs = tv.arguments() with cpp.Struct(self.VIEW_STRUCT_NAME): - cpp('typedef {} {};'.format(tv.typename(len(ml.shape()), self._arch), self.VIEW_TYPE_NAME)) - with cpp.Function(self.VIEW_FUN_NAME, arguments=viewArgs, returnType='{} {}'.format(STATIC_INLINE, self.VIEW_TYPE_NAME)): + cpp(f'using {self.VIEW_TYPE_NAME} = {tv.typename(len(ml.shape()), self._arch)};') + with cpp.Function(self.VIEW_FUN_NAME, arguments=viewArgs, returnType=f'{STATIC_INLINE} {self.VIEW_TYPE_NAME}'): tv.generate(cpp, ml, self._arch, None) else: typedArgs = typedNdArgs(len(groupSize), self._arch.uintTypename) @@ -882,7 +891,8 @@ def _init(self, cpp, baseName, baseNameWithoutNamespace, name, tensors, declarat if len(groupSize) > 0: for group,tensor in tensors.items(): ml = tensor.memoryLayout() - tv = self._tensorViewGenerator(ml) + tv = self._tensorViewGenerator(tensor) + viewArgs = tv.arguments() typename = tv.typename(len(ml.shape()), self._arch) special = ','.join(str(g) for g in group) cpp('template<>') diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index b2045a9..9fb6212 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -96,7 +96,6 @@ def setWritable(self, name): for v in self._variables: v.setWritable(name) - class ProgramAction(object): def __init__(self, result, term, add, scalar=None, condition=True): self.result = result @@ -142,6 +141,12 @@ def substituted(self, when, by, result = True, term = True): def setVariablesWritable(self, name): self.result.setWritable(name) self.term.setWritable(name) + + def getCondition(self): + if isinstance(self.condition, CNFCondition): + return self.condition + else: + return CNFCondition(self.condition) # TODO: probably should be a subclass of ProgramAction @@ -194,6 +199,7 @@ def __init__(self, action): self.initBuffer = None self.bufferMap = None + # a rather primitive CNF implementation. # do not overuse (i.e. avoid conditional assigns where possible) @@ -222,10 +228,15 @@ def __repr__(self): return f'[{", ".join(formatvar(var) for var in self.variables)}]' def ccode(self): + if not self.fulfilled and len(self.variables) == 0: + return 'false' # for now, only allow scalarly-indexed variables printvar = lambda var: f'{var}' if isinstance(var, Scalar) else f'{var}[{var.memoryLayout().addressString(Indices())}]' formatvar = lambda name: f'{printvar(name)}' if self.variables[name] else f'!{printvar(name)}' return f'({" || ".join(formatvar(var) for var in self.variables)})' + + def variableIterator(self): + return (var for var in self.variables if isinstance(var, Variable)) class CNFCondition: def __init__(self, data): @@ -298,7 +309,14 @@ def __repr__(self): return f'{self.clauses}' def ccode(self): + if self.tautology(): + return 'true' + elif self.unfulfillable(): + return 'false' return f'({" && ".join(clause.ccode() for clause in self.clauses)})' + + def variables(self): + return {var for clause in self.clauses for var in clause.variableIterator()} class LiveSet: def __init__(self, data: dict): diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index 8c73249..db1e346 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -124,7 +124,7 @@ def visit(self, cfg): V = set() for pp in cfg: if pp.action: - V = V | pp.action.result.variables() | pp.action.variables() + V = V | pp.action.result.variables() | pp.action.variables() | pp.action.getCondition().variables() return sorted([var for var in V if var.isGlobal()], key=lambda x: str(x)) class SortedPrefetchList(object): diff --git a/yateto/functions.py b/yateto/functions.py index b4dfa37..e9d01c8 100644 --- a/yateto/functions.py +++ b/yateto/functions.py @@ -43,9 +43,15 @@ def greater(x, y): return node.Elementwise(ops.CmpGt(), x, y) def greater_equal(x, y): return node.Elementwise(ops.CmpGe(), x, y) # extra reduction functions; e.g. for input to `where` -def reductionSum(term, indices): return node.Reduction(ops.Add(), term, indices) -def reductionMul(term, indices): return node.Reduction(ops.Mul(), term, indices) -def reductionAnd(term, indices): return node.Reduction(ops.And(), term, indices) -def reductionOr(term, indices): return node.Reduction(ops.Or(), term, indices) +def reduction(op, term, indices): + if len(indices) == 0: + return term + else: + reduction(op, node.Reduction(op, term, indices[0]), indices[1:]) + +def sum(term, indices): return reduction(ops.Add(), term, indices) +def product(term, indices): return reduction(ops.Mul(), term, indices) +def all(term, indices): return reduction(ops.And(), term, indices) +def any(term, indices): return reduction(ops.Or(), term, indices) def cast(x, dtype): return node.Elementwise(ops.Typecast(dtype), x) diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index 2afbb78..95aebef 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -1,5 +1,6 @@ from typing import List from abc import ABC, abstractmethod +from .type import Datatype import operator class Preference(object): @@ -18,31 +19,40 @@ def archSupported(self): return True @abstractmethod - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): pass @abstractmethod def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): pass + + # shortcut for legacy reasons + @classmethod + def _equalType(cls, datatypeA, datatypeB, datatypeC, types=(Datatype.F32, Datatype.F64)): + return datatypeA == datatypeC and datatypeB == datatypeC and datatypeC in types class BLASlike(GemmTool): - def __init__(self, operation_name: str, includes: List[str], c_code_init: str = ''): - super().__init__(operation_name, includes) + def __init__(self, prefix, includes: List[str], c_code_init: str = ''): + super().__init__(prefix, includes) self.c_code_init = c_code_init - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): return Preference.MODERATE def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return (not sparseA and not sparseB and target == 'cpu') + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return (not sparseA and not sparseB and target == 'cpu' and self._equalType(datatypeA, datatypeB, datatypeC)) def bool2Trans(self, trans): return 'Cblas{}Trans'.format('' if trans else 'No') def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, - alignedA, alignedC, prefetchName): + alignedA, alignedC, datatypeA, datatypeB, datatypeC, prefetchName): + precision = { + Datatype.F32: 's', + Datatype.F64: 'd' + }[datatypeC] parameters = [ 'CblasColMajor', self.bool2Trans(transA), @@ -51,39 +61,43 @@ def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, alpha, A, ldA, B, ldB, beta, C, ldC] - return '{}({});'.format(self.operation_name, ', '.join(str(p) for p in parameters)) + return '{}_{}gemm({});'.format(self.prefix, precision, ', '.join(str(p) for p in parameters)) class MKL(BLASlike): def __init__(self, arch): self._arch = arch - super().__init__('cblas_{}gemm'.format(arch.precision.lower()), ['mkl_cblas.h']) + super().__init__('cblas', ['mkl_cblas.h']) def archSupported(self): return self._arch.host_name.lower() in {'snb', 'hsw', 'skx', 'knl'} or self._arch.host_name.lower().startswith('avx') class OpenBLAS(BLASlike): def __init__(self, arch): - super().__init__('cblas_{}gemm'.format(arch.precision.lower()), ['cblas.h']) + super().__init__('cblas', ['cblas.h']) class BLIS(BLASlike): def __init__(self, arch): - super().__init__('bli_{}gemm'.format(arch.precision.lower()), ['blis.h'], '{0} _blis_alpha; {0} _blis_beta;'.format(arch.typename)) - self._typename = arch.typename + super().__init__('bli', ['blis.h']) def bool2Trans(self, trans): return 'BLIS{}TRANSPOSE'.format('_' if trans else '_NO_') def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, - alignedA, alignedC, prefetchName): - init = '_blis_alpha = {}; _blis_beta = {};'.format(alpha, beta) + alignedA, alignedC, datatypeA, datatypeB, datatypeC, prefetchName): + precision = { + Datatype.F32: 's', + Datatype.F64: 'd' + }[datatypeC] + initA = f'{datatypeC.ctype()} _blis_alpha = {alpha};' + initB = f'{datatypeC.ctype()} _blis_beta = {beta};' parameters = [ self.bool2Trans(transA), self.bool2Trans(transB), M, N, K, - '&_blis_alpha', 'const_cast<{}*>({})'.format(self._typename, A), 1, ldA, - 'const_cast<{}*>({})'.format(self._typename, B), 1, ldB, + '&_blis_alpha', f'const_cast<{datatypeA.ctype()}*>({A})', 1, ldA, + f'const_cast<{datatypeB.ctype()}*>({B})', 1, ldB, '&_blis_beta', C, 1, ldC] - return '{} {}({});'.format(init, self.operation_name, ', '.join(str(p) for p in parameters)) + return '{{ {}{} {}_{}gemm({}); }}'.format(initA, initB, self.prefix, precision, ', '.join(str(p) for p in parameters)) class Eigen(BLASlike): def __init__(self, arch): @@ -91,24 +105,24 @@ def __init__(self, arch): self._arch = arch def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): return (not sparseA and not sparseB and target == 'cpu') def bool2Trans(self, trans): return '.transpose()' if trans else '' def sizeTrans(self, rows, cols, trans): - return '{},{}'.format(cols,rows) if trans else '{},{}'.format(rows,cols) + return f'{cols},{rows}' if trans else f'{rows},{cols}' def align(self, ld): aligned = 'Unaligned' if self._arch.checkAlignment(ld) and self._arch.alignment in [16,32,64,128]: - aligned = 'Aligned{}'.format(self._arch.alignment) + aligned = f'Aligned{self._arch.alignment}' return aligned def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, - alignedA, alignedC, prefetchName): + alignedA, alignedC, datatypeA, datatypeB, datatypeC, prefetchName): AxB = '{alpha}_mapA{transA}*_mapB{transB}'.format( alpha=str(alpha) + '*' if alpha != 1.0 else '', transA=self.bool2Trans(transA), transB=self.bool2Trans(transB), @@ -124,12 +138,15 @@ def call(self, transA, transB, M, N, K, alpha, A, ldA, B, ldB, beta, C, ldC, using Eigen::Matrix; using Eigen::Map; using Eigen::Stride; - Map,Eigen::{alignA},Stride<{ldA},1>> _mapA(const_cast<{prec}*>({A})); - Map,Eigen::Unaligned,Stride<{ldB},1>> _mapB(const_cast<{prec}*>({B})); - Map,Eigen::{alignC},Stride<{ldC},1>> _mapC({C}); + Map,Eigen::{alignA},Stride<{ldA},1>> _mapA(const_cast<{precA}*>({A})); + Map,Eigen::Unaligned,Stride<{ldB},1>> _mapB(const_cast<{precB}*>({B})); + Map,Eigen::{alignC},Stride<{ldC},1>> _mapC({C}); {code} }} - """.format(prec=self._arch.typename, M=M, N=N, + """.format(precA=datatypeA.ctype(TypeFlavor.EIGEN), + precB=datatypeB.ctype(TypeFlavor.EIGEN), + precC=datatypeC.ctype(TypeFlavor.EIGEN), + M=M, N=N, sizeA=self.sizeTrans(M,K,transA), sizeB=self.sizeTrans(K,N,transB), ldA=ldA, ldB=ldB, ldC=ldC, A=A, B=B, C=C, @@ -163,7 +180,7 @@ def __init__(self, arch, cmd: str = 'libxsmm_gemm_generator', threshold: int = 1 self._threshold = threshold self._arch = arch - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if (m*n*k)**(1./3.) <= self._threshold: return Preference.HIGH return Preference.LOW @@ -173,13 +190,13 @@ def archSupported(self): return self._arch.host_name.lower() in supported_set def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): # Note: # Libxsmm falls back to blas for transA and more general alpha/beta # See e.g. here: # https://libxsmm.readthedocs.io/en/latest/libxsmm_qna/#what-is-a-small-matrix-multiplication # https://github.com/hfp/libxsmm/issues/396#issuecomment-674741063 - return self.archSupported() and not (sparseA or sparseB) and (not transA) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' + return self.archSupported() and not (sparseA or sparseB) and (not transA) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' and self._equalType(datatypeA, datatypeB, datatypeC) # TODO: no, there's more class LIBXSMM(CodeGenerator): def __init__(self, arch, cmd: str = 'libxsmm_gemm_generator', threshold: int = 128): @@ -191,10 +208,10 @@ def archSupported(self): return self._arch.host_name.lower() in supported_set def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return self.archSupported() and not (sparseA and sparseB) and (not transA and not transB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return self.archSupported() and not (sparseA and sparseB) and (not transA and not transB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'cpu' and (self._equalType(datatypeA, datatypeB, datatypeC) or (self._equalType(datatypeA, datatypeB, datatypeC, (Datatype.I16,)) and not sparseA and not sparseB)) - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if sparseA: return Preference.LOW if sparseB: @@ -213,11 +230,11 @@ def archSupported(self): return self._arch.host_name.lower() in supported_set def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): # NOTE: PSpaMM 0.3.0+ supports SIMD-aligned block sparsity in A (which is currently covered by sparseA + alignedA) - return self.archSupported() and alignedC and alignedA and (not transA and not transB) and target == 'cpu' + return self.archSupported() and alignedC and alignedA and (not transA and not transB) and target == 'cpu' and self._equalType(datatypeA, datatypeB, datatypeC, [Datatype.BF16, Datatype.F16, Datatype.F32, Datatype.F64]) - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if sparseB: return Preference.HIGH if sparseA and alignedA: @@ -243,10 +260,10 @@ def archSupported(self): return self._arch.backend.lower() in {'cuda', 'hip', 'oneapi', 'acpp', 'hipsycl'} def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return self.archSupported() and not (sparseA or sparseB) and target == 'gpu' + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return self.archSupported() and not (sparseA or sparseB) and target == 'gpu' and self._equalType(datatypeA, datatypeB, datatypeC) - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): if sparseA and sparseB: return Preference.LOWEST if not transA: @@ -260,15 +277,15 @@ def __init__(self, arch): super().__init__('', [], '', arch) self._arch = arch - def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC): + def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): return Preference.HIGHEST def archSupported(self): return self._arch.backend.lower() in {'oneapi'} def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): - return self.archSupported() and not (sparseA or sparseB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): + return self.archSupported() and not (sparseA or sparseB) and alpha == 1.0 and beta in [0.0, 1.0] and target == 'gpu' and self._equalType(datatypeA, datatypeB, datatypeC) # TODO: really? class GeneratorCollection(object): @@ -277,13 +294,13 @@ def __init__(self, gemmTools: List[GemmTool]): self.selected = set() def getGemmTool(self, m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): tools = dict() for gemmTool in reversed(self.gemmTools): if gemmTool.supported(m, n, k, sparseA, sparseB, transA, transB, alpha, - beta, alignedA, alignedC, target): + beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): tools[gemmTool.preference(m, n, k, sparseA, sparseB, transA, transB, alpha, beta, - alignedA, alignedC)] = gemmTool + alignedA, alignedC, datatypeA, datatypeB, datatypeC, target)] = gemmTool select = None if tools: diff --git a/yateto/type.py b/yateto/type.py index 1c00db3..1c8d18f 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -72,15 +72,16 @@ def size(self): def literal(self, value): # TODO: BF16, F16 + # (note: the extra lambda mapping is needed to prevent type errors) return { - Datatype.BOOL: 'true' if value else 'false', - Datatype.I8: f'{int(value)}', - Datatype.I16: f'{int(value)}', - Datatype.I32: f'{int(value)}', - Datatype.I64: f'{int(value)}LL', - Datatype.F32: f'{float(value):.16}f', - Datatype.F64: f'{float(value):.16}' - }[self] + Datatype.BOOL: lambda value: 'true' if value else 'false', + Datatype.I8: lambda value: f'{int(value)}', + Datatype.I16: lambda value: f'{int(value)}', + Datatype.I32: lambda value: f'{int(value)}', + Datatype.I64: lambda value: f'{int(value)}LL', + Datatype.F32: lambda value: f'{float(value):.16}f', + Datatype.F64: lambda value: f'{float(value):.16}' + }[self](value) class AddressingMode(Enum): DIRECT = 0 @@ -96,7 +97,27 @@ def pointer_type(self): AddressingMode.SCALAR: '', }[self] -class AbstractType(object): +class Symbol(object): + def __init__(self, datatype): + # datatype == None is treated as datatype == arch.datatype + self.datatype = datatype + + def getDatatype(self, arch): + return arch.datatype if self.datatype is None else self.datatype + +class ScalarMixin: + pass + +class ImmediateScalar(Symbol, ScalarMixin): + def __init__(self, data, datatype): + super().__init__(datatype) + self.data = data + +class AbstractType(Symbol): + def __init__(self, name, datatype): + super().__init__(datatype) + self._name = name + @classmethod def isValidName(cls, name): return re.match(cls.VALID_NAME, name) is not None @@ -107,21 +128,15 @@ def name(self): class IdentifiedType(AbstractType): BASE_NAME = r'[a-zA-Z]\w*' GROUP_INDEX = r'(0|[1-9]\d*)' - GROUP_INDICES = r'\(({0}(,{0})*)\)'.format(GROUP_INDEX) - VALID_NAME = r'^{}({})?$'.format(BASE_NAME, GROUP_INDICES) + GROUP_INDICES = rf'\(({GROUP_INDEX}(,{GROUP_INDEX})*)\)' + VALID_NAME = rf'^{BASE_NAME}({GROUP_INDICES})?$' def __init__(self, name, namespace=None, datatype=None): + super().__init__(name, datatype) if not self.isValidName(name): - raise ValueError('Invalid name (must match regexp {}): {}'.format(self.VALID_NAME, name)) - - self._name = name - self.namespace = namespace - - # datatype == None is treated as datatype == arch.datatype - self.datatype = datatype + raise ValueError(f'Invalid name (must match regexp {self.VALID_NAME}): {name}') - def getDatatype(self, arch): - return arch.datatype if self.datatype is None else self.datatype + self.namespace = namespace def __str__(self): return self._name @@ -165,7 +180,7 @@ def nameWithNamespace(self): def __hash__(self): return hash(self._name) -class Scalar(IdentifiedType): +class Scalar(IdentifiedType, ScalarMixin): def __init__(self, name, namespace=None, datatype=None): super().__init__(name, namespace=namespace, datatype=datatype) @@ -190,7 +205,7 @@ def __init__(self, raise ValueError('shape must not contain entries smaller than 1') if not self.isValidName(name): - raise ValueError('Tensor name invalid (must match regexp {}): {}'.format(self.VALID_NAME, name)) + raise ValueError(f'Tensor name invalid (must match regexp {self.VALID_NAME}): {name}') self._name = name self._shape = shape From 96e66ed7239c50e2b6e0d5eb01c5f40ece837cc2 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Tue, 12 Aug 2025 18:16:31 +0200 Subject: [PATCH 10/18] Fix a post-merge bug --- yateto/codegen/visitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index ae721b3..734f4f7 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -584,7 +584,7 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, kernelTensorName = lambda var: self._devTensorKernelArgument(var, writable) stream_new(self.STREAM) - data_malloc(self.TMP_MEM, self.TMP_SIZE, f'{char}*', self.STREAM) + data_malloc(self.TMP_MEM, self.TMP_SIZE, f'char*', self.STREAM) for var in variables: data_malloc(self._devTensorName(var), f'sizeof({self._tensorName(var)})', f'{var.datatype.ctype()}*', self.STREAM) data_malloc(self._devPtrTensorName(var), f'sizeof({var.datatype.ctype()}*)', f'{var.datatype.ctype()}**', self.STREAM) From 8fa3d00607388d4bb4e4f628a47645624050f80e Mon Sep 17 00:00:00 2001 From: David Schneller Date: Thu, 14 Aug 2025 03:34:35 +0200 Subject: [PATCH 11/18] Ease the allocation initialization --- include/yateto/LinearAllocator.h | 8 +++++++- yateto/codegen/common.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/include/yateto/LinearAllocator.h b/include/yateto/LinearAllocator.h index d3621ab..7fee7dd 100644 --- a/include/yateto/LinearAllocator.h +++ b/include/yateto/LinearAllocator.h @@ -13,6 +13,12 @@ struct LinearAllocatorT { userSpaceMem = ptr; } + template + void initialize(S* ptr) { + isInit = true; + userSpaceMem = reinterpret_cast(ptr); + } + T* allocate(size_t size) { assert(isInit && "YATETO: Temporary-Memory manager hasn't been initialized"); int currentByteCount = byteCount; @@ -31,5 +37,5 @@ struct LinearAllocatorT { bool isInit{false}; T *userSpaceMem{nullptr}; }; -} // yateto +} // namespace yateto #endif // YATETO_LINEAR_ALLOCATED_H_ \ No newline at end of file diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index 299fd28..8985456 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -1,5 +1,6 @@ from __future__ import annotations from .. import aspp +from ..type import AddressingMode from ..ast.indices import BoundingBox from ..ast.log import splitByDistance from .tiny_tensor_language import Dump, Function, IntegerType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst From 7b0f5348b4835bea456e31b674914f0eb5ef6ced Mon Sep 17 00:00:00 2001 From: David Schneller Date: Sat, 23 Aug 2025 20:32:58 +0200 Subject: [PATCH 12/18] Begin updating the Yateto<->TensorForge interface --- yateto/codegen/factory.py | 192 ++++++++++++++++++++++++++++++---- yateto/controlflow/visitor.py | 7 +- 2 files changed, 173 insertions(+), 26 deletions(-) diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index 52cd692..d930852 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -4,7 +4,7 @@ from ..memory import DenseMemoryLayout from .common import forLoops, TensorDescription, IndexedTensorDescription, BatchedOperationsAux from . import copyscaleadd, indexsum, log, product, fused_gemms, elementwise -from ..type import Datatype +from ..type import Datatype, AddressingMode, Scalar class KernelFactory(object): ERROR_NAME = '_error' @@ -341,7 +341,7 @@ def tensor(self, node, resultName, maxValue = 512): nz = spp.nonzero() for entry in zip(*nz): addr = ml.address(entry) - memory[addr] = str(float((addr + self._rand) % maxValue)+1.0) + memory[addr] = datatype.literal(((addr + self._rand) % maxValue)+1.0) self.temporary(resultName, size, datatype, memory=memory) self._rand += 1 @@ -360,6 +360,9 @@ def add_linear_operation(self, dest, ops, target, permute, add): def add_operation(self, description): pass + def add_tensor(self, description): + pass + class ExportFactory(KernelFactory): @classmethod def makeFactory(cls, generator): @@ -368,6 +371,8 @@ def makeFactory(cls, generator): def __init__(self, generator, cpp, arch, target): super().__init__(cpp, arch, target) self.generator = generator + self.tensors = {} + self.scalarcounter = 0 def post_generate(self, routine_cache): self.generator.generate(self._cpp, routine_cache) @@ -375,46 +380,174 @@ def post_generate(self, routine_cache): def allocateTemporary(self): return False + def _nodeTensor(self, tensor, node): + return self._handleTensorDesc(IndexedTensorDescription.fromNode(tensor, node)) + + def _varTensor(self, var, indices): + return self._handleTensorDesc(IndexedTensorDescription.fromVar(var, indices)) + + def _handleAddressing(self, desc): + if desc.addressing is None: + addressing = BatchedOperationsAux.deduce_addresing(desc) + else: + addressing = desc.addressing + + # & == deref + # n == current element + # N == element size + # o == extraOffset + # *,+ == default add and mul + # read left to right + if addressing == AddressingMode.DIRECT: + return '&' + elif addressing == AddressingMode.STRIDED: + return 'n*N+o&' + elif addressing == AddressingMode.INDIRECT: + return 'n&+o&' + elif addressing == AddressingMode.SCALAR: + return '' + + raise NotImplementedError(addressing) + + def _handleTensorDesc(self, tensorIndexed: IndexedTensorDescription): + if isinstance(tensorIndexed.memoryLayout, DenseMemoryLayout): + shape = list(tensorIndexed.memoryLayout.shape()) + shapeXt = [max(rng.stop - rng.start, shp) for rng, shp in zip(tensorIndexed.memoryLayout.bbox(), shape)] + storage = { + 'shape': shapeXt, + 'type': 'bbox', + 'start': [rng.start for rng in tensorIndexed.memoryLayout.bbox()], + 'sizes': [rng.stop - rng.start for rng in tensorIndexed.memoryLayout.bbox()] + } + else: + assert False + + eqsppnz = tensorIndexed.eqspp.nonzero() + spp = [elem for elem in zip(*eqsppnz)] + + values = None if tensorIndexed.values is None else list(tensorIndexed.values) + + tensor = { + 'name': tensorIndexed.name, + 'addressing': self._handleAddressing(tensorIndexed), + #'eqspp': spp, + 'datatype': str(tensorIndexed.datatype), + 'storage': storage, + 'values': values, + 'flags': { + 'temporary': tensorIndexed.is_temporary, + 'constant': tensorIndexed.is_compute_constant + } + } + + return self._handleTensor(tensor, spp, tensorIndexed.indices) + + def _scalarTensor(self, scalar): + if isinstance(scalar, (int, float)): # TODO numpy types + name = f'_scalar{self.scalarcounter}' + self.scalarcounter += 1 + + tensor = { + 'name': name, + 'addressing': '', + 'eqspp': (), + 'datatype': str(self._arch.datatype), + 'storage': { + 'shape': (), + 'type': 'full' + }, + 'values': { + (): scalar + }, + 'flags': { + 'temporary': False, + 'constant': True + } + } + elif isinstance(scalar, Scalar): + tensor = { + 'name': scalar.name(), + 'addressing': '', + 'eqspp': (), + 'datatype': str(scalar.getDatatype(self._arch)), + 'storage': { + 'shape': (), + 'type': 'full' + }, + 'values': None, + 'flags': { + 'temporary': False, + 'constant': True + } + } + else: + assert False + + return self._handleTensor(tensor, (), ()) + + def _handleTensor(self, tensor, eqspp, indices): + if tensor['name'] not in self.tensors: + self.tensors[tensor['name']] = tensor + self.generator.add_tensor(tensor) + else: + assert tensor == self.tensors[tensor['name']] + + return { + 'name': tensor['name'], + 'spp': eqspp, + 'indices': indices + } + + def _handleCondition(self, condition): + out = [] + for clause in condition.clauses: + outclause = [] + for var in clause.variables: + tensor = self._varTensor(clause.variables[var], ()) + outclause += [tensor] + out += [outclause] + return out + def create_Elementwise(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): - result = IndexedTensorDescription.fromNode(result, node) - preArgs = [IndexedTensorDescription.fromNode(argument, term) for argument, term in zip(arguments, node)] + result = self._nodeTensor(result, node) + preArgs = [self._nodeTensor(argument, term) for argument, term in zip(arguments, node)] args = node.fillTerms(preArgs) description = { 'type': 'elementwise', 'result': result, 'args': args, + 'condition': self._handleCondition(condition), 'linear': { - 'alpha': scalar, + 'alpha': self._scalarTensor(scalar), 'add': add, }, - 'optype': node.optype + 'optype': str(node.optype) } return self.generator.add_operation(description) def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 1 - makeNode = IndexedTensorDescription.fromNode - result = makeNode(result, node) - argnodes = [makeNode(arguments[0], node.term())] + result = self._nodeTensor(result, node) + argnodes = [self._nodeTensor(arguments[0], node.term())] description = { 'type': 'reduction', 'result': result, 'args': argnodes, + 'condition': self._handleCondition(condition), 'linear': { - 'alpha': scalar, + 'alpha': self._scalarTensor(scalar), 'add': add, }, - 'optype': node.optype + 'optype': str(node.optype) } return self.generator.add_operation(description) def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 - makeNode = IndexedTensorDescription.fromNode - argnodes = [makeNode(arguments[0], node.leftTerm()), makeNode(arguments[1], node.rightTerm())] - return self.handleLinear(makeNode(result, node), argnodes, add, scalar, node.transA(), node.transB()) + argnodes = [self._nodeTensor(arguments[0], node.leftTerm()), self._nodeTensor(arguments[1], node.rightTerm())] + return self.handleLinear(self._nodeTensor(result, node), argnodes, condition, add, scalar, node.transA(), node.transB()) def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): return create_Reduction(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) @@ -424,39 +557,52 @@ def create_Product(self, node, result, arguments, condition, add, scalar, prefet def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] - return self.handleLinear(IndexedTensorDescription.fromVar(result, node.indices), [IndexedTensorDescription.fromVar(term, node.term().indices)], add, scalar, False, False) + return self.handleLinear(self._varTensor(result, node.indices), [self._varTensor(term, node.term().indices)], condition, add, scalar, False, False) def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): - return self.handleLinear(IndexedTensorDescription.fromVar(result, self._indices(result)), [IndexedTensorDescription.fromVar(term, self._indices(term))], add, scalar, False, False) + return self.handleLinear(self._varTensor(result, self._indices(result)), [self._varTensor(term, self._indices(term))], condition, add, scalar, False, False) def getIndices(self, dest, ops): if dest is None: target_indices = [] else: - target_indices = dest.indices + target_indices = dest['indices'] indexindex = {index:i for i, index in enumerate(target_indices)} contract_counter = -1 for op in ops: - for index in op.indices: + for index in op['indices']: if index not in indexindex: indexindex[index] = contract_counter contract_counter -= 1 - target = [[indexindex[index] for index in op.indices] for op in ops] - permute = [[i for i,_ in enumerate(op.indices)] for op in ops] + target = [[indexindex[index] for index in op['indices']] for op in ops] + permute = [[i for i,_ in enumerate(op['indices'])] for op in ops] return target, permute - def handleLinear(self, dest, ops, add, scalar, transposeA, transposeB): + def handleLinear(self, dest, ops, condition, add, scalar, transposeA, transposeB): # convert indices to loop numbers target, permute = self.getIndices(dest, ops) if not (scalar == 1 or scalar == 1.0): - ops += [scalar] + ops += [self._scalarTensor(scalar)] target += [[]] permute += [[]] - return self.generator.add_linear_operation(dest, ops, target, permute, add) + description = { + 'type': 'multilinear', + 'result': dest, + 'args': ops, + 'condition': self._handleCondition(condition), + 'permute': permute, + 'target': target, + 'linear': { + 'alpha': self._scalarTensor(scalar), + 'add': add, + }, + # 'optype': node.optype + } + return self.generator.add_operation(description) diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index db1e346..d8dcaa3 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -78,11 +78,12 @@ def visit_Assign(self, node): newCondition = condition & CNFCondition(myCondition) self._condition.append(newCondition) - lVar = self.visit(node[0]) - rVar = self.visit(node[1]) self._condition = self._condition[:-1] - + + rVar = self.visit(node[1]) rhs = self._addPermuteIfRequired(node.indices, node.rightTerm(), rVar) + + lVar = self.visit(node[0]) action = ProgramAction(lVar, rhs, False, condition=newCondition) self._addAction(action) From 6ba8b4699ae4b8c5d86f1df00672202b4439acfd Mon Sep 17 00:00:00 2001 From: David Schneller Date: Sat, 22 Nov 2025 17:28:43 +0100 Subject: [PATCH 13/18] Improve FP format support --- include/yateto.h | 1 + include/yateto/Type.h | 39 +++++++++++++++++++++++++++++++++++++++ yateto/ast/node.py | 2 +- yateto/type.py | 17 ++++++++++++----- 4 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 include/yateto/Type.h diff --git a/include/yateto.h b/include/yateto.h index 46c68db..003b5ba 100644 --- a/include/yateto.h +++ b/include/yateto.h @@ -5,5 +5,6 @@ #include "yateto/InitTools.h" #include "yateto/LinearAllocator.h" #include "yateto/Misc.h" +#include "yateto/Type.h" #endif diff --git a/include/yateto/Type.h b/include/yateto/Type.h new file mode 100644 index 0000000..61b5a72 --- /dev/null +++ b/include/yateto/Type.h @@ -0,0 +1,39 @@ +#ifndef YATETO_TYPE_H_ +#define YATETO_TYPE_H_ + +#include + +// C++23 include +#if __has_include() +#include +#endif + +// cf. https://stackoverflow.com/a/70868019 +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include + +namespace yateto { + +#ifdef __STDCPP_FLOAT128_T__ +using f128_ty = std::float128_t; +#elif defined(FLT128_MIN) +using f128_ty = _Float128; +#else +using f128_ty = __float128; +#endif +#ifdef __STDCPP_FLOAT16_T__ +using f16_ty = std::float16_t; +#elif defined(FLT16_MIN) +using f16_ty = _Float16; +#else +using f16_ty = __fp16; +#endif +#ifdef __STDCPP_BFLOAT16_T__ +using bf16_ty = std::bfloat16_t; +#else +using bf16_ty = __bf16; +#endif + +} // yateto + +#endif // YATETO_TYPE_H_ diff --git a/yateto/ast/node.py b/yateto/ast/node.py index be51efb..30a545a 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -615,7 +615,7 @@ def computeSparsityPattern(self, *spps): add_spp = permute_summand(i) spp = aspp.add(spp, add_spp) return spp - + def nonZeroFlops(self): nzFlops = 0 for child in self: diff --git a/yateto/type.py b/yateto/type.py index 1c8d18f..a77b768 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -16,6 +16,7 @@ class Datatype(Enum): F64 = 6 F16 = 7 BF16 = 8 + F128 = 9 def __str__(self): return { @@ -26,6 +27,7 @@ def __str__(self): Datatype.I64: 'i64', Datatype.F32: 'f32', Datatype.F64: 'f64', + Datatype.F128: 'f128', Datatype.F16: 'f16', Datatype.BF16: 'bf16', }[self] @@ -54,6 +56,7 @@ def nptype(self): Datatype.F64: np.float64, Datatype.F16: np.float16, Datatype.BF16: np.float32, # NYI + Datatype.F128: np.float128, }[self] def size(self): @@ -68,6 +71,7 @@ def size(self): Datatype.F64: 8, Datatype.F16: 2, Datatype.BF16: 2, + Datatype.F128: 16, }[self] def literal(self, value): @@ -75,12 +79,15 @@ def literal(self, value): # (note: the extra lambda mapping is needed to prevent type errors) return { Datatype.BOOL: lambda value: 'true' if value else 'false', - Datatype.I8: lambda value: f'{int(value)}', - Datatype.I16: lambda value: f'{int(value)}', - Datatype.I32: lambda value: f'{int(value)}', - Datatype.I64: lambda value: f'{int(value)}LL', + Datatype.I8: lambda value: f'static_cast({int(value)}LL)', + Datatype.I16: lambda value: f'static_cast({int(value)}LL)', + Datatype.I32: lambda value: f'static_cast({int(value)}LL)', + Datatype.I64: lambda value: f'static_cast({int(value)}LL)', Datatype.F32: lambda value: f'{float(value):.16}f', - Datatype.F64: lambda value: f'{float(value):.16}' + Datatype.F64: lambda value: f'{float(value):.16}', + Datatype.F16: lambda value: f'static_cast({float(value):.16})', + Datatype.BF16: lambda value: f'static_cast({float(value):.16})', + Datatype.F128: lambda value: f'static_cast({float(value):.32}q)', }[self](value) class AddressingMode(Enum): From a7df3eed590ed78e7e970a7cf5187b4b1e27e736 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Mon, 22 Dec 2025 08:54:11 +0100 Subject: [PATCH 14/18] Tests and bugfixes --- tests/code-gen/elementwise.py | 41 ++++++++++++++++++++++++++++ tests/code-gen/reduction.py | 42 +++++++++++++++++++++++++++++ yateto/ast/node.py | 5 +++- yateto/ast/transformer.py | 18 ++++++++++--- yateto/codegen/factory.py | 6 ++--- yateto/codegen/reduction/generic.py | 6 ++--- yateto/functions.py | 16 ++++++++--- yateto/ops.py | 10 ++++--- yateto/type.py | 24 ++++++++++------- 9 files changed, 139 insertions(+), 29 deletions(-) create mode 100644 tests/code-gen/elementwise.py create mode 100644 tests/code-gen/reduction.py diff --git a/tests/code-gen/elementwise.py b/tests/code-gen/elementwise.py new file mode 100644 index 0000000..a289d3e --- /dev/null +++ b/tests/code-gen/elementwise.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A = Tensor('A', (N, N)) + B = Tensor('B', (N, N)) + C = Tensor('C', (N, N)) + + AI = Tensor('AI', (N, N), datatype=Datatype.I32) + BI = Tensor('BI', (N, N), datatype=Datatype.I32) + CI = Tensor('CI', (N, N), datatype=Datatype.I32) + + AB = Tensor('AB', (N, N), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(A['ij'] <= yf.sqrt(B['ij'])) + _(A['ij'] <= yf.sqrt(B['ij']) + yf.sin(C['ij'])) + _(A['ij'] <= yf.sqrt(B['ij']) * yf.sin(C['ij'])) + _(A['ij'] <= yf.minimum(B['ij'], C['ij'])) + _(A['ij'] <= yf.minimum(B['ij'], C['ij'] + yf.atanh(B['ij']))) + + _(AI['ij'] <= BI['ij'] + CI['ij']) + _(AI['ij'] <= yf.bitwise_and(BI['ij'], CI['ij'])) + + _(AB['ij'] <= yf.greater_equal(BI['ij'], CI['ij'])) + _(A['ij'] <= yf.where(yf.greater_equal(BI['ij'], CI['ij']), B['ij'], C['ij'])) + + _(AI['ij'] <= yf.cast(A['ij'], Datatype.I32)) diff --git a/tests/code-gen/reduction.py b/tests/code-gen/reduction.py new file mode 100644 index 0000000..ef0399d --- /dev/null +++ b/tests/code-gen/reduction.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A0 = Tensor('A0', ()) + A1 = Tensor('A1', (N,)) + A2 = Tensor('A2', (N, N)) + + AI0 = Tensor('AI0', (), datatype=Datatype.I32) + AI1 = Tensor('AI1', (N,), datatype=Datatype.I32) + AI2 = Tensor('AI2', (N, N), datatype=Datatype.I32) + + AB0 = Tensor('AB0', (), datatype=Datatype.BOOL) + AB1 = Tensor('AB1', (N,), datatype=Datatype.BOOL) + AB2 = Tensor('AB2', (N, N), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(A0[''] <= yf.sum(A1['i'], 'i')) + _(A0[''] <= yf.sum(A2['ij'], 'ij')) + _(A0[''] <= yf.min(A2['ij'], 'ij')) + + _(AI0[''] <= yf.sum(AI2['ij'], 'ij')) + _(AI0[''] <= yf.min(AI2['ij'], 'ij')) + _(AI0[''] <= yf.all(AI2['ij'], 'ij')) + _(AI0[''] <= yf.any(AI2['ij'], 'ij')) + + _(AB0[''] <= yf.all(AB2['ij'], 'ij')) + _(AB0[''] <= yf.any(AB2['ij'], 'ij')) + diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 03ae16b..ebaa484 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -109,7 +109,7 @@ def __rmul__(self, other): def __add__(self, other): if not isinstance(other, Node): - raise ValueError('Unsupported operation: Cannot add {} to {}.'.format(self, other)) + raise ValueError(f'Unsupported operation: Cannot add {self} to {other}.') return self._binOp(other, Add) def __radd__(self, other): @@ -689,6 +689,9 @@ def nonZeroFlops(self): def reductionIndex(self): return self._reductionIndex + def reductionIndices(self): + return [self._reductionIndex] + def computeSparsityPattern(self, *spps): assert len(spps) <= 1 spp = spps[0] if len(spps) == 1 else self.term().eqspp() diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index d21d69a..4af340b 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -34,16 +34,16 @@ def visit(self, node, bound=None): elif isinstance(self._targetIndices, Indices): node.indices = self._targetIndices else: - raise ValueError('Target indices type ({}) is not supported.'.format(self._targetIndices.__class__.__name__)) + raise ValueError(f'Target indices type ({self._targetIndices.__class__.__name__}) is not supported.') if not (node.indices <= oldIndices and oldIndices <= node.indices): - raise ValueError('Target index dimensions do not match: {} != {}'.format(node.indices.__repr__(), oldIndices.__repr__())) + raise ValueError(f'Target index dimensions do not match: {node.indices.__repr__()} != {oldIndices.__repr__()}') return node def visit_IndexedTensor(self, node, bound): if set(node.indices) > bound: free = node.indices - bound - raise ValueError('The indices {} are not bound in {}.'.format(free.__repr__(), node)) + raise ValueError(f'The indices {free.__repr__()} are not bound in {node}.') return node def visit_Einsum(self, node, bound): @@ -103,6 +103,11 @@ def visit_Elementwise(self, node, bound): self.visit(child, bound) node.indices = deepcopy(node[0].indices) return node + + def visit_Reduction(self, node, bound): + subbound = bound | set(node.reductionIndices()) + self.visit(node.term(), subbound) + return node def visit_SliceView(self, node, bound): self.visit(node.term(), bound) @@ -128,7 +133,7 @@ def visit_Assign(self, node, bound): node.indices = lhs.indices if not (rhs.indices <= lhs.indices): - raise ValueError('Index dimensions do not match: {} != {}'.format(lhs.indices.__repr__(), rhs.indices.__repr__())) + raise ValueError(f'Index dimensions do not match: {lhs.indices.__repr__()} != {rhs.indices.__repr__()}') return node @@ -230,6 +235,11 @@ def visit_Elementwise(self, node): node.setEqspp( node.computeSparsityPattern() ) return node + def visit_Reduction(self, node): + self.generic_visit(node) + node.setEqspp( node.computeSparsityPattern() ) + return node + def getEqspp(self, terms, targetIndices): # Shortcut if all terms have dense eqspps if all(term.eqspp().is_dense() for term in terms): diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index d939e48..de6cc79 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -3,7 +3,7 @@ from ..ast.node import IndexedTensor from ..memory import DenseMemoryLayout from .common import forLoops, TensorDescription, IndexedTensorDescription, BatchedOperationsAux -from . import copyscaleadd, indexsum, log, product, fused_gemms, elementwise +from . import copyscaleadd, indexsum, log, product, fused_gemms, elementwise, reduction from ..type import Datatype, AddressingMode, Scalar class KernelFactory(object): @@ -275,9 +275,7 @@ def create_Elementwise(self, node, result, arguments, condition, add, scalar, pr def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, node.indices) - argTerm = self._formatTerm(arguments[0], node.term()) - - termTerm = node.optype.callstr(*node.fillTerms(argTerms)) + termTerm = self._formatTerm(arguments[0], node.term().indices) return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) diff --git a/yateto/codegen/reduction/generic.py b/yateto/codegen/reduction/generic.py index d257bd2..c092201 100644 --- a/yateto/codegen/reduction/generic.py +++ b/yateto/codegen/reduction/generic.py @@ -16,11 +16,11 @@ def generate(self, cpp, routineCache): assert len(sumIndex) == 1 class IndexSumBody(object): def __call__(s): - target = '{}[{}]'.format(d.result.name, d.result.memoryLayout.addressString(d.result.indices)) + target = f'{d.result.name}[{d.result.memoryLayout.addressString(d.result.indices)}]' initialValue = target if d.add else d.result.datatype.literal(d.optype.neutral()) cpp(f'{d.result.datatype.ctype()} acc = {initialValue};') - with cpp.For('int {0} = {1}; {0} < {2}; ++{0}'.format(sumIndex, d.sumLoopRange.start, d.sumLoopRange.stop)): - argstr = {d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}] + with cpp.For(f'int {sumIndex} = {d.sumLoopRange.start}; {sumIndex} < {d.sumLoopRange.stop}; ++{sumIndex}'): + argstr = f'{d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}]' cpp( f'acc = {d.optype.callstr('acc', argstr)};' ) mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' cpp( f'{target} = {mult}acc;' ) diff --git a/yateto/functions.py b/yateto/functions.py index e9d01c8..c5a9196 100644 --- a/yateto/functions.py +++ b/yateto/functions.py @@ -2,6 +2,12 @@ from .ast import node from .type import Datatype +def add(x, y): return node.Elementwise(ops.Add(), x, y) +def mul(x, y): return node.Elementwise(ops.Mul(), x, y) +def bitwise_or(x, y): return node.Elementwise(ops.Or(), x, y) +def bitwise_and(x, y): return node.Elementwise(ops.And(), x, y) +def bitwise_xor(x, y): return node.Elementwise(ops.Xor(), x, y) + def sin(x): return node.Elementwise(ops.Sin(), x) def cos(x): return node.Elementwise(ops.Cos(), x) def tan(x): return node.Elementwise(ops.Tan(), x) @@ -25,8 +31,8 @@ def cbrt(x): return node.Elementwise(ops.Cbrt(), x) def abs(x): return node.Elementwise(ops.Abs(), x) -def max(x, y): return node.Elementwise(ops.Max(), x, y) -def min(x, y): return node.Elementwise(ops.Min(), x, y) +def maximum(x, y): return node.Elementwise(ops.Max(), x, y) +def minimum(x, y): return node.Elementwise(ops.Min(), x, y) def pow(x, y): return node.Elementwise(ops.Pow(), x, y) def assign(lhs, rhs): return node.Assign(lhs, rhs) @@ -47,11 +53,13 @@ def reduction(op, term, indices): if len(indices) == 0: return term else: - reduction(op, node.Reduction(op, term, indices[0]), indices[1:]) + return reduction(op, node.Reduction(op, term, indices[0]), indices[1:]) def sum(term, indices): return reduction(ops.Add(), term, indices) -def product(term, indices): return reduction(ops.Mul(), term, indices) +def prod(term, indices): return reduction(ops.Mul(), term, indices) def all(term, indices): return reduction(ops.And(), term, indices) def any(term, indices): return reduction(ops.Or(), term, indices) +def min(term, indices): return reduction(ops.Min(), term, indices) +def max(term, indices): return reduction(ops.Max(), term, indices) def cast(x, dtype): return node.Elementwise(ops.Typecast(dtype), x) diff --git a/yateto/ops.py b/yateto/ops.py index 2b94525..2167c24 100644 --- a/yateto/ops.py +++ b/yateto/ops.py @@ -191,7 +191,9 @@ def datatypeResult(self, argtypes): return argtypes[0] -class Max(Operation, CFunctionMixin, BinaryArgsMixin): +class Max(Operation, CFunctionMixin, BinaryArgsMixin, CommutativeMonoidMixin): + def neutral(self): + return -float('inf') def cppname(self): return 'std::max' def call(self, *args): @@ -199,7 +201,9 @@ def call(self, *args): def datatypeResult(self, argtypes): # assert argtypes[0] == argtypes[1] return argtypes[0] -class Min(Operation, CFunctionMixin, BinaryArgsMixin): +class Min(Operation, CFunctionMixin, BinaryArgsMixin, CommutativeMonoidMixin): + def neutral(self): + return float('inf') def cppname(self): return 'std::min' def call(self, *args): @@ -315,7 +319,7 @@ def call(self, *args): return args[0] > args[1] def datatypeResult(self, argtypes): return Datatype.BOOL -class CmpGt(Operation, CBinaryOperatorMixin, BinaryArgsMixin): +class CmpGe(Operation, CBinaryOperatorMixin, BinaryArgsMixin): def cppname(self, *args): return '>=' def call(self, *args): diff --git a/yateto/type.py b/yateto/type.py index 59767bb..9f2c47f 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -41,8 +41,9 @@ def ctype(self): Datatype.I64: 'int64_t', Datatype.F32: 'float', Datatype.F64: 'double', - Datatype.F16: 'int16_t', - Datatype.BF16: 'int16_t', + Datatype.F16: 'yateto::f16_ty', + Datatype.BF16: 'yateto::bf16_ty', + Datatype.F128: 'yateto::f128_ty', }[self] def nptype(self): @@ -74,20 +75,23 @@ def size(self): Datatype.F128: 16, }[self] + def safeint(self, value): + # allow inf/-inf to be treated as int + return int(max(-2**64, min(2**64, value))) + def literal(self, value): - # TODO: BF16, F16 # (note: the extra lambda mapping is needed to prevent type errors) return { Datatype.BOOL: lambda value: 'true' if value else 'false', - Datatype.I8: lambda value: f'static_cast({int(value)}LL)', - Datatype.I16: lambda value: f'static_cast({int(value)}LL)', - Datatype.I32: lambda value: f'static_cast({int(value)}LL)', - Datatype.I64: lambda value: f'static_cast({int(value)}LL)', + Datatype.I8: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.I16: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.I32: lambda value: f'static_cast({self.safeint(value)}LL)', + Datatype.I64: lambda value: f'static_cast({self.safeint(value)}LL)', Datatype.F32: lambda value: f'{float(value):.16}f', Datatype.F64: lambda value: f'{float(value):.16}', - Datatype.F16: lambda value: f'static_cast({float(value):.16})', - Datatype.BF16: lambda value: f'static_cast({float(value):.16})', - Datatype.F128: lambda value: f'static_cast({float(value):.32}q)', + Datatype.F16: lambda value: f'static_cast({float(value):.16})', + Datatype.BF16: lambda value: f'static_cast({float(value):.16})', + Datatype.F128: lambda value: f'static_cast({float(value):.32}q)', }[self](value) class AddressingMode(Enum): From f2afc9061b56c3032660ee8a790c7235b15e318a Mon Sep 17 00:00:00 2001 From: David Schneller Date: Mon, 22 Dec 2025 10:59:56 +0100 Subject: [PATCH 15/18] Adjust index permutation computation for the new ops --- yateto/ast/visitor.py | 62 ++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/yateto/ast/visitor.py b/yateto/ast/visitor.py index fc59e87..81f4b28 100644 --- a/yateto/ast/visitor.py +++ b/yateto/ast/visitor.py @@ -1,4 +1,4 @@ -from numpy import ndindex, arange, float64, add, einsum +from numpy import ndindex, arange, float64, add, einsum, apply_along_axis import math import collections import itertools @@ -110,29 +110,31 @@ def variantsFixedRootPermutation(self, node, fixedPerm, permutationVariants): variants[fixedPerm] = self.Variant(minCost, minInd) return variants - def allPermutationsNoCostBinaryOp(self, node): + def allPermutationsNoCostNAryOp(self, node): permutationVariants = self.findVariants(node) - lV = permutationVariants[node.leftTerm()] - rV = permutationVariants[node.rightTerm()] + V = [permutationVariants[child] for child in node] minCost = LoGCost() - minAind = None - minBind = None - for Aind in sorted(lV): - for Bind in sorted(rV): - cost = lV[Aind]._cost + rV[Bind]._cost - if cost < minCost: - minCost = cost - minAind = Aind - minBind = Bind - assert minAind is not None and minBind is not None + minInd = None + + for ind in itertools.product(*V): + + cost = sum((V[i][Vind]._cost for i,Vind in enumerate(ind)), + LoGCost.addIdentity()) + + if cost < minCost: + minCost = cost + minInd = ind + + assert minInd is not None + iterator = itertools.permutations(node.indices) - permutationVariants[node] = {''.join(Cs): self.Variant(minCost, [minAind, minBind]) for Cs in iterator} + permutationVariants[node] = {''.join(Cs): self.Variant(minCost, list(minInd)) for Cs in iterator} return permutationVariants def generic_visit(self, node): permutationVariants = self.findVariants(node) variants = self.variantsFixedRootPermutation(node, str(node.indices), permutationVariants) - assert variants, 'Could not find implementation for {}.'.format(type(node)) + assert variants, f'Could not find implementation for {node}.' permutationVariants[node] = variants return permutationVariants @@ -152,8 +154,11 @@ def visit_ScalarMultiplication(self, node): return permutationVariants def visit_Product(self, node): - return self.allPermutationsNoCostBinaryOp(node) - + return self.allPermutationsNoCostNAryOp(node) + + def visit_Elementwise(self, node): + return self.allPermutationsNoCostNAryOp(node) + def visit_IndexSum(self, node): permutationVariants = self.findVariants(node) tV = permutationVariants[node.term()] @@ -168,6 +173,9 @@ def visit_IndexSum(self, node): iterator = itertools.permutations(node.indices) permutationVariants[node] = {''.join(Cs): self.Variant(minCost, [minTind]) for Cs in iterator} return permutationVariants + + def visit_Reduction(self, node): + return self.visit_IndexSum(node) def visit_Contraction(self, node): permutationVariants = self.findVariants(node) @@ -299,32 +307,42 @@ def generic_visit(self, node): def visit_Einsum(self, node): terms = self.generic_visit(node) childIndices = [child.indices for child in node] - assert None not in childIndices and node.indices is not None, 'Use DeduceIndices before {}.'.format(self.__class__.__name__) + assert None not in childIndices and node.indices is not None, f'Use DeduceIndices before {self.__class__.__name__}.' einsumDescription = ','.join(indices.tostring() for indices in childIndices) einsumDescription = '{}->{}'.format(einsumDescription, node.indices.tostring()) return einsum(einsumDescription, *terms) def visit_Add(self, node): terms = self.generic_visit(node) - assert len(terms) > 1 permute = lambda indices, tensor: tensor.transpose(tuple(indices.find(idx) for idx in node.indices)) return reduce(add, [permute(child.indices, terms[i]) for i,child in enumerate(node)]) def visit_ScalarMultiplication(self, node): - assert node.is_constant() is not None, '{} may only be used when all involved scalars are constant.'.format(self.__class__.__name__) + assert node.is_constant() is not None, f'{self.__class__.__name__} may only be used when all involved scalars are constant.' terms = self.generic_visit(node) assert len(terms) == 1 return node.scalar() * terms[0] def visit_IndexedTensor(self, node): term = node.tensor.values_as_ndarray(self._dtype) - assert term is not None, '{} may only be used when all involved tensors are constant.'.format(self.__class__.__name__) + assert term is not None, f'{self.__class__.__name__} may only be used when all involved tensors are constant.' return term def visit_Elementwise(self, node): terms = self.generic_visit(node) fullTerms = node.fillTerms(terms) return node.optype.call(*fullTerms) + + def visit_Reduction(self, node): + terms = self.generic_visit(node) + assert len(terms) == 1 + indexpos = node.term().indices.find(node.reductionIndex()) + return np.apply_along_axis(lambda x: reduce(node.optype.call, x), indexpos, terms[0]) + + def visit_Accumulate(self, node): + terms = self.generic_visit(node) + permute = lambda indices, tensor: tensor.transpose(tuple(indices.find(idx) for idx in node.indices)) + return reduce(node.optype.call, [permute(child.indices, terms[i]) for i,child in enumerate(node)]) class ComputeIndexSet(CachedVisitor): def generic_visit(self, node): From 3ff353aa2d6d83de1e25e3c2ec2714c76fada88b Mon Sep 17 00:00:00 2001 From: David Schneller Date: Mon, 22 Dec 2025 11:15:02 +0100 Subject: [PATCH 16/18] Fix codegen --- yateto/controlflow/visitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index eeb0297..b9c241b 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -24,7 +24,7 @@ def _addTransformOp(self, permute, variable): if not self._simpleMemoryLayout: permute.setEqspp( permute.computeSparsityPattern() ) permute.computeMemoryLayout() - permute.datatype = term.datatype + permute.datatype = permute[0].datatype result = self._nextTemporary(permute) action = ProgramAction(result, Expression(permute, self._ml(permute), [variable]), False, condition=self._condition[-1]) self._addAction(action) From 4a907486761c3138e28c9db844c2bae25915c4f6 Mon Sep 17 00:00:00 2001 From: David Schneller Date: Sat, 21 Feb 2026 16:49:03 +0100 Subject: [PATCH 17/18] Add some more tests --- .github/workflows/yateto-ci.yml | 2 +- tests/code-gen/conditional.py | 44 +++++++++++++++++++++++++++++++ tests/code-gen/datatype.py | 29 ++++++++++++++++++++ yateto/codegen/common.py | 13 +++++---- yateto/controlflow/graph.py | 11 ++++++-- yateto/controlflow/transformer.py | 8 +++--- 6 files changed, 95 insertions(+), 12 deletions(-) create mode 100644 tests/code-gen/conditional.py create mode 100644 tests/code-gen/datatype.py diff --git a/.github/workflows/yateto-ci.yml b/.github/workflows/yateto-ci.yml index b637aba..7647cbd 100644 --- a/.github/workflows/yateto-ci.yml +++ b/.github/workflows/yateto-ci.yml @@ -73,7 +73,7 @@ jobs: - name: Codegen Tests run: | cd ./tests/code-gen - for example in matmul minimal indices slicing; do + for example in matmul minimal indices slicing elementwise reduction datatype conditional; do for build_type in Debug Release; do for precision in single double; do echo " ====== Test Config: ======" diff --git a/tests/code-gen/conditional.py b/tests/code-gen/conditional.py new file mode 100644 index 0000000..7295a26 --- /dev/null +++ b/tests/code-gen/conditional.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A = Tensor('A', (N, N)) + B = Tensor('B', (N, N)) + C = Tensor('C', (N, N)) + + AI = Tensor('AI', (N, N), datatype=Datatype.I32) + BI = Tensor('BI', (N, N), datatype=Datatype.I32) + CI = Tensor('CI', (N, N), datatype=Datatype.I32) + + AB = Tensor('AB', (N, N), datatype=Datatype.BOOL) + + X = Tensor('X', (), datatype=Datatype.BOOL) + X1 = Tensor('X1', (), datatype=Datatype.BOOL) + X2 = Tensor('X2', (), datatype=Datatype.BOOL) + X3 = Tensor('X3', (), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(yf.assignIf(X[''], A['ij'], yf.sqrt(B['ij']))) + _(yf.assignIf(yf.all(AB['ij'], 'ij'), A['ij'], yf.sqrt(B['ij']))) + _([ + yf.assignIf(X[''], A['ij'], yf.sqrt(B['ij'])), + yf.assignIf(X[''], AI['ij'], -BI['ij']) + ]) + _([ + yf.assignIf(X1[''], A['ij'], B['ik'] * C['kj'] + C['ij']), + yf.assignIf(X1[''], A['ij'], A['ij'] + B['ik'] * C['kj'] + C['ij']), + yf.assignIf(X2[''], C['ij'], yf.sqrt(B['ij'])) + ]) diff --git a/tests/code-gen/datatype.py b/tests/code-gen/datatype.py new file mode 100644 index 0000000..c2f7581 --- /dev/null +++ b/tests/code-gen/datatype.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +from yateto import * + +import yateto.functions as yf + +def add(g): + N = 8 + A = Tensor('A', (N, N)) + B = Tensor('B', (N, N)) + C = Tensor('C', (N, N)) + + AI = Tensor('AI', (N, N), datatype=Datatype.I32) + BI = Tensor('BI', (N, N), datatype=Datatype.I32) + CI = Tensor('CI', (N, N), datatype=Datatype.I32) + + AB = Tensor('AB', (N, N), datatype=Datatype.BOOL) + + class Counter: + def __init__(self): + self.counter = 0 + + counter = Counter() + + def _(kernel): + counter.counter += 1 + g.add(f'kernel{counter.counter}', kernel) + + _(AI['ij'] <= yf.cast(A['ij'], Datatype.I32)) diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index a2e78e0..d3e4aa1 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -3,7 +3,7 @@ from ..type import AddressingMode from ..ast.indices import BoundingBox from ..ast.log import splitByDistance -from .tiny_tensor_language import Dump, Function, IntegerType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst +from .tiny_tensor_language import Dump, Function, IntegerType, FloatingType, MemrefType, GroupType, IntImmValue, DYNAMIC, SubviewInst, LoadInst import hashlib class TensorDescription(object): @@ -18,6 +18,9 @@ def __init__(self, name, memoryLayout, eqspp, is_compute_constant=False, is_temp elements are known at compile time is_temporary (bool): if true then the description is for a temporary tensor which usually results from a result of an intermediate computation + values (Union[np.ndarray, None]): the values of the compute_constant tensor, if they are known at compile time + datatype (Datatype): the datatype of the tensor elements + addressing (AddressingMode): the addressing mode for the tensor """ self.name = name self.memoryLayout = memoryLayout @@ -301,8 +304,8 @@ def toTinyTCType(datatype: Datatype): Datatype.I16: ScalarType(IntegerType.i16), Datatype.I32: ScalarType(IntegerType.i32), Datatype.I64: ScalarType(IntegerType.i64), - Datatype.F32: ScalarType(IntegerType.f32), - Datatype.F64: ScalarType(IntegerType.f64) + Datatype.F32: ScalarType(FloatingType.f32), + Datatype.F64: ScalarType(FloatingType.f64) }[datatype] def toTinyTCImmediate(datatype: Datatype, value): @@ -312,6 +315,6 @@ def toTinyTCImmediate(datatype: Datatype, value): Datatype.I16: lambda value: IntImmValue(IntegerType.i16, value), Datatype.I32: lambda value: IntImmValue(IntegerType.i32, value), Datatype.I64: lambda value: IntImmValue(IntegerType.i64, value), - Datatype.F32: lambda value: FloatImmValue(IntegerType.f32, value), - Datatype.F64: lambda value: FloatImmValue(IntegerType.f64, value), + Datatype.F32: lambda value: FloatImmValue(FloatingType.f32, value), + Datatype.F64: lambda value: FloatImmValue(FloatingType.f64, value), }[datatype](value) diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index cca763a..58f167e 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -196,10 +196,11 @@ def maySubstitute(self, when, by, result = True, term = True): return (not term or maySubsTerm) and (not result or maySubsResult) and compatible - def substituted(self, when, by, result = True, term = True): + def substituted(self, when, by, cond, result = True, term = True): rsubs = self.result.substituted(when, by) if result else self.result tsubs = self.term.substituted(when, by, rsubs.memoryLayout()) if term else self.term - return ProgramAction(rsubs, tsubs, self.add, self.scalar, self.condition) + csubs = self.condition & cond + return ProgramAction(rsubs, tsubs, self.add, self.scalar, csubs) def setVariablesWritable(self, name): self.result.setWritable(name) @@ -344,6 +345,9 @@ def __not__(self): result = result | clauseInv return result + + def __and__(self, other): + return self.__rand__(other) def __rand__(self, other): if not isinstance(other, CNFCondition): @@ -356,6 +360,9 @@ def __rand__(self, other): condition._prune() return condition + + def __or__(self, other): + return self.__ror__(other) def __ror__(self, other): if not isinstance(other, CNFCondition): diff --git a/yateto/controlflow/transformer.py b/yateto/controlflow/transformer.py index cfddd8a..019c4fe 100644 --- a/yateto/controlflow/transformer.py +++ b/yateto/controlflow/transformer.py @@ -47,7 +47,7 @@ def visit(self, cfg): maySubs = all([cfg[j].action.maySubstitute(when, by) for j in range(i, n)]) if maySubs: for j in range(i, n): - cfg[j].action = cfg[j].action.substituted(when, by) + cfg[j].action = cfg[j].action.substituted(when, by, ua.condition) cfg = LivenessAnalysis().visit(cfg) return cfg @@ -69,9 +69,9 @@ def visit(self, cfg): when = u.action.result maySubs = cfg[found].action.maySubstitute(when, by, term=False) and all([cfg[j].action.maySubstitute(when, by) for j in range(found+1,i+1)]) if maySubs: - cfg[found].action = cfg[found].action.substituted(when, by, term=False) + cfg[found].action = cfg[found].action.substituted(when, by, va.condition, term=False) for j in range(found+1,i+1): - cfg[j].action = cfg[j].action.substituted(when, by) + cfg[j].action = cfg[j].action.substituted(when, by, va.condition) cfg = LivenessAnalysis().visit(cfg) return cfg @@ -109,7 +109,7 @@ def visit(self, cfg): if found >= 0: va = cfg[found].action if ua.maySubstitute(ua.result, va.result, term=False): - cfg[i].action = ua.substituted(ua.result, va.result, term=False) + cfg[i].action = ua.substituted(ua.result, va.result, va.condition, term=False) cfg[i].action.add = va.add if not va.hasTrivialScalar(): cfg[i].action.scalar = va.scalar From 4cf025ca673a7f9daee5a9d6baf6b8aff099328d Mon Sep 17 00:00:00 2001 From: David Schneller Date: Sat, 21 Feb 2026 19:40:56 +0100 Subject: [PATCH 18/18] Fix CI --- include/yateto.h | 2 +- include/yateto/LinearAllocator.h | 6 ++-- include/yateto/Type.h | 2 +- tests/code-gen/reduction.py | 1 - yateto/arch.py | 11 +++--- yateto/ast/node.py | 42 +++++++++++----------- yateto/ast/transformer.py | 18 +++++----- yateto/ast/visitor.py | 8 ++--- yateto/codegen/common.py | 4 +-- yateto/codegen/elementwise/generic.py | 2 +- yateto/codegen/factory.py | 50 +++++++++++++-------------- yateto/codegen/indexsum/generic.py | 4 +-- yateto/codegen/log/generic.py | 2 +- yateto/codegen/reduction/factory.py | 10 +++--- yateto/codegen/reduction/generic.py | 6 ++-- yateto/codegen/visitor.py | 27 ++++++++------- yateto/controlflow/graph.py | 42 +++++++++++----------- yateto/controlflow/visitor.py | 12 +++---- yateto/gemm_configuration.py | 4 +-- yateto/generator.py | 4 +-- yateto/metagen.py | 8 ++--- yateto/ops.py | 14 ++++---- yateto/type.py | 14 ++++---- 23 files changed, 148 insertions(+), 145 deletions(-) diff --git a/include/yateto.h b/include/yateto.h index 5bf2d52..ccc005e 100644 --- a/include/yateto.h +++ b/include/yateto.h @@ -4,7 +4,7 @@ #include "yateto/InitTools.h" #include "yateto/LinearAllocator.h" #include "yateto/Misc.h" -#include "yateto/Type.h" #include "yateto/TensorView.h" +#include "yateto/Type.h" #endif diff --git a/include/yateto/LinearAllocator.h b/include/yateto/LinearAllocator.h index 01f4595..cfc45b5 100644 --- a/include/yateto/LinearAllocator.h +++ b/include/yateto/LinearAllocator.h @@ -13,10 +13,10 @@ struct LinearAllocatorT { userSpaceMem = ptr; } - template + template void initialize(S* ptr) { - isInit = true; - userSpaceMem = reinterpret_cast(ptr); + isInit = true; + userSpaceMem = reinterpret_cast(ptr); } T* allocate(size_t size) { diff --git a/include/yateto/Type.h b/include/yateto/Type.h index 61b5a72..1e1a798 100644 --- a/include/yateto/Type.h +++ b/include/yateto/Type.h @@ -34,6 +34,6 @@ using bf16_ty = std::bfloat16_t; using bf16_ty = __bf16; #endif -} // yateto +} // namespace yateto #endif // YATETO_TYPE_H_ diff --git a/tests/code-gen/reduction.py b/tests/code-gen/reduction.py index ef0399d..420a246 100644 --- a/tests/code-gen/reduction.py +++ b/tests/code-gen/reduction.py @@ -39,4 +39,3 @@ def _(kernel): _(AB0[''] <= yf.all(AB2['ij'], 'ij')) _(AB0[''] <= yf.any(AB2['ij'], 'ij')) - diff --git a/yateto/arch.py b/yateto/arch.py index 36f0ce8..3fd2f76 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -70,11 +70,14 @@ def __init__(self, self.host_name = host_name self.precision = precision.upper() - if self.precision == 'D': - self.epsilon = 2.22e-16 + if self.precision == 'Q': + self.epsilon = 2**-112 + self.datatype = Datatype.F128 + elif self.precision == 'D': + self.epsilon = 2**-52 self.datatype = Datatype.F64 elif self.precision == 'S': - self.epsilon = 1.19e-7 + self.epsilon = 2**-23 self.datatype = Datatype.F32 else: raise ValueError(f'Unknown precision type {self.precision}') @@ -113,7 +116,7 @@ def formatConstant(self, constant): def onHeap(self, byteCount): return byteCount > self._tmpStackLimit - + def __eq__(self, other): return self.name == other.name diff --git a/yateto/ast/node.py b/yateto/ast/node.py index 1bc77f0..4d3e2b0 100644 --- a/yateto/ast/node.py +++ b/yateto/ast/node.py @@ -124,10 +124,10 @@ def __sub__(self, other): def __le__(self, other): return Assign(self, other) - + def __truediv__(self, other): return Elementwise(ops.Div(), self, other) - + def __rtruediv__(self, other): return Elementwise(ops.Div(), other, self) @@ -355,18 +355,18 @@ def __init__(self, lTerm, rTerm, condition=True): super().__init__(lTerm, rTerm, condition) else: super().__init__(lTerm, rTerm) - + self._condition = condition - + def leftTerm(self): return self._children[0] - + def rightTerm(self): return self._children[1] - + def condition(self): return self._condition - + def setChildren(self, children): if not isinstance(children[0].viewed(), IndexedTensor): raise ValueError('First child of Assign node must be an IndexedTensor: ' + str(children[0].viewed())) @@ -378,7 +378,7 @@ def nonZeroFlops(self): def computeSparsityPattern(self, *spps): spp = spps[1] if len(spps) >= 2 else self.rightTerm().eqspp() return self.broadcast(self.rightTerm().indices, self.permute(self.rightTerm().indices, spp, False)) - + def __str__(self): selfname = type(self).__name__ indices = self.indices if self.indices != None else '' @@ -614,20 +614,20 @@ def __init__(self, condition, yesTerm, noTerm): super().__init__(yesTerm, noTerm, condition) else: super().__init__(yesTerm, noTerm) - + self._condition = condition - + def condition(self): return condition - + def nonZeroFlops(self): return 0 - + def computeSparsityPattern(self, *spps): # TODO: yesTerm OR noTerm spp = spps[0] if len(spps) >= 2 else self.term().eqspp() return spp - + def __str__(self): indices = self.indices if self.indices != None else '' return f'{type(self).__name__}[{indices}]' @@ -660,17 +660,17 @@ def __init__(self, optype: ops.Operation, *terms): def nonZeroFlops(self): return self.eqspp().count_nonzero() - + def fillTerms(self, terms): assert len(terms) == len(self) return [terms[index] if template is None else template for template, index in zip(self.termTemplate, self.nodeTermIndices)] - + def computeSparsityPattern(self, *spps): if len(spps) == 0: spps = [node.eqspp() for node in self] xspp = spps[0] return spps[0] - + def __str__(self): indices = self.indices if self.indices != None else '' return f'{type(self).__name__}({self.optype})[{indices}]' @@ -682,21 +682,21 @@ def __init__(self, optype, term, sumIndex): self.indices = term.indices - set([sumIndex]) self._reductionIndex = term.indices.extract(sumIndex) self.optype = optype - + def nonZeroFlops(self): return self.term().eqspp().count_nonzero() - self.eqspp().count_nonzero() - + def reductionIndex(self): return self._reductionIndex - + def reductionIndices(self): return [self._reductionIndex] - + def computeSparsityPattern(self, *spps): assert len(spps) <= 1 spp = spps[0] if len(spps) == 1 else self.term().eqspp() return spp.indexSum(self.term().indices, self.indices) - + def __str__(self): indices = self.indices if self.indices != None else '' return f'{type(self).__name__}({self.optype})[{indices}]' diff --git a/yateto/ast/transformer.py b/yateto/ast/transformer.py index 244924f..13385b7 100644 --- a/yateto/ast/transformer.py +++ b/yateto/ast/transformer.py @@ -97,13 +97,13 @@ def visit_ScalarMultiplication(self, node, bound): self.visit(node.term(), bound) node.indices = deepcopy(node.term().indices) return node - + def visit_Elementwise(self, node, bound): for child in node: self.visit(child, bound) node.indices = deepcopy(node[0].indices) return node - + def visit_Reduction(self, node, bound): subbound = bound | set(node.reductionIndices()) self.visit(node.term(), subbound) @@ -229,17 +229,17 @@ def visit_Assign(self, node): self.generic_visit(node) node.setEqspp( node.computeSparsityPattern() ) return node - + def visit_Elementwise(self, node): self.generic_visit(node) node.setEqspp( node.computeSparsityPattern() ) return node - + def visit_Reduction(self, node): self.generic_visit(node) node.setEqspp( node.computeSparsityPattern() ) return node - + def getEqspp(self, terms, targetIndices): # Shortcut if all terms have dense eqspps if all(term.eqspp().is_dense() for term in terms): @@ -303,12 +303,12 @@ def visit_IndexedTensor(self, node): super().generic_visit(node) node.datatype = node.tensor.getDatatype(self.arch) return node - + def visit_Elementwise(self, node): super().generic_visit(node) node.datatype = node.optype.datatypeResult([c.datatype for c in node]) return node - + def visit_Assign(self, node): super().generic_visit(node) node.datatype = node[0].datatype @@ -325,12 +325,12 @@ def generic_visit(self, node): def visit_IndexedTensor(self, node): super().generic_visit(node) return node - + def visit_Elementwise(self, node): super().generic_visit(node) node.datatype = node.optype.datatypeResult([c.datatype for c in node]) return node - + def visit_Assign(self, node): super().generic_visit(node) node.datatype = node[0].datatype diff --git a/yateto/ast/visitor.py b/yateto/ast/visitor.py index 7d3768d..12106b9 100644 --- a/yateto/ast/visitor.py +++ b/yateto/ast/visitor.py @@ -155,7 +155,7 @@ def visit_ScalarMultiplication(self, node): def visit_Product(self, node): return self.allPermutationsNoCostNAryOp(node) - + def visit_Elementwise(self, node): return self.allPermutationsNoCostNAryOp(node) @@ -173,7 +173,7 @@ def visit_IndexSum(self, node): iterator = itertools.permutations(node.indices) permutationVariants[node] = {''.join(Cs): self.Variant(minCost, [minTind]) for Cs in iterator} return permutationVariants - + def visit_Reduction(self, node): return self.visit_IndexSum(node) @@ -327,12 +327,12 @@ def visit_IndexedTensor(self, node): term = node.tensor.values_as_ndarray(self._dtype) assert term is not None, f'{self.__class__.__name__} may only be used when all involved tensors are constant.' return term - + def visit_Elementwise(self, node): terms = self.generic_visit(node) fullTerms = node.fillTerms(terms) return node.optype.call(*fullTerms) - + def visit_Reduction(self, node): terms = self.generic_visit(node) assert len(terms) == 1 diff --git a/yateto/codegen/common.py b/yateto/codegen/common.py index 302e034..1537137 100644 --- a/yateto/codegen/common.py +++ b/yateto/codegen/common.py @@ -55,7 +55,7 @@ def fromNode(cls, var, node): values = baseNode.tensor.values() addressing = baseNode.tensor.addressing return cls(str(var), node.indices, var.memoryLayout(), node.eqspp(), is_const, var.is_temporary, values, datatype, addressing) - + @classmethod def fromVar(cls, var, indices): datatype = var.datatype @@ -134,7 +134,7 @@ def _get_ptr_type(cls, addressing: AddressingMode): def deduce_addresing(cls, term): if term.addressing is not None: return term.addressing - + # default deduction if term.is_compute_constant: return AddressingMode.DIRECT diff --git a/yateto/codegen/elementwise/generic.py b/yateto/codegen/elementwise/generic.py index c0a0f4f..a64cf4e 100644 --- a/yateto/codegen/elementwise/generic.py +++ b/yateto/codegen/elementwise/generic.py @@ -21,7 +21,7 @@ def _generateDenseDense(self, cpp): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) initializeWithZero(cpp, d.result, writeBB) - + flops, assigner = self._affine(d.add, d.alpha) class ProductBody(object): diff --git a/yateto/codegen/factory.py b/yateto/codegen/factory.py index 2d84c3d..c6bb2f2 100644 --- a/yateto/codegen/factory.py +++ b/yateto/codegen/factory.py @@ -92,7 +92,7 @@ def reset_flags(self): def _indices(self, var): shape = var.memoryLayout().shape() return Indices(string.ascii_lowercase[:len(shape)], shape) - + def _conditional(self, condition, generate): if isinstance(condition, bool): if condition: @@ -132,7 +132,7 @@ def create_FusedGEMMs(self, node, result, arguments, condition, add, scalar, pre description = fused_gemms.Description(node, result, arguments, condition, add, scalar) generator = fused_gemms.generator(self._arch, description, gemm_cfg, self._target) return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache, gemm_cfg)) - + def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 1 description = indexsum.Description( @@ -143,7 +143,7 @@ def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefe ) generator = indexsum.generator(self._arch, description, self._target) return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): assert len(arguments) == 2 description = product.Description( @@ -168,7 +168,7 @@ def create_Elementwise(self, node, result, arguments, condition, add, scalar, pr ) generator = elementwise.generator(self._arch, description, self._target) return self._conditional(condition, lambda: generator.generate(self._cpp, routineCache)) - + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): description = reduction.Description( alpha = scalar, @@ -184,17 +184,17 @@ def create_Permute(self, node, result, arguments, condition, add, scalar, prefet result = IndexedTensorDescription.fromNode(result, node) term = IndexedTensorDescription.fromNode(arguments[0], node.term()) return self._csa(result, term, add, condition, scalar, routineCache, gemm_cfg) - + def create_Broadcast(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): result = IndexedTensorDescription.fromNode(result, node) term = IndexedTensorDescription.fromNode(arguments[0], node.term()) return self._csa(result, term, add, condition, scalar, routineCache, gemm_cfg) - + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): result = IndexedTensorDescription.fromVar(result, self._indices(result)) term = IndexedTensorDescription.fromVar(term, self._indices(term)) return self._csa(result, term, condition, add, scalar, routineCache, gemm_cfg) - + def _csa(self, result, term, condition, add, scalar, routineCache, gemm_cfg): description = copyscaleadd.Description( alpha = scalar, @@ -215,7 +215,7 @@ def __init__(self, cpp, arch, nameFun, testFramework): def _formatTerm(self, var, indices): address = var.memoryLayout().addressString(indices) return f'{self._name(var)}[{address}]' - + def create_Einsum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = node.indices for child in node: @@ -231,14 +231,14 @@ def create_Einsum(self, node, result, arguments, condition, add, scalar, prefetc if not add: self._cpp.memset(self._name(result), result.memoryLayout().requiredReals(), result.datatype.ctype()) - + class EinsumBody(object): def __call__(s): self._cpp(f"{resultTerm} += {' * '.join(terms)};") return len(terms) return self._conditional(condition, lambda: forLoops(self._cpp, g, ranges, EinsumBody(), pragmaSimd=False)) - + def create_ScalarMultiplication(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): return self._conditional(condition, lambda: self.simple(result, arguments[0], add, scalar, routineCache)) @@ -247,13 +247,13 @@ def create_Permute(self, node, result, arguments, condition, add, scalar, prefet resultTerm = self._formatTerm(result, node.indices) termTerm = self._formatTerm(arguments[0], node.term().indices) return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, node.indices)) - + def create_Broadcast(self, node, result, arguments, add, scalar, prefetchName, routineCache, gemm_cfg): assert node.term().indices <= node.indices resultTerm = self._formatTerm(result, node.indices) termTerm = self._formatTerm(arguments[0], node.term().indices) return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, node.indices)) - + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, node.indices) @@ -271,7 +271,7 @@ def create_Elementwise(self, node, result, arguments, condition, add, scalar, pr termTerm = node.optype.callstr(*node.fillTerms(argTerms)) return self._conditional(condition, lambda: self._simpleBody(resultTerm, termTerm, add, scalar, g)) - + def create_Reduction(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): g = self._indices(result) resultTerm = self._formatTerm(result, node.indices) @@ -364,7 +364,7 @@ def generate(self, cpp, cache): def add_linear_operation(self, dest, ops, target, permute, add): pass - + def add_operation(self, description): pass @@ -381,19 +381,19 @@ def __init__(self, generator, cpp, arch, target): self.generator = generator self.tensors = {} self.scalarcounter = 0 - + def post_generate(self, routine_cache): self.generator.generate(self._cpp, routine_cache) def allocateTemporary(self): return False - + def _nodeTensor(self, tensor, node): return self._handleTensorDesc(IndexedTensorDescription.fromNode(tensor, node)) def _varTensor(self, var, indices): return self._handleTensorDesc(IndexedTensorDescription.fromVar(var, indices)) - + def _handleAddressing(self, desc): if desc.addressing is None: addressing = BatchedOperationsAux.deduce_addresing(desc) @@ -429,7 +429,7 @@ def _handleTensorDesc(self, tensorIndexed: IndexedTensorDescription): } else: assert False - + eqsppnz = tensorIndexed.eqspp.nonzero() spp = [elem for elem in zip(*eqsppnz)] @@ -499,13 +499,13 @@ def _handleTensor(self, tensor, eqspp, indices): self.generator.add_tensor(tensor) else: assert tensor == self.tensors[tensor['name']] - + return { 'name': tensor['name'], 'spp': eqspp, 'indices': indices } - + def _handleCondition(self, condition): out = [] for clause in condition.clauses: @@ -556,21 +556,21 @@ def create_LoopOverGEMM(self, node, result, arguments, condition, add, scalar, p assert len(arguments) == 2 argnodes = [self._nodeTensor(arguments[0], node.leftTerm()), self._nodeTensor(arguments[1], node.rightTerm())] return self.handleLinear(self._nodeTensor(result, node), argnodes, condition, add, scalar, node.transA(), node.transB()) - + def create_IndexSum(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): return create_Reduction(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) - + def create_Product(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): return create_Elementwise(node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg) def create_Permute(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] return self.handleLinear(self._varTensor(result, node.indices), [self._varTensor(term, node.term().indices)], condition, add, scalar, False, False) - + def create_Broadcast(self, node, result, arguments, condition, add, scalar, prefetchName, routineCache, gemm_cfg): term = arguments[0] return self.handleLinear(self._varTensor(result, node.indices), [self._varTensor(term, node.term().indices)], condition, add, scalar, False, False) - + def simple(self, result, term, condition, add, scalar, routineCache, gemm_cfg): return self.handleLinear(self._varTensor(result, self._indices(result)), [self._varTensor(term, self._indices(term))], condition, add, scalar, False, False) @@ -603,7 +603,7 @@ def handleLinear(self, dest, ops, condition, add, scalar, transposeA, transposeB ops += [self._scalarTensor(scalar)] target += [[]] permute += [[]] - + description = { 'type': 'multilinear', 'result': dest, diff --git a/yateto/codegen/indexsum/generic.py b/yateto/codegen/indexsum/generic.py index 06fd86d..5a57764 100644 --- a/yateto/codegen/indexsum/generic.py +++ b/yateto/codegen/indexsum/generic.py @@ -11,7 +11,7 @@ def generate(self, cpp, routineCache): if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) initializeWithZero(cpp, d.result, writeBB) - + sumIndex = d.term.indices - d.result.indices assert len(sumIndex) == 1 class IndexSumBody(object): @@ -23,7 +23,7 @@ def __call__(s): cpp( f'sum += {d.term.name}[{d.term.memoryLayout.addressString(d.term.indices)}];' ) mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' cpp( f'{target} = {mult}sum;' ) - + flop = 1 if d.alpha != 1.0 else 0 return d.sumLoopRange.size() + flop diff --git a/yateto/codegen/log/generic.py b/yateto/codegen/log/generic.py index bcae573..59177c0 100644 --- a/yateto/codegen/log/generic.py +++ b/yateto/codegen/log/generic.py @@ -97,7 +97,7 @@ def generate(self, cpp, routineCache, gemm_cfg): lr.update( self._defuse(n, d.rightTerm, In) ) writeBB = boundingBoxFromLoopRanges(d.result.indices, lr) initializeWithZero(cpp, d.result, writeBB) - + class LoGBody(object): def __call__(s): if hasInnerLoops: diff --git a/yateto/codegen/reduction/factory.py b/yateto/codegen/reduction/factory.py index 69f8fea..4133201 100644 --- a/yateto/codegen/reduction/factory.py +++ b/yateto/codegen/reduction/factory.py @@ -10,21 +10,21 @@ def __init__(self, alpha, add: bool, result: IndexedTensorDescription, term: Ind self.result = result self.term = term self.optype = optype - + rA = loopRanges(self.term, self.result.indices) rB = loopRanges(self.result, self.result.indices) assert testLoopRangesAContainedInB(rA, rB) - + self.loopRanges = rA - + self.sumIndex = self.term.indices - self.result.indices assert len(self.sumIndex) == 1 self.sumLoopRange = loopRanges(self.term, self.sumIndex)[str(self.sumIndex)] - + def generator(arch, descr, target): if target == 'cpu': return Generic(arch, descr) elif target == 'gpu': - raise RuntimeError("IndexSum operation has not been implemented for GPU-like architectures") \ No newline at end of file + raise RuntimeError("IndexSum operation has not been implemented for GPU-like architectures") diff --git a/yateto/codegen/reduction/generic.py b/yateto/codegen/reduction/generic.py index c092201..522ef21 100644 --- a/yateto/codegen/reduction/generic.py +++ b/yateto/codegen/reduction/generic.py @@ -7,11 +7,11 @@ def __init__(self, arch, descr): def generate(self, cpp, routineCache): d = self._descr - + if not d.add: writeBB = boundingBoxFromLoopRanges(d.result.indices, d.loopRanges) initializeWithZero(cpp, d.result, writeBB) - + sumIndex = d.term.indices - d.result.indices assert len(sumIndex) == 1 class IndexSumBody(object): @@ -24,7 +24,7 @@ def __call__(s): cpp( f'acc = {d.optype.callstr('acc', argstr)};' ) mult = f'{d.alpha} * ' if d.alpha != 1.0 else '' cpp( f'{target} = {mult}acc;' ) - + flop = 1 if d.alpha != 1.0 else 0 return d.sumLoopRange.size() + flop diff --git a/yateto/codegen/visitor.py b/yateto/codegen/visitor.py index b085e11..3ee650d 100644 --- a/yateto/codegen/visitor.py +++ b/yateto/codegen/visitor.py @@ -305,7 +305,7 @@ def kernelArgs(base_name_with_namespace, groups, writable, is_constant, datatype header(f'{class_name}::{container_type} {base_name};') else: header(f'{typ}{ptr_type} {base_name}{"{"}nullptr{"}"};') - + def scalarArgs(base_name_with_namespace, datatype, groups): prefix, base_name = Tensor.splitBasename(base_name_with_namespace) typ = datatype.ctype() @@ -542,11 +542,11 @@ def generate(self, cpp, namespace, testName, kernelClass, cfg, target, gemm_cfg, for i,scalar in enumerate(scalars): cpp('{} {} = {};'.format(scalar.getDatatype(self._arch).ctype(), self._tensorNameS(scalar), float(i+2))) - + for var in variables: factory.tensor(var.tensor, self._tensorName(var)) factory.temporary(self._name(var), var.memoryLayout().requiredReals(), var.datatype, iniZero=True) - + shape = var.memoryLayout().shape() cpp('{supportNS}::DenseTensorView<{dim},{datatype},{arch.uintTypename}> {viewName}({utName}, {{{shape}}}, {{{start}}}, {{{stop}}});'.format( supportNS = SUPPORT_LIBRARY_NAMESPACE, @@ -652,13 +652,14 @@ class TensorView(object): def __init__(self, datatype): self._datatype = datatype - def typename(self, dim, arch): + def typename(self, dim, arch, const): + constStr = 'true' if const else 'false' return f'::{SUPPORT_LIBRARY_NAMESPACE}::{type(self).__name__}<{dim},{self._datatype.ctype()},{arch.uintTypename},{constStr}>' - + def arguments(self, const): conststr = ' const*' if const else '*' return f'{self._datatype.ctype()}{conststr} {self.ARGUMENT_NAME}' - + def generate(cpp, group, memLayout): raise NotImplementedError @@ -670,7 +671,7 @@ def formatArray(self, numberType, name, values, declarationOnly): if declarationOnly: return '' return f'{MODIFIERS} {lhs} = {self.listToInitializerList(values)};' - + class DenseTensorView(TensorView): START_NAME = 'Start' STOP_NAME = 'Stop' @@ -748,7 +749,7 @@ def __init__(self, arch, tensors, scalars): self._groupSize = {baseName: tuple(map(lambda x: x+1, mi)) for baseName, mi in maxIndex.items()} maxIndexScalar = {baseName: tuple(map(max, *groups.keys())) if len(groups) > 1 else next(iter(groups.keys())) for baseName, groups in self._scalarCollect.items()} self._groupSizeScalar = {baseName: tuple(map(lambda x: x+1, mi)) for baseName, mi in maxIndexScalar.items()} - + def _tensorViewGenerator(self, tensor): memoryLayout = tensor.memoryLayout() memLayoutMap = { @@ -756,7 +757,7 @@ def _tensorViewGenerator(self, tensor): 'CSCMemoryLayout': self.CSCMatrixView } return memLayoutMap[type(memoryLayout).__name__](tensor.getDatatype(self._arch)) - + def iterate_collect(self): cur_namespace = '' cur_dict = collections.OrderedDict() @@ -923,13 +924,12 @@ def _init(self, cpp, baseName, baseNameWithoutNamespace, name, tensors, declarat cpp(f'{STATIC} {self._realPtrType(datatype)} {self.VALUES_BASENAME}[];') cpp.emptyline() - viewArgs = self.TensorView.arguments(False) - viewArgsConst = self.TensorView.arguments(True) if len(groupSize) == 0: prototensor = next(iter(tensors.values())) ml = prototensor.memoryLayout() tv = self._tensorViewGenerator(prototensor) - viewArgs = tv.arguments() + viewArgs = tv.arguments(False) + viewArgsConst = tv.arguments(True) with cpp.Struct(self.VIEW_STRUCT_NAME): cpp(f'using {self.VIEW_TYPE_NAME} = {tv.typename(len(ml.shape()), self._arch, False)};') cpp(f'using {self.VIEW_TYPE_NAME_CONST} = {tv.typename(len(ml.shape()), self._arch, True)};') @@ -945,7 +945,8 @@ def _init(self, cpp, baseName, baseNameWithoutNamespace, name, tensors, declarat for group,tensor in tensors.items(): ml = tensor.memoryLayout() tv = self._tensorViewGenerator(tensor) - viewArgs = tv.arguments() + viewArgs = tv.arguments(False) + viewArgsConst = tv.arguments(True) typename = tv.typename(len(ml.shape()), self._arch, False) typenameConst = tv.typename(len(ml.shape()), self._arch, True) special = ','.join(str(g) for g in group) diff --git a/yateto/controlflow/graph.py b/yateto/controlflow/graph.py index a55114a..3eb5dfb 100644 --- a/yateto/controlflow/graph.py +++ b/yateto/controlflow/graph.py @@ -205,7 +205,7 @@ def substituted(self, when, by, cond, result = True, term = True): def setVariablesWritable(self, name): self.result.setWritable(name) self.term.setWritable(name) - + def getCondition(self): if isinstance(self.condition, CNFCondition): return self.condition @@ -274,7 +274,7 @@ def __init__(self, variables): else: self.variables = variables self.fulfilled = False - + def negateVariables(self): return {var: ~self.variables[var] for var in self.variables} @@ -286,11 +286,11 @@ def unite(self, clause): if not output.fulfilled: output.variables = {**self.variables, **clause.variables} return output - + def __repr__(self): formatvar = lambda name: f'{name}' if self.variables[name] else f'~{name}' return f'[{", ".join(formatvar(var) for var in self.variables)}]' - + def ccode(self): if not self.fulfilled and len(self.variables) == 0: return 'false' @@ -298,7 +298,7 @@ def ccode(self): printvar = lambda var: f'{var}' if isinstance(var, Scalar) else f'{var}[{var.memoryLayout().addressString(Indices())}]' formatvar = lambda name: f'{printvar(name)}' if self.variables[name] else f'!{printvar(name)}' return f'({" || ".join(formatvar(var) for var in self.variables)})' - + def variableIterator(self): return (var for var in self.variables if isinstance(var, Variable)) @@ -311,7 +311,7 @@ def __init__(self, data): self.clauses = [CNFClause([])] else: self.clauses = [CNFClause([data])] - + def _prune(self): newclauses = [] for clause in self.clauses: @@ -328,14 +328,14 @@ def _prune(self): def tautology(self): return len(self.clauses) == 0 - + def unfulfillable(self): return any(not clause.fulfilled and len(clause.variables) == 0 for clause in self.clauses) def __not__(self): if self.tautology(): return CNFCondition(False) - + # this is the actually painful step (as it's also pretty inefficient right now) result = CNFCondition(True) for clause in self.clauses: @@ -345,14 +345,14 @@ def __not__(self): result = result | clauseInv return result - + def __and__(self, other): return self.__rand__(other) def __rand__(self, other): if not isinstance(other, CNFCondition): other = CNFCondition(other) - + clauses = self.clauses + other.clauses condition = CNFCondition(True) @@ -360,42 +360,42 @@ def __rand__(self, other): condition._prune() return condition - + def __or__(self, other): return self.__ror__(other) def __ror__(self, other): if not isinstance(other, CNFCondition): other = CNFCondition(other) - + clauses = [clause.unite(oclause) for clause in self.clauses for oclause in other.clauses] condition = CNFCondition(True) condition.clauses = clauses condition._prune() return condition - + def __repr__(self): return f'{self.clauses}' - + def ccode(self): if self.tautology(): return 'true' elif self.unfulfillable(): return 'false' return f'({" && ".join(clause.ccode() for clause in self.clauses)})' - + def variables(self): return {var for clause in self.clauses for var in clause.variableIterator()} class LiveSet: def __init__(self, data: dict): self.data = data - + def __sub__(self, other): if isinstance(other, dict): other = LiveSet(other) - + result = {k:self.data[k] for k in self.data} for var in other.data: @@ -405,11 +405,11 @@ def __sub__(self, other): result.remove(var) return LiveSet(result) - + def __or__(self, other): if isinstance(other, dict): other = LiveSet(other) - + result = {k:self.data[k] for k in self.data} for var in other.data: @@ -417,13 +417,13 @@ def __or__(self, other): result[var] |= other.data[var] return LiveSet(result) - + def __contains__(self, element): if isinstance(element, tuple): return element[0] in self.data and (~self.data[element[0]] & element[1]).unfulfillable() else: return element in self.data - + def variables(self): return set(k for k in self.data) diff --git a/yateto/controlflow/visitor.py b/yateto/controlflow/visitor.py index 4a34662..707b562 100644 --- a/yateto/controlflow/visitor.py +++ b/yateto/controlflow/visitor.py @@ -13,7 +13,7 @@ def __init__(self, simpleMemoryLayout=False): self._writable = set() self._simpleMemoryLayout = simpleMemoryLayout self._condition = [True] - + def cfg(self): return self._cfg + [ProgramPoint(None)] @@ -105,19 +105,19 @@ def visit_Assign(self, node): newCondition = condition & CNFCondition(myCondition) self._condition.append(newCondition) self._condition = self._condition[:-1] - + rVar = self.visit(node[1]) rhs = self._addPermuteIfRequired(node.indices, node.rightTerm(), rVar) lVar = self.visit(node[0]) action = ProgramAction(lVar, rhs, False, condition=newCondition) self._addAction(action) - + return lVar - + def visit_IndexedTensor(self, node): return Variable(node.name(), node.name() in self._writable, self._ml(node), node.eqspp(), node.tensor, datatype=node.datatype, is_temporary=node.tensor.temporary) - + def visit_IfThenElse(self, node): if len(self._condition) > 0: condition = self._condition.top() @@ -130,7 +130,7 @@ def visit_IfThenElse(self, node): self._condition.pop() self._addAction(ProgramAction()) return self.visit(node.term()) - + def _addAction(self, action): self._cfg.append(ProgramPoint(action)) diff --git a/yateto/gemm_configuration.py b/yateto/gemm_configuration.py index d111753..22ea731 100644 --- a/yateto/gemm_configuration.py +++ b/yateto/gemm_configuration.py @@ -26,7 +26,7 @@ def preference(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, ali def supported(self, m, n, k, sparseA, sparseB, transA, transB, alpha, beta, alignedA, alignedC, datatypeA, datatypeB, datatypeC, target): pass - + # shortcut for legacy reasons @classmethod def _equalType(cls, datatypeA, datatypeB, datatypeC, types=(Datatype.F32, Datatype.F64)): @@ -67,7 +67,7 @@ class MKL(BLASlike): def __init__(self, arch): self._arch = arch super().__init__('cblas', ['mkl_cblas.h']) - + def archSupported(self): return self._arch.host_name.lower() in {'snb', 'hsw', 'skx', 'knl'} or self._arch.host_name.lower().startswith('avx') diff --git a/yateto/generator.py b/yateto/generator.py index 09ebcb2..9f97d58 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -178,7 +178,7 @@ def kernels(self): def prepareUntilUnitTest(self, arch): for kernel in self._kernels.values(): kernel.prepareUntilUnitTest(arch) - + def prepareUntilCodeGen(self, costEstimator, enableFusedGemm: bool): for kernel in self._kernels.values(): kernel.prepareUntilCodeGen(costEstimator, enableFusedGemm) @@ -460,7 +460,7 @@ def unit_test_body(cpp, testFramework): cpp.include(fInit.hName) with cpp.Namespace(namespace): initGen.generateInitCpp(cpp) - + prefixnsp = lambda a: a.name if a.namespace == '' else f'{a.namespace}::{a.name}' return { 'namespace': namespace, diff --git a/yateto/metagen.py b/yateto/metagen.py index cc910d2..fc1d36d 100644 --- a/yateto/metagen.py +++ b/yateto/metagen.py @@ -95,7 +95,7 @@ def headerForward(name, data): for entry in data: self.template(header, entry, data[entry], f'{name}') - + headerForward('tensor', tensors) headerForward('init', tensors) headerForward('kernel', kernels) @@ -105,7 +105,7 @@ def cppForward(name): for gendata in self.generators: outdirname = f'metagen_{gendata["name"]}' header.include(f'{outdirname}/{name}.cpp') - + cppForward('tensor') cppForward('init') cppForward('kernel') @@ -131,12 +131,12 @@ def inner(): templatetypes = ', '.join(f'{typ} Arg{i}' for i, typ in enumerate(self.templateType)) templateargs = ', '.join(f'Arg{i}' for i, _ in enumerate(self.templateType)) - + with header.Namespace('internal'): header(f'template<{templatetypes}> struct {internalName} {"{"} using Type = void; {"}"};') for gnsp, spec in foundin: spectext = ', '.join(str(specpart) for specpart in spec) header(f'template<> struct {internalName}<{spectext}> {"{"} using Type = ::{gnsp}::{fullname}; {"}"};') header(f'template<{templatetypes}> using {name} = typename internal::{internalName}<{templateargs}>::Type;') - + self.namespacing(header, splitname[:-1] + [subnsp], inner) diff --git a/yateto/ops.py b/yateto/ops.py index 2167c24..825a05b 100644 --- a/yateto/ops.py +++ b/yateto/ops.py @@ -4,16 +4,16 @@ class Operation: #def callstr(self, *args) -> str: # raise NotImplementedError() - + def call(self, *args): raise NotImplementedError() - + def datatypeResult(self, argtypes): raise NotImplementedError() # TODO - + def __str__(self): return type(self).__name__ - + def __eq__(self, other): # we're more or less using "dummy" types here return type(self).__name__ == type(other).__name__ @@ -35,21 +35,21 @@ class BinaryArgsMixin: class CFunctionMixin: def cppname(self) -> str: raise NotImplementedError() - + def callstr(self, *args) -> str: return f'{self.cppname()}({", ".join(str(arg) for arg in args)})' class CUnaryOperatorMixin: def cppname(self) -> str: raise NotImplementedError() - + def callstr(self, *args) -> str: return f'{self.cppname()}({args[0]})' class CBinaryOperatorMixin: def cppname(self) -> str: raise NotImplementedError() - + def callstr(self, *args) -> str: return f'({args[0]}) {self.cppname()} ({args[1]})' diff --git a/yateto/type.py b/yateto/type.py index 1341dfb..ed6505c 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -45,7 +45,7 @@ def ctype(self): Datatype.BF16: 'yateto::bf16_ty', Datatype.F128: 'yateto::f128_ty', }[self] - + def nptype(self): return { Datatype.BOOL: np.bool, @@ -59,7 +59,7 @@ def nptype(self): Datatype.BF16: np.float32, # NYI Datatype.F128: np.float128, }[self] - + def size(self): # unpacked size return { @@ -74,7 +74,7 @@ def size(self): Datatype.BF16: 2, Datatype.F128: 16, }[self] - + def safeint(self, value): # allow inf/-inf to be treated as int return int(max(-2**64, min(2**64, value))) @@ -112,7 +112,7 @@ class Symbol(object): def __init__(self, datatype): # datatype == None is treated as datatype == arch.datatype self.datatype = datatype - + def getDatatype(self, arch): return arch.datatype if self.datatype is None else self.datatype @@ -128,7 +128,7 @@ class AbstractType(Symbol): def __init__(self, name, datatype): super().__init__(datatype) self._name = name - + @classmethod def isValidName(cls, name): return re.match(cls.VALID_NAME, name) is not None @@ -194,10 +194,10 @@ def nameWithNamespace(self): def __hash__(self): return hash(self._name) -class Scalar(IdentifiedType, ScalarMixin): +class Scalar(IdentifiedType, ScalarMixin): def __init__(self, name, namespace=None, datatype=None): super().__init__(name, namespace=namespace, datatype=datatype) - + def __hash__(self): return hash(self._name)