Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4124a58
Add a tensor datatype, introduce tensor addressing
davschneller Jan 31, 2025
3dbe6e9
Start propagating the datatype to the AST
davschneller Jan 31, 2025
3193d89
Begin adjusting the TensorDescriptions
davschneller Jan 31, 2025
b173856
Start adding (elementwise) nonlinear function support
davschneller Apr 16, 2025
075853d
Fix bugs; propagate datatype
davschneller Apr 16, 2025
ed03770
Merge remote-tracking branch 'origin/master' into davschneller/dataty…
davschneller Apr 16, 2025
389cf7a
Merge remote-tracking branch 'origin/davschneller/datatyping' into da…
davschneller Apr 16, 2025
1975753
Fix build
davschneller Apr 16, 2025
4c0fdb6
Fix bugs; propagate datatypes
davschneller Apr 17, 2025
8c5afc4
Add primitive conditional execution
davschneller Apr 17, 2025
52c0fb3
Continue the datatype and elementwise propagation
davschneller Apr 24, 2025
2db4cb9
Merge remote-tracking branch 'origin/master' into davschneller/nonlin…
davschneller Aug 12, 2025
a215877
Merge remote-tracking branch 'origin/davschneller/cuda-gpu-tests' int…
davschneller Aug 12, 2025
96e66ed
Fix a post-merge bug
davschneller Aug 12, 2025
8fa3d00
Ease the allocation initialization
davschneller Aug 14, 2025
7b0f534
Begin updating the Yateto<->TensorForge interface
davschneller Aug 23, 2025
9fcc365
Merge remote-tracking branch 'origin/master' into davschneller/nonlin…
davschneller Sep 10, 2025
aa421b4
Merge remote-tracking branch 'origin/master' into davschneller/nonlin…
davschneller Nov 22, 2025
6ba8b46
Improve FP format support
davschneller Nov 22, 2025
6d40823
Merge remote-tracking branch 'origin/master' into davschneller/nonlin…
davschneller Dec 9, 2025
ade9f62
Merge remote-tracking branch 'origin/davschneller/slicing' into davsc…
davschneller Dec 9, 2025
725b594
Merge remote-tracking branch 'origin/master' into davschneller/nonlin…
davschneller Dec 15, 2025
a7df3ee
Tests and bugfixes
davschneller Dec 22, 2025
f2afc90
Adjust index permutation computation for the new ops
davschneller Dec 22, 2025
3ff353a
Fix codegen
davschneller Dec 22, 2025
4a90748
Add some more tests
davschneller Feb 21, 2026
e650530
Merge remote-tracking branch 'origin/master' into davschneller/nonlin…
davschneller Feb 21, 2026
4cf025c
Fix CI
davschneller Feb 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/yateto-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
- name: Codegen Tests
run: |
cd ./tests/code-gen
for example in matmul minimal indices slicing; do
for example in matmul minimal indices slicing elementwise reduction datatype conditional; do
for build_type in Debug Release; do
for precision in single double; do
echo " ====== Test Config: ======"
Expand Down
1 change: 1 addition & 0 deletions include/yateto.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
#include "yateto/LinearAllocator.h"
#include "yateto/Misc.h"
#include "yateto/TensorView.h"
#include "yateto/Type.h"

#endif
6 changes: 6 additions & 0 deletions include/yateto/LinearAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ struct LinearAllocatorT {
userSpaceMem = ptr;
}

template <typename S>
void initialize(S* ptr) {
isInit = true;
userSpaceMem = reinterpret_cast<T*>(ptr);
}

