Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
554cc99
Fix pack test on rvv bool
ZenusZhang Jun 26, 2025
bd39ac2
unpack_generator initialized
ZenusZhang Jun 23, 2025
104a904
Add unpack ctest generator
ZenusZhang Jun 25, 2025
2d7d2a8
Added Unpack test generator
ZenusZhang Jun 26, 2025
0448151
First version done of ctest cast
ZenusZhang Jul 3, 2025
a75552e
binary test generator 50%
ZenusZhang Jul 7, 2025
5bcd7bb
Test for binary normal case done
ZenusZhang Jul 8, 2025
acf7e60
Refactor the ctest dir structrue
ZenusZhang Jul 15, 2025
6b015b9
binary add passed
ZenusZhang Jul 15, 2025
adf4151
Binary add, pack, unpack passed at rv and x86
ZenusZhang Jul 16, 2025
060f036
Binary add, sub, mul, div passed on riscv and x86
ZenusZhang Jul 16, 2025
4c11875
Binary mod passed on x86 and rv
ZenusZhang Jul 17, 2025
f927ad1
Rewrite mod for fp8
ZenusZhang Jul 17, 2025
08c0681
Passed int32 floor_mod by bypassing avx fllor_mod
ZenusZhang Jul 21, 2025
bdc312e
Floor mod passed on almost all types on x86 except float16
ZenusZhang Jul 22, 2025
47a46af
passed floor mod on x86
ZenusZhang Jul 22, 2025
7d25b16
RVV bug in qemu or compiler
ZenusZhang Jul 23, 2025
0758c0d
Passed all ceil_div on x86
ZenusZhang Jul 24, 2025
26b2255
Passed power on x86
ZenusZhang Jul 25, 2025
8aa7194
RVV POW smoke test passed
ZenusZhang Jul 29, 2025
794ddef
RVV power float32 passed
ZenusZhang Jul 30, 2025
5b5b5da
Passed binary test on RVV platform
ZenusZhang Jul 31, 2025
2843239
Passed swishb on x86 execept uint8 and int8
ZenusZhang Jul 31, 2025
05e72a0
swishb passed on x86 and riscv
ZenusZhang Aug 1, 2025
b7ec6fb
Passed inner_product on x86 and riscv
ZenusZhang Aug 4, 2025
5be6b0d
Passed outer_product on x86
ZenusZhang Aug 4, 2025
fb3158b
merge prepare, passed on x86
ZenusZhang Aug 5, 2025
19d161a
Reopen rvv intrinsic of floor mod, Pass on rvv
ZenusZhang Aug 5, 2025
46c32c6
prepare for merge
ZenusZhang Aug 5, 2025
0b3216a
Add ulp in compare tensor
ZenusZhang Aug 5, 2025
af80865
Passed all uint32 cast except fp8
ZenusZhang Aug 8, 2025
c375fb6
cast done for normal case, but seems to work wrong for values out of…
ZenusZhang Aug 12, 2025
8655a20
passed all ctests on x86
ZenusZhang Aug 15, 2025
3b72956
Passed on X86, Failed on RISC-V becaust the cast behavior
ZenusZhang Aug 18, 2025
d10a29f
Passed sub uint32 on RISC-V
ZenusZhang Aug 18, 2025
7d613b5
Fix conflict
ZenusZhang Aug 18, 2025
03f55fe
commit for merge
ZenusZhang Aug 18, 2025
044b3d7
Fix CI bug
ZenusZhang Aug 19, 2025
d460c6a
Add ulp support for non-formal float types
ZenusZhang Aug 20, 2025
371e2b8
Add non-formal float type support
ZenusZhang Aug 20, 2025
ee5d1c6
merge from origin
ZenusZhang Aug 20, 2025
4d1fa9d
Fix the ambigious cast for fp16 and bf16
ZenusZhang Aug 21, 2025
31603d3
Add cast for bf16, fp16, and fix the fp16 rvv mod
ZenusZhang Aug 22, 2025
56666e2
Add float cast to bf16
ZenusZhang Aug 22, 2025
c23ed2a
Add more cast functions to fp16 and bf16
ZenusZhang Aug 22, 2025
ac70ee9
fix compare error
ZenusZhang Aug 22, 2025
6819f27
temp to fix the CI bug
ZenusZhang Aug 25, 2025
5a9ab88
Initial plan
Copilot Aug 25, 2025
35ce968
Fix CI auditwheel error by using static linking and excluding problem…
Copilot Aug 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/runtime-build.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: runtime-build

on: [ push, pull_request ]
on:
push:
paths:
- 'ntt/**'
pull_request:
paths:
- 'ntt/**'

concurrency:
group: runtime-build-${{ github.ref }}
Expand Down
9 changes: 8 additions & 1 deletion conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,16 @@ def configure(self):
if not self.options.runtime:
if self.settings.os == 'Windows' and self.settings.build_type == 'Debug':
self.options["nethost"].shared = True
else:
# For Linux and other platforms, use static linking to avoid auditwheel issues
self.options["nethost"].shared = False

# Configure fmt to be static for Linux builds to avoid auditwheel issues
if self.settings.os == 'Linux':
self.options["fmt"].shared = False

if self.options.tests:
self.options["ortki"].shared = True
self.options["ortki"].shared = False
self.options["date"].header_only = True

def validate(self):
Expand Down
2 changes: 1 addition & 1 deletion ntt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ if(BUILD_TESTING)
endif()

if(BUILD_BENCHMARK)
add_subdirectory(test/benchmark_test)
# add_subdirectory(test/benchmark_test)
endif()
110 changes: 91 additions & 19 deletions ntt/include/nncase/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ struct bfloat16 {
constexpr operator __bf16() const noexcept {
return std::bit_cast<__bf16>(value_);
}
// #else
// constexpr operator float() const noexcept {
// uint32_t value = raw() << 16;
// return std::bit_cast<float>(value);
// }

#endif

constexpr bfloat16() noexcept = default;
Expand All @@ -53,25 +59,6 @@ struct bfloat16 {
constexpr explicit bfloat16(const T &v) noexcept
: value_(round_to_bfloat16(v).value_) {}

constexpr bfloat16(from_raw_t, uint16_t value) noexcept : value_(value) {}

constexpr operator float() const noexcept {
uint32_t value = raw() << 16;
return std::bit_cast<float>(value);
}

constexpr uint16_t raw() const noexcept { return value_; }

static constexpr bfloat16 from_raw(uint16_t v) noexcept {
return bfloat16(nncase::from_raw, v);
}

static constexpr bfloat16 truncate_to_bfloat16(float v) noexcept {
return !std::isnan(v) ? from_raw(static_cast<uint16_t>(
std::bit_cast<uint32_t>(v) >> 16))
: nan();
}

// Converts a float point to bfloat16, with round-nearest-to-even as
// rounding method.
static constexpr bfloat16 round_to_bfloat16(float v) {
Expand All @@ -93,6 +80,90 @@ struct bfloat16 {
}
}

// Integer conversion constructors
constexpr explicit bfloat16(int x) noexcept
: value_(round_to_bfloat16(float(x)).value_) {}

constexpr explicit bfloat16(int64_t x) noexcept
: value_(round_to_bfloat16(float(x)).value_) {}

constexpr explicit bfloat16(uint32_t x) noexcept
: value_(round_to_bfloat16(float(x)).value_) {}

constexpr explicit bfloat16(uint64_t x) noexcept
: value_(round_to_bfloat16(double(x)).value_) {}

constexpr explicit bfloat16(float x) noexcept
: value_(round_to_bfloat16((x)).value_) {}
// Floating point conversion constructors
constexpr explicit bfloat16(double x) noexcept
: value_(round_to_bfloat16(float(x)).value_) {}

constexpr bfloat16(from_raw_t, uint16_t value) noexcept : value_(value) {}

constexpr operator float() const noexcept {
uint32_t value = raw() << 16;
return std::bit_cast<float>(value);
}

constexpr uint16_t raw() const noexcept { return value_; }

static constexpr bfloat16 from_raw(uint16_t v) noexcept {
return bfloat16(nncase::from_raw, v);
}

// Type conversion operators
constexpr explicit operator double() const noexcept {
return double(float(*this));
}

constexpr explicit operator int() const noexcept {
return int(float(*this));
}

constexpr explicit operator int64_t() const noexcept {
return int64_t(float(*this));
}

constexpr explicit operator uint32_t() const noexcept {
return uint32_t(float(*this));
}

constexpr explicit operator uint64_t() const noexcept {
return uint64_t(double(*this));
}


constexpr explicit operator uint8_t() const noexcept {
return uint8_t(float(*this));
}

constexpr explicit operator int8_t() const noexcept {
return int8_t(float(*this));
}


constexpr explicit operator int16_t() const noexcept {
return int16_t(float(*this));
}

constexpr explicit operator uint16_t() const noexcept {
return uint16_t(float(*this));
}


constexpr explicit operator bool() const noexcept {
return bool(std::bit_cast<uint16_t>(*this));
}

static constexpr bfloat16 truncate_to_bfloat16(float v) noexcept {
return !std::isnan(v) ? from_raw(static_cast<uint16_t>(
std::bit_cast<uint32_t>(v) >> 16))
: nan();
}



static constexpr bfloat16 epsilon() noexcept {
// 0x1.0p-7
return from_raw(0x3c00);
Expand Down Expand Up @@ -297,3 +368,4 @@ template <> struct is_arithmetic<bfloat16> : public true_type {};
inline nncase::bfloat16 operator"" _bf16(long double x) {
return nncase::bfloat16(float(x));
}

27 changes: 20 additions & 7 deletions ntt/include/nncase/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
// #include "nncase/nncase.h"
#include "bfloat16.h"
#include "half.h"
#include "bfloat16.h"
#ifndef CUTLASS_HOST_DEVICE
#define CUTLASS_HOST_DEVICE inline
#define CUTLASS_DEVICE inline
Expand Down Expand Up @@ -493,9 +492,6 @@ struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(float x) { storage = from_float(x).storage; }

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(bfloat16 x) : float_e4m3_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(half x) { storage = from_half(x).storage; }

Expand All @@ -508,7 +504,17 @@ struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
explicit float_e4m3_t(int x) : float_e4m3_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(size_t x) : float_e4m3_t(float(x)) {}
explicit float_e4m3_t(int64_t x) : float_e4m3_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(bfloat16 x) : float_e4m3_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(uint64_t x) : float_e4m3_t(double(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(uint32_t x) : float_e4m3_t(float(x)) {}


/// E5M2 conversion. Defined after float_e5m2_t is defined.
CUTLASS_HOST_DEVICE
Expand Down Expand Up @@ -704,11 +710,17 @@ struct alignas(1) float_e5m2_t : float8_base<FloatEncoding::E5M2> {
explicit float_e5m2_t(int x) : float_e5m2_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(size_t x) : float_e5m2_t(float(x)) {}
explicit float_e5m2_t(uint64_t x) : float_e5m2_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(bfloat16 x) : float_e5m2_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(int64_t x) : float_e5m2_t(float(x)) {}

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(uint32_t x) : float_e5m2_t(float(x)) {}

/// E4M3 conversion
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(float_e4m3_t x);
Expand Down Expand Up @@ -1025,7 +1037,8 @@ half operator*(float_e5m2_t const &lhs, float_e4m3_t const &rhs) {
return half(float(lhs) * float(rhs));
}

///////////////////////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////////////////////////////////////////////////////////////////
//
// float_e4m3_t <=> float_e5m2_t conversions
//
Expand Down
107 changes: 86 additions & 21 deletions ntt/include/nncase/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,62 @@ struct half {
constexpr explicit half(const T &v) noexcept
: value_(round_to_half(v).value_) {}

constexpr half(fp16_from_raw_t, uint16_t value) noexcept
: value_(std::bit_cast<_Float16>(value)) {}

constexpr operator _Float16() const noexcept { return value_; }
constexpr operator float() const noexcept {
static constexpr half round_to_half(float v) {
if (std::is_constant_evaluated()) {
return (float)value_;
return (_Float16)v;
} else {
#ifdef __F16C__
// To avoid extendhfdf2
return _cvtsh_ss(raw());
// To avoid truncsfhf2
return from_raw(_cvtss_sh(v, _MM_FROUND_NEARBYINT));
#else
return (float)value_;
return (_Float16)v;
#endif
}

return (_Float16)v;
}

static constexpr half epsilon() noexcept { return from_raw(0x0800); }

// Integer conversion constructors
constexpr explicit half(int x) noexcept
: value_(round_to_half(float(x)).value_) {}

constexpr explicit half(int64_t x) noexcept
: value_(round_to_half(float(x)).value_) {}

constexpr explicit half(uint32_t x) noexcept
: value_(round_to_half(float(x)).value_) {}

constexpr explicit half(uint64_t x) noexcept
: value_(round_to_half(double(x)).value_) {}

// Floating point conversion constructors
constexpr explicit half(double x) noexcept
: value_(round_to_half(float(x)).value_) {}

// bfloat16 conversion constructor
constexpr explicit half(bfloat16 x) noexcept
: value_(round_to_half(float(x)).value_) {}

constexpr half(fp16_from_raw_t, uint16_t value) noexcept
: value_(std::bit_cast<_Float16>(value)) {}

constexpr operator _Float16() const noexcept { return value_; }
// constexpr operator float() const noexcept {
// if (std::is_constant_evaluated()) {
// return (float)value_;
// } else {
// #ifdef __F16C__
// // To avoid extendhfdf2
// return _cvtsh_ss(raw());
// #else
// return (float)value_;
// #endif
// }
// }

constexpr uint16_t raw() const noexcept {
return std::bit_cast<uint16_t>(value_);
}
Expand All @@ -77,22 +116,48 @@ struct half {
return half(nncase::fp16_from_raw, v);
}

static constexpr half round_to_half(float v) {
if (std::is_constant_evaluated()) {
return (_Float16)v;
} else {
#ifdef __F16C__
// To avoid truncsfhf2
return from_raw(_cvtss_sh(v, _MM_FROUND_NEARBYINT));
#else
return (_Float16)v;
#endif
}
// Type conversion operators
constexpr explicit operator double() const noexcept {
return double(float(*this));
}

return (_Float16)v;
constexpr explicit operator int8_t() const noexcept {
return int(float(*this));
}

static constexpr half epsilon() noexcept { return from_raw(0x0800); }
constexpr explicit operator uint8_t() const noexcept {
return int(float(*this));
}


constexpr explicit operator int16_t() const noexcept {
return int(float(*this));
}


constexpr explicit operator uint16_t() const noexcept {
return int(float(*this));
}

constexpr explicit operator int() const noexcept {
return int(float(*this));
}

constexpr explicit operator int64_t() const noexcept {
return int64_t(float(*this));
}

constexpr explicit operator uint32_t() const noexcept {
return uint32_t(float(*this));
}

constexpr explicit operator uint64_t() const noexcept {
return uint64_t(double(*this));
}

constexpr explicit operator bool() const noexcept {
return bool(std::bit_cast<uint16_t>(*this));
}

static constexpr half highest() noexcept { return from_raw(0x7bff); }

Expand Down
Loading
Loading