T* allocate(size_t size) {
assert(isInit && "YATETO: Temporary-Memory manager hasn't been initialized");
int currentByteCount = byteCount;
Expand Down
39 changes: 39 additions & 0 deletions include/yateto/Type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef YATETO_TYPE_H_
#define YATETO_TYPE_H_

#include <cstddef>

// C++23 include
#if __has_include(<stdfloat>)
#include <stdfloat>
#endif

// cf. https://stackoverflow.com/a/70868019
#define __STDC_WANT_IEC_60559_TYPES_EXT__
#include <cfloat>

namespace yateto {

#ifdef __STDCPP_FLOAT128_T__
using f128_ty = std::float128_t;
#elif defined(FLT128_MIN)
using f128_ty = _Float128;
#else
using f128_ty = __float128;
#endif
#ifdef __STDCPP_FLOAT16_T__
using f16_ty = std::float16_t;
#elif defined(FLT16_MIN)
using f16_ty = _Float16;
#else
using f16_ty = __fp16;
#endif
#ifdef __STDCPP_BFLOAT16_T__
using bf16_ty = std::bfloat16_t;
#else
using bf16_ty = __bf16;
#endif

} // namespace yateto

#endif // YATETO_TYPE_H_
44 changes: 44 additions & 0 deletions tests/code-gen/conditional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3

from yateto import *

import yateto.functions as yf

def add(g):
N = 8
A = Tensor('A', (N, N))
B = Tensor('B', (N, N))
C = Tensor('C', (N, N))

AI = Tensor('AI', (N, N), datatype=Datatype.I32)
BI = Tensor('BI', (N, N), datatype=Datatype.I32)
CI = Tensor('CI', (N, N), datatype=Datatype.I32)

AB = Tensor('AB', (N, N), datatype=Datatype.BOOL)

X = Tensor('X', (), datatype=Datatype.BOOL)
X1 = Tensor('X1', (), datatype=Datatype.BOOL)
X2 = Tensor('X2', (), datatype=Datatype.BOOL)
X3 = Tensor('X3', (), datatype=Datatype.BOOL)

class Counter:
def __init__(self):
self.counter = 0

counter = Counter()

def _(kernel):
counter.counter += 1
g.add(f'kernel{counter.counter}', kernel)

_(yf.assignIf(X[''], A['ij'], yf.sqrt(B['ij'])))
_(yf.assignIf(yf.all(AB['ij'], 'ij'), A['ij'], yf.sqrt(B['ij'])))
_([
yf.assignIf(X[''], A['ij'], yf.sqrt(B['ij'])),
yf.assignIf(X[''], AI['ij'], -BI['ij'])
])
_([
yf.assignIf(X1[''], A['ij'], B['ik'] * C['kj'] + C['ij']),
yf.assignIf(X1[''], A['ij'], A['ij'] + B['ik'] * C['kj'] + C['ij']),
yf.assignIf(X2[''], C['ij'], yf.sqrt(B['ij']))
])
29 changes: 29 additions & 0 deletions tests/code-gen/datatype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3

from yateto import *

import yateto.functions as yf

def add(g):
N = 8
A = Tensor('A', (N, N))
B = Tensor('B', (N, N))
C = Tensor('C', (N, N))

AI = Tensor('AI', (N, N), datatype=Datatype.I32)
BI = Tensor('BI', (N, N), datatype=Datatype.I32)
CI = Tensor('CI', (N, N), datatype=Datatype.I32)

AB = Tensor('AB', (N, N), datatype=Datatype.BOOL)

class Counter:
def __init__(self):
self.counter = 0

counter = Counter()

def _(kernel):
counter.counter += 1
g.add(f'kernel{counter.counter}', kernel)

_(AI['ij'] <= yf.cast(A['ij'], Datatype.I32))
41 changes: 41 additions & 0 deletions tests/code-gen/elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

from yateto import *

import yateto.functions as yf

def add(g):
N = 8
A = Tensor('A', (N, N))
B = Tensor('B', (N, N))
C = Tensor('C', (N, N))

AI = Tensor('AI', (N, N), datatype=Datatype.I32)
BI = Tensor('BI', (N, N), datatype=Datatype.I32)
CI = Tensor('CI', (N, N), datatype=Datatype.I32)

AB = Tensor('AB', (N, N), datatype=Datatype.BOOL)

class Counter:
def __init__(self):
self.counter = 0

counter = Counter()

def _(kernel):
counter.counter += 1
g.add(f'kernel{counter.counter}', kernel)

_(A['ij'] <= yf.sqrt(B['ij']))
_(A['ij'] <= yf.sqrt(B['ij']) + yf.sin(C['ij']))
_(A['ij'] <= yf.sqrt(B['ij']) * yf.sin(C['ij']))
_(A['ij'] <= yf.minimum(B['ij'], C['ij']))
_(A['ij'] <= yf.minimum(B['ij'], C['ij'] + yf.atanh(B['ij'])))

_(AI['ij'] <= BI['ij'] + CI['ij'])
_(AI['ij'] <= yf.bitwise_and(BI['ij'], CI['ij']))

_(AB['ij'] <= yf.greater_equal(BI['ij'], CI['ij']))
_(A['ij'] <= yf.where(yf.greater_equal(BI['ij'], CI['ij']), B['ij'], C['ij']))

_(AI['ij'] <= yf.cast(A['ij'], Datatype.I32))
41 changes: 41 additions & 0 deletions tests/code-gen/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

from yateto import *

import yateto.functions as yf

def add(g):
N = 8
A0 = Tensor('A0', ())
A1 = Tensor('A1', (N,))
A2 = Tensor('A2', (N, N))

AI0 = Tensor('AI0', (), datatype=Datatype.I32)
AI1 = Tensor('AI1', (N,), datatype=Datatype.I32)
AI2 = Tensor('AI2', (N, N), datatype=Datatype.I32)

AB0 = Tensor('AB0', (), datatype=Datatype.BOOL)
AB1 = Tensor('AB1', (N,), datatype=Datatype.BOOL)
AB2 = Tensor('AB2', (N, N), datatype=Datatype.BOOL)

class Counter:
def __init__(self):
self.counter = 0

counter = Counter()

def _(kernel):
counter.counter += 1
g.add(f'kernel{counter.counter}', kernel)

_(A0[''] <= yf.sum(A1['i'], 'i'))
_(A0[''] <= yf.sum(A2['ij'], 'ij'))
_(A0[''] <= yf.min(A2['ij'], 'ij'))

_(AI0[''] <= yf.sum(AI2['ij'], 'ij'))
_(AI0[''] <= yf.min(AI2['ij'], 'ij'))
_(AI0[''] <= yf.all(AI2['ij'], 'ij'))
_(AI0[''] <= yf.any(AI2['ij'], 'ij'))

_(AB0[''] <= yf.all(AB2['ij'], 'ij'))
_(AB0[''] <= yf.any(AB2['ij'], 'ij'))
26 changes: 14 additions & 12 deletions yateto/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#

from .memory import DenseMemoryLayout
from .type import Datatype
from collections import namedtuple
from typing import Union
import re
Expand Down Expand Up @@ -69,19 +70,20 @@ def __init__(self,
self.host_name = host_name

self.precision = precision.upper()
if self.precision == 'D':
self.bytesPerReal = 8
self.typename = 'double'
self.epsilon = 2.22e-16
if self.precision == 'Q':
self.epsilon = 2**-112
self.datatype = Datatype.F128
elif self.precision == 'D':
self.epsilon = 2**-52
self.datatype = Datatype.F64
elif self.precision == 'S':
self.bytesPerReal = 4
self.typename = 'float'
self.epsilon = 1.19e-7
self.epsilon = 2**-23
self.datatype = Datatype.F32
else:
raise ValueError(f'Unknown precision type {self.precision}')
self.alignment = alignment
assert self.alignment % self.bytesPerReal == 0
self.alignedReals = self.alignment // self.bytesPerReal
assert self.alignment % self.datatype.size() == 0
self.alignedReals = self.alignment // self.datatype.size()
self.enablePrefetch = enablePrefetch

self.uintTypename = 'unsigned'
Expand Down Expand Up @@ -110,10 +112,10 @@ def checkAlignment(self, offset):
return offset % self.alignedReals == 0

def formatConstant(self, constant):
return str(constant) + ('f' if self.precision == 'S' else '')
return self.datatype.literal(constant)

def onHeap(self, numReals):
return (numReals * self.bytesPerReal) > self._tmpStackLimit
def onHeap(self, byteCount):
return byteCount > self._tmpStackLimit

def __eq__(self, other):
return self.name == other.name
Expand Down
Loading
Loading