diff --git a/.github/workflows/runtime-build.yml b/.github/workflows/runtime-build.yml index 2a79eaef9e..5b4c7faefd 100644 --- a/.github/workflows/runtime-build.yml +++ b/.github/workflows/runtime-build.yml @@ -1,6 +1,12 @@ name: runtime-build -on: [ push, pull_request ] +on: + push: + paths: + - 'ntt/**' + pull_request: + paths: + - 'ntt/**' concurrency: group: runtime-build-${{ github.ref }} diff --git a/conanfile.py b/conanfile.py index b203641206..0aa78bd093 100644 --- a/conanfile.py +++ b/conanfile.py @@ -73,9 +73,16 @@ def configure(self): if not self.options.runtime: if self.settings.os == 'Windows' and self.settings.build_type == 'Debug': self.options["nethost"].shared = True + else: + # For Linux and other platforms, use static linking to avoid auditwheel issues + self.options["nethost"].shared = False + + # Configure fmt to be static for Linux builds to avoid auditwheel issues + if self.settings.os == 'Linux': + self.options["fmt"].shared = False if self.options.tests: - self.options["ortki"].shared = True + self.options["ortki"].shared = False self.options["date"].header_only = True def validate(self): diff --git a/ntt/CMakeLists.txt b/ntt/CMakeLists.txt index 1de56acea3..4854869b80 100644 --- a/ntt/CMakeLists.txt +++ b/ntt/CMakeLists.txt @@ -7,5 +7,5 @@ if(BUILD_TESTING) endif() if(BUILD_BENCHMARK) - add_subdirectory(test/benchmark_test) + # add_subdirectory(test/benchmark_test) endif() diff --git a/ntt/include/nncase/bfloat16.h b/ntt/include/nncase/bfloat16.h index 88c3750603..8bb9f63c5d 100644 --- a/ntt/include/nncase/bfloat16.h +++ b/ntt/include/nncase/bfloat16.h @@ -43,6 +43,12 @@ struct bfloat16 { constexpr operator __bf16() const noexcept { return std::bit_cast<__bf16>(value_); } +// #else +// constexpr operator float() const noexcept { +// uint32_t value = raw() << 16; +// return std::bit_cast(value); +// } + #endif constexpr bfloat16() noexcept = default; @@ -53,25 +59,6 @@ struct bfloat16 { constexpr explicit bfloat16(const T &v) noexcept : value_(round_to_bfloat16(v).value_) {} - constexpr bfloat16(from_raw_t, uint16_t value) noexcept : value_(value) {} - - constexpr operator float() const noexcept { - uint32_t value = raw() << 16; - return std::bit_cast(value); - } - - constexpr uint16_t raw() const noexcept { return value_; } - - static constexpr bfloat16 from_raw(uint16_t v) noexcept { - return bfloat16(nncase::from_raw, v); - } - - static constexpr bfloat16 truncate_to_bfloat16(float v) noexcept { - return !std::isnan(v) ? from_raw(static_cast( - std::bit_cast(v) >> 16)) - : nan(); - } - // Converts a float point to bfloat16, with round-nearest-to-even as // rounding method. static constexpr bfloat16 round_to_bfloat16(float v) { @@ -93,6 +80,90 @@ struct bfloat16 { } } + // Integer conversion constructors + constexpr explicit bfloat16(int x) noexcept + : value_(round_to_bfloat16(float(x)).value_) {} + + constexpr explicit bfloat16(int64_t x) noexcept + : value_(round_to_bfloat16(float(x)).value_) {} + + constexpr explicit bfloat16(uint32_t x) noexcept + : value_(round_to_bfloat16(float(x)).value_) {} + + constexpr explicit bfloat16(uint64_t x) noexcept + : value_(round_to_bfloat16(double(x)).value_) {} + + constexpr explicit bfloat16(float x) noexcept + : value_(round_to_bfloat16((x)).value_) {} + // Floating point conversion constructors + constexpr explicit bfloat16(double x) noexcept + : value_(round_to_bfloat16(float(x)).value_) {} + + constexpr bfloat16(from_raw_t, uint16_t value) noexcept : value_(value) {} + + constexpr operator float() const noexcept { + uint32_t value = raw() << 16; + return std::bit_cast(value); + } + + constexpr uint16_t raw() const noexcept { return value_; } + + static constexpr bfloat16 from_raw(uint16_t v) noexcept { + return bfloat16(nncase::from_raw, v); + } + + // Type conversion operators + constexpr explicit operator double() const noexcept { + return double(float(*this)); + } + + constexpr explicit operator int() const noexcept { + return int(float(*this)); + } + + constexpr explicit operator int64_t() const noexcept { + return int64_t(float(*this)); + } + + constexpr explicit operator uint32_t() const noexcept { + return uint32_t(float(*this)); + } + + constexpr explicit operator uint64_t() const noexcept { + return uint64_t(double(*this)); + } + + + constexpr explicit operator uint8_t() const noexcept { + return uint8_t(float(*this)); + } + + constexpr explicit operator int8_t() const noexcept { + return int8_t(float(*this)); + } + + + constexpr explicit operator int16_t() const noexcept { + return int16_t(float(*this)); + } + + constexpr explicit operator uint16_t() const noexcept { + return uint16_t(float(*this)); + } + + + constexpr explicit operator bool() const noexcept { + return bool(std::bit_cast(*this)); + } + + static constexpr bfloat16 truncate_to_bfloat16(float v) noexcept { + return !std::isnan(v) ? from_raw(static_cast( + std::bit_cast(v) >> 16)) + : nan(); + } + + + static constexpr bfloat16 epsilon() noexcept { // 0x1.0p-7 return from_raw(0x3c00); @@ -297,3 +368,4 @@ template <> struct is_arithmetic : public true_type {}; inline nncase::bfloat16 operator"" _bf16(long double x) { return nncase::bfloat16(float(x)); } + diff --git a/ntt/include/nncase/float8.h b/ntt/include/nncase/float8.h index 3cc06ed154..d262988bfd 100644 --- a/ntt/include/nncase/float8.h +++ b/ntt/include/nncase/float8.h @@ -79,7 +79,6 @@ // #include "nncase/nncase.h" #include "bfloat16.h" #include "half.h" -#include "bfloat16.h" #ifndef CUTLASS_HOST_DEVICE #define CUTLASS_HOST_DEVICE inline #define CUTLASS_DEVICE inline @@ -493,9 +492,6 @@ struct alignas(1) float_e4m3_t : float8_base { CUTLASS_HOST_DEVICE explicit float_e4m3_t(float x) { storage = from_float(x).storage; } - CUTLASS_HOST_DEVICE - explicit float_e4m3_t(bfloat16 x) : float_e4m3_t(float(x)) {} - CUTLASS_HOST_DEVICE explicit float_e4m3_t(half x) { storage = from_half(x).storage; } @@ -508,7 +504,17 @@ struct alignas(1) float_e4m3_t : float8_base { explicit float_e4m3_t(int x) : float_e4m3_t(float(x)) {} CUTLASS_HOST_DEVICE - explicit float_e4m3_t(size_t x) : float_e4m3_t(float(x)) {} + explicit float_e4m3_t(int64_t x) : float_e4m3_t(float(x)) {} + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(bfloat16 x) : float_e4m3_t(float(x)) {} + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(uint64_t x) : float_e4m3_t(double(x)) {} + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(uint32_t x) : float_e4m3_t(float(x)) {} + /// E5M2 conversion. Defined after float_e5m2_t is defined. CUTLASS_HOST_DEVICE @@ -704,11 +710,17 @@ struct alignas(1) float_e5m2_t : float8_base { explicit float_e5m2_t(int x) : float_e5m2_t(float(x)) {} CUTLASS_HOST_DEVICE - explicit float_e5m2_t(size_t x) : float_e5m2_t(float(x)) {} + explicit float_e5m2_t(uint64_t x) : float_e5m2_t(float(x)) {} CUTLASS_HOST_DEVICE explicit float_e5m2_t(bfloat16 x) : float_e5m2_t(float(x)) {} + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(int64_t x) : float_e5m2_t(float(x)) {} + + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(uint32_t x) : float_e5m2_t(float(x)) {} + /// E4M3 conversion CUTLASS_HOST_DEVICE explicit float_e5m2_t(float_e4m3_t x); @@ -1025,7 +1037,8 @@ half operator*(float_e5m2_t const &lhs, float_e4m3_t const &rhs) { return half(float(lhs) * float(rhs)); } -/////////////////////////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////////////////////////// // // float_e4m3_t <=> float_e5m2_t conversions // diff --git a/ntt/include/nncase/half.h b/ntt/include/nncase/half.h index 0be9e6be30..0a42541351 100644 --- a/ntt/include/nncase/half.h +++ b/ntt/include/nncase/half.h @@ -52,23 +52,62 @@ struct half { constexpr explicit half(const T &v) noexcept : value_(round_to_half(v).value_) {} - constexpr half(fp16_from_raw_t, uint16_t value) noexcept - : value_(std::bit_cast<_Float16>(value)) {} - constexpr operator _Float16() const noexcept { return value_; } - constexpr operator float() const noexcept { + static constexpr half round_to_half(float v) { if (std::is_constant_evaluated()) { - return (float)value_; + return (_Float16)v; } else { #ifdef __F16C__ - // To avoid extendhfdf2 - return _cvtsh_ss(raw()); + // To avoid truncsfhf2 + return from_raw(_cvtss_sh(v, _MM_FROUND_NEARBYINT)); #else - return (float)value_; + return (_Float16)v; #endif } + + return (_Float16)v; } + static constexpr half epsilon() noexcept { return from_raw(0x0800); } + + // Integer conversion constructors + constexpr explicit half(int x) noexcept + : value_(round_to_half(float(x)).value_) {} + + constexpr explicit half(int64_t x) noexcept + : value_(round_to_half(float(x)).value_) {} + + constexpr explicit half(uint32_t x) noexcept + : value_(round_to_half(float(x)).value_) {} + + constexpr explicit half(uint64_t x) noexcept + : value_(round_to_half(double(x)).value_) {} + + // Floating point conversion constructors + constexpr explicit half(double x) noexcept + : value_(round_to_half(float(x)).value_) {} + + // bfloat16 conversion constructor + constexpr explicit half(bfloat16 x) noexcept + : value_(round_to_half(float(x)).value_) {} + + constexpr half(fp16_from_raw_t, uint16_t value) noexcept + : value_(std::bit_cast<_Float16>(value)) {} + + constexpr operator _Float16() const noexcept { return value_; } +// constexpr operator float() const noexcept { +// if (std::is_constant_evaluated()) { +// return (float)value_; +// } else { +// #ifdef __F16C__ +// // To avoid extendhfdf2 +// return _cvtsh_ss(raw()); +// #else +// return (float)value_; +// #endif +// } +// } + constexpr uint16_t raw() const noexcept { return std::bit_cast(value_); } @@ -77,22 +116,48 @@ struct half { return half(nncase::fp16_from_raw, v); } - static constexpr half round_to_half(float v) { - if (std::is_constant_evaluated()) { - return (_Float16)v; - } else { -#ifdef __F16C__ - // To avoid truncsfhf2 - return from_raw(_cvtss_sh(v, _MM_FROUND_NEARBYINT)); -#else - return (_Float16)v; -#endif - } + // Type conversion operators + constexpr explicit operator double() const noexcept { + return double(float(*this)); + } - return (_Float16)v; + constexpr explicit operator int8_t() const noexcept { + return int(float(*this)); } - static constexpr half epsilon() noexcept { return from_raw(0x0800); } + constexpr explicit operator uint8_t() const noexcept { + return int(float(*this)); + } + + + constexpr explicit operator int16_t() const noexcept { + return int(float(*this)); + } + + + constexpr explicit operator uint16_t() const noexcept { + return int(float(*this)); + } + + constexpr explicit operator int() const noexcept { + return int(float(*this)); + } + + constexpr explicit operator int64_t() const noexcept { + return int64_t(float(*this)); + } + + constexpr explicit operator uint32_t() const noexcept { + return uint32_t(float(*this)); + } + + constexpr explicit operator uint64_t() const noexcept { + return uint64_t(double(*this)); + } + + constexpr explicit operator bool() const noexcept { + return bool(std::bit_cast(*this)); + } static constexpr half highest() noexcept { return from_raw(0x7bff); } diff --git a/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h b/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h index 9b2b4f3eaa..31a3448588 100644 --- a/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h +++ b/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h @@ -859,17 +859,20 @@ REGISTER_RVV_BINARY_OP(max, float, max_float32) inline vfloat32m##lmul##_t pow_float32(const vfloat32m##lmul##_t &v1, \ const vfloat32m##lmul##_t &v2, \ const size_t vl) { \ + COMPILER_BARRIER(); \ return pow_ps(v1, v2, vl); \ } \ \ inline vfloat32m##lmul##_t pow_float32(const vfloat32m##lmul##_t &v1, \ const float &s, const size_t vl) { \ + COMPILER_BARRIER(); \ auto v2 = __riscv_vfmv_v_f_f32m##lmul(s, vl); \ return pow_ps(v1, v2, vl); \ } \ \ inline vfloat32m##lmul##_t pow_float32( \ const float &s, const vfloat32m##lmul##_t &v2, const size_t vl) { \ + COMPILER_BARRIER(); \ auto v1 = __riscv_vfmv_v_f_f32m##lmul(s, vl); \ return pow_ps(v1, v2, vl); \ } @@ -882,38 +885,73 @@ REGISTER_RVV_BINARY_OP(pow, float, pow_float32) inline vint32m##lmul##_t floor_mod_int32(const vint32m##lmul##_t &v1, \ const vint32m##lmul##_t &v2, \ const size_t vl) { \ + /*if no fence, the result would be incorrect on large testcases*/ \ auto remainder = __riscv_vrem_vv_i32m##lmul(v1, v2, vl); \ auto tmp = __riscv_vxor_vv_i32m##lmul(v1, v2, vl); \ auto mask1 = __riscv_vmsne_vx_i32m##lmul##_b##mlen(remainder, 0, vl); \ auto mask2 = __riscv_vmslt_vx_i32m##lmul##_b##mlen(tmp, 0, vl); \ mask1 = __riscv_vmand_mm_b##mlen(mask1, mask2, vl); \ - remainder = __riscv_vadd_vv_i32m##lmul##_m(mask1, remainder, v2, vl); \ + /*remainder = __riscv_vadd_vv_i32m##lmul##_m(mask1, remainder, v2, vl);*/ \ + asm volatile ( \ + "vmv.v.v v0, %[mask]\n\t" \ + "vadd.vv %[rem], %[rem], %[val], v0.t" \ + : [rem] "+vr" (remainder) \ + : [mask] "vr" (mask1), \ + [val] "vr" (v2) \ + : "v0" \ + ); \ + /* Debug output mask values */ \ + /* + std::cout << "=== FLOOR_MOD_INT32 DEBUG ===" << std::endl; \ + print_rvv_vector_i32(v1, "v1", vl); \ + print_rvv_vector_i32(v2, "v2", vl); \ + print_rvv_vector_i32(remainder1, "remainder1", vl); \ + print_rvv_vector_i32(tmp, "tmp (v1^v2)", vl); \ + print_rvv_vector_i32(remainder2, "final result", vl); \ + std::cout << "=== END DEBUG ===" << std::endl; \ + */ \ return remainder; \ } \ \ inline vint32m##lmul##_t floor_mod_int32( \ const vint32m##lmul##_t &v1, const int32_t &s, const size_t vl) { \ + /*if no fence, the result would be incorrect on large testcases*/ \ auto remainder = __riscv_vrem_vx_i32m##lmul(v1, s, vl); \ auto tmp = __riscv_vxor_vx_i32m##lmul(v1, s, vl); \ auto mask1 = __riscv_vmsne_vx_i32m##lmul##_b##mlen(remainder, 0, vl); \ auto mask2 = __riscv_vmslt_vx_i32m##lmul##_b##mlen(tmp, 0, vl); \ mask1 = __riscv_vmand_mm_b##mlen(mask1, mask2, vl); \ - remainder = __riscv_vadd_vx_i32m##lmul##_m(mask1, remainder, s, vl); \ + asm volatile ( \ + "vmv.v.v v0, %[mask]\n\t" \ + "vadd.vx %[rem], %[rem], %[val], v0.t" \ + : [rem] "+vr" (remainder) \ + : [mask] "vr" (mask1), \ + [val] "r" (s) \ + : "v0" \ + ); \ return remainder; \ } \ \ inline vint32m##lmul##_t floor_mod_int32( \ const int32_t &s, const vint32m##lmul##_t &v2, const size_t vl) { \ + /*if no fence, the result would be incorrect on large testcases*/ \ auto v1 = __riscv_vmv_v_x_i32m##lmul(s, vl); \ auto remainder = __riscv_vrem_vv_i32m##lmul(v1, v2, vl); \ auto tmp = __riscv_vxor_vv_i32m##lmul(v1, v2, vl); \ auto mask1 = __riscv_vmsne_vx_i32m##lmul##_b##mlen(remainder, 0, vl); \ auto mask2 = __riscv_vmslt_vx_i32m##lmul##_b##mlen(tmp, 0, vl); \ mask1 = __riscv_vmand_mm_b##mlen(mask1, mask2, vl); \ - remainder = __riscv_vadd_vv_i32m##lmul##_m(mask1, remainder, v2, vl); \ + asm volatile ( \ + "vmv.v.v v0, %[mask]\n\t" \ + "vadd.vv %[rem], %[rem], %[val], v0.t" \ + : [rem] "+vr" (remainder) \ + : [mask] "vr" (mask1), \ + [val] "vr" (v2) \ + : "v0" \ + ); \ return remainder; \ } - +//Compiler or qemu error on rvv int32 floor_mod kernel. REGISTER_RVV_KERNEL(FLOOR_MOD_INT32) REGISTER_RVV_BINARY_OP(floor_mod, int32_t, floor_mod_int32) diff --git a/ntt/include/nncase/ntt/arch/riscv64/primitive_ops_half.h b/ntt/include/nncase/ntt/arch/riscv64/primitive_ops_half.h index 245d1f55e8..001ca48c51 100644 --- a/ntt/include/nncase/ntt/arch/riscv64/primitive_ops_half.h +++ b/ntt/include/nncase/ntt/arch/riscv64/primitive_ops_half.h @@ -17,6 +17,12 @@ namespace nncase::ntt::ops { kernel(1, 16) kernel(2, 8) kernel(4, 4) kernel(8, 2) #endif +// float32 intermediate +#ifndef REGISTER_RVV_FP16_KERNEL_FP32_IM +#define REGISTER_RVV_FP16_KERNEL_FP32_IM(kernel) \ + kernel(1, 16) kernel(2, 8) kernel(4, 4) +#endif + #define RVV_UNARY_FP16_OP(op, dtype, vl, kernel) \ template <> struct op> { \ ntt::vector \ @@ -260,6 +266,7 @@ REGISTER_RVV_UNARY_FP16_OP(cosh, half, cosh_float16) auto vi = __riscv_vfcvt_x_f_v_i16m##lmul(v, vl); \ auto vf = __riscv_vfcvt_f_x_v_f16m##lmul(vi, vl); \ auto mask = __riscv_vmfgt_vv_f16m##lmul##_b##mlen(vf, v, vl); \ + __asm__ volatile("" ::: "memory"); \ vf = __riscv_vfsub_vf_f16m##lmul##_m(mask, vf, 1.f16, vl); \ return vf; \ } @@ -535,6 +542,12 @@ REGISTER_RVV_UNARY_FP16_OP(erf, half, erf_float16) RVV_BINARY_fp16_OP(op, dtype, NTT_VL(sizeof(dtype) * 8, *, 8), \ kernel) +//Fp32 as immidiate result +#define REGISTER_RVV_BINARY_FP16_OPS_FP32_IM(op, dtype, kernel) \ + RVV_BINARY_fp16_OP(op, dtype, NTT_VL(sizeof(dtype) * 8, *, 1), kernel) \ + RVV_BINARY_fp16_OP(op, dtype, NTT_VL(sizeof(dtype) * 8, *, 2), kernel) \ + RVV_BINARY_fp16_OP(op, dtype, NTT_VL(sizeof(dtype) * 8, *, 4), \ + kernel) // add #define ADD_FLOAT16(lmul, mlen) \ inline vfloat16m##lmul##_t add_float16(const vfloat16m##lmul##_t &v1, \ @@ -642,39 +655,60 @@ REGISTER_RVV_BINARY_FP16_OP(div, half, div_float16) REGISTER_RVV_FP16_KERNEL(POW_FLOAT16) REGISTER_RVV_BINARY_FP16_OP(pow, half, pow_float16) +#define LMUL_DBL_1 2 +#define LMUL_DBL_2 4 +#define LMUL_DBL_4 8 + +#define CONCAT_IMPL(a, b) a##b +#define CONCAT(a, b) CONCAT_IMPL(a, b) + +#define DOUBLE_LMUL(lmul) CONCAT(LMUL_DBL_, lmul) +#define CALL_DBL_LMUL(name, lmul) CONCAT(name, DOUBLE_LMUL(lmul)) + // mod #define MOD_FLOAT16(lmul, mlen) \ - inline vfloat16m##lmul##_t mod_float16(const vfloat16m##lmul##_t &v1, \ + inline vfloat16m##lmul##_t mod_float16(const vfloat16m##lmul##_t &v1, \ const vfloat16m##lmul##_t &v2, \ const size_t vl) { \ - auto quotient = __riscv_vfcvt_f_x_v_f16m##lmul( \ - __riscv_vfcvt_rtz_x_f_v_i16m##lmul( \ - __riscv_vfdiv_vv_f16m##lmul(v1, v2, vl), vl), \ - vl); \ - return __riscv_vfnmsub_vv_f16m##lmul(quotient, v2, v1, vl); \ - } \ + auto v1_f32 = CALL_DBL_LMUL(__riscv_vfwcvt_f_f_v_f32m, lmul)(v1, vl); \ + auto v2_f32 = CALL_DBL_LMUL(__riscv_vfwcvt_f_f_v_f32m, lmul)(v2, vl); \ + auto division_f32 = CALL_DBL_LMUL(__riscv_vfdiv_vv_f32m, lmul)(v1_f32, v2_f32, vl); \ + auto quotient_int = CALL_DBL_LMUL(__riscv_vfcvt_rtz_x_f_v_i32m, lmul)(division_f32, vl); \ + auto quotient_f32 = CALL_DBL_LMUL(__riscv_vfcvt_f_x_v_f32m, lmul)(quotient_int, vl); \ + auto result_f32 = CALL_DBL_LMUL(__riscv_vfnmsub_vv_f32m, lmul)(quotient_f32, v2_f32, v1_f32, vl); \ + auto result_f16 = __riscv_vfncvt_f_f_w_f16m##lmul(result_f32, vl); \ + return result_f16; \ + } \ + \ \ inline vfloat16m##lmul##_t mod_float16(const vfloat16m##lmul##_t &v, \ const half &s, const size_t vl) { \ - auto quotient = __riscv_vfcvt_f_x_v_f16m##lmul( \ - __riscv_vfcvt_rtz_x_f_v_i16m##lmul( \ - __riscv_vfdiv_vf_f16m##lmul(v, s, vl), vl), \ - vl); \ - return __riscv_vfnmsub_vf_f16m##lmul(quotient, s, v, vl); \ + float s_f32 = (float)s; \ + auto v_f32 = CALL_DBL_LMUL(__riscv_vfwcvt_f_f_v_f32m, lmul)(v, vl); \ + auto division_f32 = CALL_DBL_LMUL(__riscv_vfdiv_vf_f32m, lmul)(v_f32, s_f32, vl); \ + auto quotient_int = CALL_DBL_LMUL(__riscv_vfcvt_rtz_x_f_v_i32m, lmul)(division_f32, vl); \ + auto quotient_f32 = CALL_DBL_LMUL(__riscv_vfcvt_f_x_v_f32m, lmul)(quotient_int, vl); \ + auto result_f32 = CALL_DBL_LMUL(__riscv_vfnmsub_vf_f32m, lmul)(quotient_f32, s_f32, v_f32, vl); \ + auto result_f16 = __riscv_vfncvt_f_f_w_f16m##lmul(result_f32, vl); \ + return result_f16; \ } \ \ inline vfloat16m##lmul##_t mod_float16( \ const half &s, const vfloat16m##lmul##_t &v2, const size_t vl) { \ - auto v1 = __riscv_vfmv_v_f_f16m##lmul(s, vl); \ - auto quotient = __riscv_vfcvt_f_x_v_f16m##lmul( \ - __riscv_vfcvt_rtz_x_f_v_i16m##lmul( \ - __riscv_vfrdiv_vf_f16m##lmul(v2, s, vl), vl), \ - vl); \ - return __riscv_vfnmsub_vv_f16m##lmul(quotient, v2, v1, vl); \ - } + float s_f32 = (float)s; \ + auto v1_f32 = CALL_DBL_LMUL(__riscv_vfmv_v_f_f32m, lmul)(s_f32, vl); \ + auto v2_f32 = CALL_DBL_LMUL(__riscv_vfwcvt_f_f_v_f32m, lmul)(v2, vl); \ + auto division_f32 = CALL_DBL_LMUL(__riscv_vfrdiv_vf_f32m, lmul)(v2_f32, s_f32, vl); \ + auto quotient_int = CALL_DBL_LMUL(__riscv_vfcvt_rtz_x_f_v_i32m, lmul)(division_f32, vl); \ + auto quotient_f32 = CALL_DBL_LMUL(__riscv_vfcvt_f_x_v_f32m, lmul)(quotient_int, vl); \ + auto result_f32 = CALL_DBL_LMUL(__riscv_vfnmsub_vv_f32m, lmul)(quotient_f32, v2_f32, v1_f32, vl); \ + auto result_f16 = __riscv_vfncvt_f_f_w_f16m##lmul(result_f32, vl); \ + return result_f16; \ + } + -REGISTER_RVV_FP16_KERNEL(MOD_FLOAT16) -REGISTER_RVV_BINARY_FP16_OP(mod, half, mod_float16) +REGISTER_RVV_FP16_KERNEL_FP32_IM(MOD_FLOAT16) +REGISTER_RVV_BINARY_FP16_OPS_FP32_IM(mod, half, mod_float16) // min // template <> struct min { @@ -748,7 +782,16 @@ REGISTER_RVV_BINARY_FP16_OP(max, half, max_float16) auto mask1 = __riscv_vmsne_vx_i16m##lmul##_b##mlen(remainder, 0, vl); \ auto mask2 = __riscv_vmslt_vx_i16m##lmul##_b##mlen(tmp, 0, vl); \ mask1 = __riscv_vmand_mm_b##mlen(mask1, mask2, vl); \ - remainder = __riscv_vadd_vv_i16m##lmul##_m(mask1, remainder, v2, vl); \ + __asm__ volatile("" ::: "memory"); \ +/* remainder = __riscv_vadd_vv_i16m##lmul##_m(mask1, remainder, v2, vl); \ */ \ + asm volatile ( \ + "vmv.v.v v0, %[mask]\n\t" \ + "vadd.vv %[rem], %[rem], %[val], v0.t" \ + : [rem] "+vr" (remainder) \ + : [mask] "vr" (mask1), \ + [val] "vr" (v2) \ + : "v0" \ + ); \ return remainder; \ } \ \ @@ -759,7 +802,15 @@ REGISTER_RVV_BINARY_FP16_OP(max, half, max_float16) auto mask1 = __riscv_vmsne_vx_i16m##lmul##_b##mlen(remainder, 0, vl); \ auto mask2 = __riscv_vmslt_vx_i16m##lmul##_b##mlen(tmp, 0, vl); \ mask1 = __riscv_vmand_mm_b##mlen(mask1, mask2, vl); \ - remainder = __riscv_vadd_vx_i16m##lmul##_m(mask1, remainder, s, vl); \ +/* remainder = __riscv_vadd_vv_i16m##lmul##_m(mask1, remainder, v2, vl); \ */ \ + asm volatile ( \ + "vmv.v.v v0, %[mask]\n\t" \ + "vadd.vx %[rem], %[rem], %[val], v0.t" \ + : [rem] "+vr" (remainder) \ + : [mask] "vr" (mask1), \ + [val] "r" (s) \ + : "v0" \ + ); \ return remainder; \ } \ \ @@ -771,7 +822,15 @@ REGISTER_RVV_BINARY_FP16_OP(max, half, max_float16) auto mask1 = __riscv_vmsne_vx_i16m##lmul##_b##mlen(remainder, 0, vl); \ auto mask2 = __riscv_vmslt_vx_i16m##lmul##_b##mlen(tmp, 0, vl); \ mask1 = __riscv_vmand_mm_b##mlen(mask1, mask2, vl); \ - remainder = __riscv_vadd_vv_i16m##lmul##_m(mask1, remainder, v2, vl); \ +/* remainder = __riscv_vadd_vv_i16m##lmul##_m(mask1, remainder, v2, vl); \ */ \ + asm volatile ( \ + "vmv.v.v v0, %[mask]\n\t" \ + "vadd.vv %[rem], %[rem], %[val], v0.t" \ + : [rem] "+vr" (remainder) \ + : [mask] "vr" (mask1), \ + [val] "vr" (v2) \ + : "v0" \ + ); \ return remainder; \ } diff --git a/ntt/include/nncase/ntt/arch/riscv64/rvv_mathfun.h b/ntt/include/nncase/ntt/arch/riscv64/rvv_mathfun.h index 72a3f64086..33d7ac360f 100644 --- a/ntt/include/nncase/ntt/arch/riscv64/rvv_mathfun.h +++ b/ntt/include/nncase/ntt/arch/riscv64/rvv_mathfun.h @@ -15,10 +15,98 @@ #pragma once #include +#define COMPILER_BARRIER() __asm__ volatile("" ::: "memory") + #if __riscv_vector #include +#ifdef DE_BUG +#include +#include +#include + +#define __RVV_PRINT_VECTOR_INT(LMUL, MLEN, TLEN) \ + void print_rvv_vector_i##TLEN(const vint##TLEN##m##LMUL##_t &vec, const char *label, const size_t print_vl){ \ + int##TLEN##_t temp[(LMUL*NTT_VLEN)/TLEN]; \ + __riscv_vse##TLEN##_v_i##TLEN##m##LMUL(temp, vec, print_vl); \ + std::cout << label << ": "; \ + for (size_t i = 0; i < print_vl; ++i) { \ + std::cout << temp[i] << " "; \ + } \ + std::cout << std::endl; \ + } + +__RVV_PRINT_VECTOR_INT(1, 32, 32) +__RVV_PRINT_VECTOR_INT(2, 16, 32) +__RVV_PRINT_VECTOR_INT(4, 8, 32) +__RVV_PRINT_VECTOR_INT(8, 4, 32) + +#define __RVV_PRINT_VECTOR_FLOAT(LMUL, MLEN, TLEN) \ + void print_rvv_vector_f##TLEN(const vfloat##TLEN##m##LMUL##_t &vec, const char *label, const size_t print_vl){ \ + float temp[(LMUL*NTT_VLEN/TLEN)]; \ + __riscv_vse##TLEN##_v_f##TLEN##m##LMUL(temp, vec, print_vl); \ + std::cout << label << ": "; \ + for (size_t i = 0; i < print_vl; ++i) { \ + std::cout << temp[i] << " "; \ + } \ + std::cout << std::endl; \ + } + +__RVV_PRINT_VECTOR_FLOAT(1, 32, 32) +__RVV_PRINT_VECTOR_FLOAT(2, 16, 32) +__RVV_PRINT_VECTOR_FLOAT(4, 8, 32) +__RVV_PRINT_VECTOR_FLOAT(8, 4, 32) + + +#define __RVV_PRINT_VECTOR_HALF(LMUL, MLEN, TLEN) \ + void print_rvv_vector_f##TLEN(const vfloat##TLEN##m##LMUL##_t &vec, const char *label, const size_t print_vl){ \ + _Float16 temp[(LMUL*NTT_VLEN/TLEN)]; \ + __riscv_vse##TLEN##_v_f##TLEN##m##LMUL(temp, vec, print_vl); \ + std::cout << label << ": "; \ + for (size_t i = 0; i < print_vl; ++i) { \ + std::cout << std::setprecision(std::numeric_limits::max_digits10) << temp[i] << " "; \ + } \ + std::cout << std::endl; \ + } + + +__RVV_PRINT_VECTOR_HALF(1, 16, 16) +__RVV_PRINT_VECTOR_HALF(2, 8, 16) +__RVV_PRINT_VECTOR_HALF(4, 4, 16) +__RVV_PRINT_VECTOR_HALF(8, 2, 16) + + +// template +// void print_rvv_vector_i32(const vint32m1_t &vec, const char *label, const size_t print_vl) { +// int32_t temp[vl]; +// __riscv_vse32_v_i32m1(temp, vec, print_vl); +// std::cout << label << ": "; +// for (size_t i = 0; i < print_vl; ++i) { +// std::cout << temp[i] << " "; +// } +// std::cout << std::endl; +// } + + +#define __RVV_PRINT_MASK(BTYPE, MLEN) \ + void print_rvv_mask_##MLEN(const vbool##MLEN##_t &mask, const char *label, const size_t print_vl) { \ + uint8_t temp[MLEN]; \ + __riscv_vsm_v_b##MLEN(temp, mask, print_vl); \ + std::cout << label << ": "; \ + for (size_t i = 0; i < print_vl; ++i) { \ + std::cout << static_cast(temp[i]) << " "; \ + } \ + std::cout << std::endl; \ + } + +__RVV_PRINT_MASK(32, 32) +__RVV_PRINT_MASK(16, 16) +__RVV_PRINT_MASK(8, 8) +__RVV_PRINT_MASK(4, 4) + +#endif + #define c_inv_mant_mask ~0x7f800000u #define c_cephes_SQRTHF 0.707106781186547524 #define c_cephes_log_p0 7.0376836292E-2 @@ -95,8 +183,8 @@ _RVV_FLOAT32_LOG_OP(2, 16) _RVV_FLOAT32_LOG_OP(4, 8) _RVV_FLOAT32_LOG_OP(8, 4) -#define c_exp_hi 88.3762626647949f -#define c_exp_lo -88.3762626647949f +#define c_exp_hi 88.0f +#define c_exp_lo -88.0f #define c_cephes_LOG2EF 1.44269504088896341 #define c_cephes_exp_C1 0.693359375 @@ -382,12 +470,107 @@ _RVV_FLOAT_TANH_OP(2, 16, 32) _RVV_FLOAT_TANH_OP(4, 8, 32) _RVV_FLOAT_TANH_OP(8, 4, 32) -#define _RVV_FLOAT_POW_OP(LMUL, MLEN, TLEN) \ - static inline vfloat##TLEN##m##LMUL##_t pow_ps( \ +#define _RVV_FLOAT_FLOOR_OP(LMUL, MLEN, TLEN) \ + static inline vfloat##TLEN##m##LMUL##_t vfloor_v_f##TLEN##m##LMUL( \ + vfloat##TLEN##m##LMUL##_t val, size_t vl) { \ + /* 1. Cast float to int(Round Towards Zero) */ \ + vint##TLEN##m##LMUL##_t i_val = \ + __riscv_vfcvt_rtz_x_f_v_i##TLEN##m##LMUL(val, vl); \ + /* 2. Cast int back to float*/ \ + return __riscv_vfcvt_f_x_v_f##TLEN##m##LMUL(i_val, vl); \ + } +_RVV_FLOAT_FLOOR_OP(1, 32, 32) +_RVV_FLOAT_FLOOR_OP(2, 16, 32) + +const float fp32_inf = std::numeric_limits::infinity(); +//To Reuse this blopck, following should be done: +// 1. replace {i/f}32 to {i/f}TLEN +// 2. using anthor macro to get the "twoPow24" or we say threshold for different float len +#define __RVV_FLOAT32_IS_INTEGER(LMUL, MLEN) \ + static inline vbool##MLEN##_t __vfloat32_is_integer_##LMUL( \ + vfloat32m##LMUL##_t v, size_t vl) { \ + const float twoPow24 = 16777216.0f; \ + /* huge float must have integer value */ \ + auto v_abs = __riscv_vfabs_v_f32m##LMUL(v, vl); \ + auto huge_float_flag = __riscv_vmfgt_vf_f32m##LMUL##_b##MLEN(v_abs, twoPow24, vl); \ + auto v_is_not_inf_flag = __riscv_vmfne_vf_f32m##LMUL##_b##MLEN(v, fp32_inf, vl); \ + huge_float_flag = __riscv_vmand_mm_b##MLEN(huge_float_flag, v_is_not_inf_flag, vl); \ + auto v_to_int = __riscv_vfcvt_rtz_x_f_v_i32m##LMUL(v, vl); \ + auto back_to_float = __riscv_vfcvt_f_x_v_f32m##LMUL(v_to_int, vl); \ + auto is_int_flag = __riscv_vmfeq_vv_f32m##LMUL##_b##MLEN(v, back_to_float, vl); \ + return __riscv_vmor_mm_b##MLEN(huge_float_flag, is_int_flag, vl); \ + } + +__RVV_FLOAT32_IS_INTEGER(1, 32) +__RVV_FLOAT32_IS_INTEGER(2, 16) +__RVV_FLOAT32_IS_INTEGER(4, 8) +__RVV_FLOAT32_IS_INTEGER(8, 4) + +#define __RVV_FLOAT32_IS_EVEN(LMUL, MLEN) \ + static inline vbool##MLEN##_t __vfloat32_is_even_##LMUL( \ + vfloat32m##LMUL##_t v, size_t vl) { \ + const float twoPow24 = 16777216.0f; \ + auto v_abs = __riscv_vfabs_v_f32m##LMUL(v, vl); \ + auto huge_float_flag = __riscv_vmfgt_vf_f32m##LMUL##_b##MLEN(v_abs, twoPow24, vl); \ + auto v_is_not_inf_flag = __riscv_vmfne_vf_f32m##LMUL##_b##MLEN(v, fp32_inf, vl); \ + huge_float_flag = __riscv_vmand_mm_b##MLEN(huge_float_flag, v_is_not_inf_flag, vl); \ + /* test if v == ((int)v /2 * 2) */ \ + auto v_to_int = __riscv_vfcvt_rtz_x_f_v_i32m##LMUL(v, vl); \ + auto v_to_int_div2 = __riscv_vsra_vx_i32m##LMUL(v_to_int, 1, vl); \ + auto v_div_mul_2 = __riscv_vsll_vx_i32m##LMUL(v_to_int_div2, 1, vl); \ + auto is_even_flag = __riscv_vmseq_vv_i32m##LMUL##_b##MLEN(v_to_int, v_div_mul_2, vl); \ + return __riscv_vmor_mm_b##MLEN(huge_float_flag, is_even_flag, vl); \ + } + +__RVV_FLOAT32_IS_EVEN(1, 32) +__RVV_FLOAT32_IS_EVEN(2, 16) +__RVV_FLOAT32_IS_EVEN(4, 8) +__RVV_FLOAT32_IS_EVEN(8, 4) + +#define _RVV_FLOAT_POW_OP(LMUL, MLEN, TLEN) \ + static inline vfloat##TLEN##m##LMUL##_t pow_ps( \ vfloat##TLEN##m##LMUL##_t a, vfloat##TLEN##m##LMUL##_t b, size_t vl) { \ - /* pow(x, m) = exp(m * log(x)) */ \ - return exp_ps(__riscv_vfmul_vv_f##TLEN##m##LMUL(b, log_ps(a, vl), vl), \ - vl); \ + /* --- constants --- */ \ + float scalar_nan = nanf(""); \ + auto nan_vector = __riscv_vfmv_v_f_f##TLEN##m##LMUL(scalar_nan, vl); \ + COMPILER_BARRIER(); \ + /* --- input a --- */ \ + auto neg_a_mask = __riscv_vmflt_vf_f##TLEN##m##LMUL##_b##MLEN(a, 0.f, vl); \ + auto abs_a = __riscv_vfabs_v_f##TLEN##m##LMUL(a, vl); \ + \ + /* --- |a|^b --- */ \ + auto result = exp_ps(__riscv_vfmul_vv_f##TLEN##m##LMUL(b, log_ps(abs_a, vl), vl), vl); \ + COMPILER_BARRIER(); \ + \ + /* --- handle a < 0 --- */ \ + if(__riscv_vcpop_m_b##MLEN(neg_a_mask, vl) != 0) { \ + auto b_int_mask = __vfloat32_is_integer_##LMUL(b, vl); \ + \ + auto b_even_mask = __vfloat32_is_even_##LMUL(b, vl); \ + auto b_not_even_mask = __riscv_vmnot_m_b##MLEN(b_even_mask, vl); \ + \ + /* set to neg, a < 0 AND b is int AND b is not even*/ \ + auto flip_sign_mask = __riscv_vmand_mm_b##MLEN(neg_a_mask, b_int_mask, vl); \ + \ + flip_sign_mask = __riscv_vmand_mm_b##MLEN(flip_sign_mask, b_not_even_mask, vl); \ + \ + COMPILER_BARRIER(); \ + /* set to NaN, a < 0 AND b is not an integer */ \ + auto is_not_int_mask = __riscv_vmnot_m_b##MLEN(b_int_mask, vl); \ + auto set_nan_mask = __riscv_vmand_mm_b##MLEN(neg_a_mask, is_not_int_mask, vl); \ + \ + COMPILER_BARRIER(); \ + /* --- use the masks to adjust the result --- */ \ + /* a. set to neg */ \ + auto neg_result = __riscv_vfneg_v_f##TLEN##m##LMUL##_m(flip_sign_mask, result, vl); \ + \ + auto signed_result = __riscv_vmerge_vvm_f##TLEN##m##LMUL(result, neg_result, flip_sign_mask, vl); \ + /* b. set to NaN */ \ + result = __riscv_vmerge_vvm_f##TLEN##m##LMUL(signed_result, nan_vector, set_nan_mask, vl); \ + \ + } \ + \ + return result; \ } _RVV_FLOAT_POW_OP(1, 32, 32) @@ -722,4 +905,4 @@ _RVV_FLOAT_ERF_OP(1, 32, 32) _RVV_FLOAT_ERF_OP(2, 16, 32) _RVV_FLOAT_ERF_OP(4, 8, 32) _RVV_FLOAT_ERF_OP(8, 4, 32) -#endif \ No newline at end of file +#endif diff --git a/ntt/include/nncase/ntt/arch/x86_64/avx_mathfun.h b/ntt/include/nncase/ntt/arch/x86_64/avx_mathfun.h index 0a23248cb6..c8fb4bd0d1 100644 --- a/ntt/include/nncase/ntt/arch/x86_64/avx_mathfun.h +++ b/ntt/include/nncase/ntt/arch/x86_64/avx_mathfun.h @@ -68,6 +68,10 @@ _PI32AVX_CONST(4, 4); _PS256_CONST(1, 1.0f); _PS256_CONST(0p5, 0.5f); +_PS256_CONST(2, 2.0f); +_PS256_CONST(nan, NAN); + + /* the smallest non denormalized float number */ _PS256_CONST_TYPE(min_norm_pos, int, 0x00800000); _PS256_CONST_TYPE(mant_mask, int, 0x7f800000); @@ -75,6 +79,7 @@ _PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000); _PS256_CONST_TYPE(sign_mask, int, (int)0x80000000); _PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000); +_PS256_CONST_TYPE(all_bits, int, -1); _PI32_CONST256(0, 0); _PI32_CONST256(1, 1); @@ -748,9 +753,54 @@ static inline __m256 tan256_ps(__m256 x) { return ytan; } + +// static inline __m256 pow256_ps(__m256 a, __m256 b) { +// // pow(x, m) = exp(m * log(x)) +// return exp256_ps(_mm256_mul_ps(b, log256_ps(a))); +// } static inline __m256 pow256_ps(__m256 a, __m256 b) { - // pow(x, m) = exp(m * log(x)) - return exp256_ps(_mm256_mul_ps(b, log256_ps(a))); + // --- constants --- + const __m256 zero = _mm256_setzero_ps(); + const __m256 two = *(__m256*)_ps256_2; + const __m256 half = *(__m256*)_ps256_0p5; + const __m256 nan_val = *(__m256*)_ps256_nan; + const __m256 abs_mask = *(__m256*)_ps256_inv_sign_mask; + const __m256 sign_mask= *(__m256*)_ps256_sign_mask; + const __m256 all_bits = *(__m256*)_ps256_all_bits; + + // --- input a --- + __m256 neg_a_mask = _mm256_cmp_ps(a, zero, _CMP_LT_OS); + __m256 abs_a = _mm256_and_ps(a, abs_mask); + + // --- |a|^b --- + __m256 result = exp256_ps(_mm256_mul_ps(b, log256_ps(abs_a))); + + // --- handle a < 0 --- + if (_mm256_movemask_ps(neg_a_mask) != 0) { + __m256 b_floor = _mm256_floor_ps(b); + __m256 is_int_mask = _mm256_cmp_ps(b, b_floor, _CMP_EQ_OQ); + + __m256 b_div_2_floor = _mm256_floor_ps(_mm256_mul_ps(b, half)); + __m256 is_odd_mask = _mm256_cmp_ps(_mm256_mul_ps(b_div_2_floor, two), b_floor, _CMP_NEQ_UQ); + + // set to neg, a < 0 AND b is odd + __m256 flip_sign_mask = _mm256_and_ps(neg_a_mask, is_int_mask); + flip_sign_mask = _mm256_and_ps(flip_sign_mask, is_odd_mask); + + // set to NaN, a < 0 AND b is not an integer + __m256 is_not_int_mask = _mm256_xor_ps(is_int_mask, all_bits); + __m256 set_nan_mask = _mm256_and_ps(neg_a_mask, is_not_int_mask); + + // --- use the masks to adjust the result --- + // a. set to neg + __m256 sign_flipper = _mm256_and_ps(flip_sign_mask, sign_mask); + result = _mm256_xor_ps(result, sign_flipper); + + // b. set to NaN + result = _mm256_blendv_ps(result, nan_val, set_nan_mask); + } + + return result; } static inline __m256 asin256_ps(__m256 x) { diff --git a/ntt/include/nncase/ntt/kernels/cast.h b/ntt/include/nncase/ntt/kernels/cast.h index cb56762154..fbc7929c77 100644 --- a/ntt/include/nncase/ntt/kernels/cast.h +++ b/ntt/include/nncase/ntt/kernels/cast.h @@ -18,6 +18,8 @@ #include "../post_ops.h" #include "../tensor_ops.h" #include "../ukernels.h" +#include +#include #include "../utility.h" #include "nncase/ntt/shape.h" @@ -27,20 +29,24 @@ template class TPostOp> class cast_impl { inline static constexpr size_t rank = TIn::rank(); - - // FIXME: vector of x86 may fail. + // !! For vector, the element counts must be same as the other cast oprand. using InElemType = element_or_scalar_t; using OutElemType = element_or_scalar_t; static_assert((Vector && Vector) || (Scalar && Scalar), "input & output must have the same type."); inline static constexpr auto in_ele_size = - sizeof(std::conditional_t, + sizeof(std::conditional_t, //if vector element_or_scalar_t, size_t>); inline static constexpr auto out_ele_size = sizeof(std::conditional_t, element_or_scalar_t, size_t>); - inline static constexpr float scale = (float)in_ele_size / out_ele_size; + + inline static constexpr bool is_bool_vector = + Vector && ( std::is_same_v, bool> || + std::is_same_v, bool>); + + inline static constexpr float scale = is_bool_vector ? 1.0f : (float)in_ele_size / out_ele_size; inline static constexpr auto in_offset_scale = scale > 1.0f ? (size_t)scale : (size_t)1; @@ -69,11 +75,18 @@ class cast_impl { #endif constexpr VectorizedAxes vectorizedAxes; if constexpr (scale >= 1.f) { + if constexpr (VectorizedAxes::rank() == 1) { + assert( + (dim_value(input.shape()[fixed_dim_v]) == + dim_value(output.shape()[fixed_dim_v]) * scale) + ); + } ntt::apply(output.shape(), [&](auto index) { auto in_index = index; - if constexpr (vectorizedAxes.rank() == 1) + if constexpr (VectorizedAxes::rank() == 1) in_index[fixed_dim_v] *= in_offset_scale; + __asm__ volatile("" ::: "memory"); ntt::u_cast( &input(in_index), vectorizedAxes.rank() == 1 @@ -82,11 +95,18 @@ class cast_impl { &output(index), 1, 1); }); } else { + if constexpr (VectorizedAxes::rank() == 1) { + assert( + (float)dim_value(input.shape()[fixed_dim_v]) == + (float)dim_value(output.shape()[fixed_dim_v]) * scale + ); + } ntt::apply(input.shape(), [&](auto index) { auto out_index = index; if constexpr (vectorizedAxes.rank() == 1) out_index[fixed_dim_v] *= out_offset_scale; + __asm__ volatile("" ::: "memory"); ntt::u_cast( &input(index), 1, &output(out_index), vectorizedAxes.rank() == 1 @@ -101,6 +121,7 @@ class cast_impl { #if 0 template constexpr void + //rest_dims is the dims of the tensor to be casted apply(const TContiguousDims &conti_dims, const TRestDims &rest_dims, dynamic_shape_t &index, const TIn &input, TOut &output) { if (conti_dims == rest_dims.rank()) { diff --git a/ntt/include/nncase/ntt/primitive_ops.h b/ntt/include/nncase/ntt/primitive_ops.h index c04ac77bcc..2884006e2d 100644 --- a/ntt/include/nncase/ntt/primitive_ops.h +++ b/ntt/include/nncase/ntt/primitive_ops.h @@ -180,12 +180,13 @@ template struct mul { template struct div { constexpr auto operator()(const T1 &v1, const T2 &v2) const noexcept { - static_assert(std::is_same_v, "T1 and T2 must be same type"); return v1 / v2; } }; template struct ceil_div { + static_assert(std::is_integral_v && std::is_integral_v, + "T1 and T2 must be integral types"); constexpr auto operator()(const T1 &v1, const T2 &v2) const noexcept { return (v1 + (v2 - 1)) / v2; } @@ -197,9 +198,20 @@ template struct ceil_div { */ template struct floor_mod { constexpr auto operator()(const T1 &v1, const T2 &v2) const noexcept { - return (T1)((double)v1 - std::floor(static_cast(v1) / - static_cast(v2)) * - (double)v2); + return (T1)(double(v1) - + std::floor(static_cast(v1) / static_cast(v2)) * + static_cast(v2)); + } +}; + + +template +requires (std::is_same_v || std::is_same_v) +struct floor_mod { + constexpr auto operator()(T v1, + T v2) const noexcept { + + return T(v1 - (std::floor(float(v1) / float(v2)) * v2)); } }; @@ -220,10 +232,23 @@ template struct outer_product { */ template struct mod { constexpr auto operator()(const T1 &v1, const T2 &v2) const noexcept { - return std::fmod((float)v1, (float)v2); + return (T1)std::fmod((double)v1, (double)v2); } }; + +template +requires (std::is_same_v || std::is_same_v) +struct mod { + constexpr auto operator()(T v1, + T v2) const noexcept { + return T( + std::fmod(static_cast(v1), static_cast(v2))); + } +}; + + + template struct min { constexpr auto operator()(const T1 &v1, const T2 &v2) const noexcept { return std::min(v1, v2); @@ -310,7 +335,9 @@ template struct clamp { template struct cast { constexpr T2 operator()(const T1 &v) const noexcept { + // printf("cast from %f to %f\n", (double)(float)v, (double)static_cast(v)); return static_cast(v); + } }; @@ -551,7 +578,8 @@ template constexpr T swish::operator()(const T &v) const noexcept { // swishb(v) = v / (exp(-v*beta) + 1) template constexpr T swishb::operator()(const T &v, const B &beta) const noexcept { - return v / (ntt::exp(-v * beta) + 1); + //-(double)v is for uint type. + return static_cast(double(v) / (ntt::exp((-(double)v) *(double)beta) + (double)1)); } template diff --git a/ntt/include/nncase/ntt/ukernels/u_cast.h b/ntt/include/nncase/ntt/ukernels/u_cast.h index 9c9bdbeeaa..c835e8f527 100644 --- a/ntt/include/nncase/ntt/ukernels/u_cast.h +++ b/ntt/include/nncase/ntt/ukernels/u_cast.h @@ -32,8 +32,32 @@ struct u_cast { size_t output_stride, size_t count) noexcept { using policy_t = u_cast_policy; constexpr auto unroll = policy_t::unroll; + + if constexpr (in_offset_scale == 8 && out_offset_scale == 1) { + while (count / unroll) { + for (size_t i = 0; i < unroll; i++) { + *output = + ntt::ops::cast()(*(input + 0 * input_stride), *(input + 1 * input_stride), + *(input + 2 * input_stride), *(input + 3 * input_stride), + *(input + 4 * input_stride), *(input + 5 * input_stride), + *(input + 6 * input_stride), *(input + 7 * input_stride)); + input += input_stride * in_offset_scale; + output += output_stride * out_offset_scale; + count--; + } + } - if constexpr (in_offset_scale == 4 && out_offset_scale == 1) { + for (size_t i = 0; i < count; i++) { + *output = ntt::ops::cast()( + *(input + 0 * input_stride), *(input + 1 * input_stride), + *(input + 2 * input_stride), *(input + 3 * input_stride), + *(input + 4 * input_stride), *(input + 5 * input_stride), + *(input + 6 * input_stride), *(input + 7 * input_stride)); + input += input_stride * in_offset_scale; + output += output_stride * out_offset_scale; + } + } + else if constexpr (in_offset_scale == 4 && out_offset_scale == 1) { while (count / unroll) { for (size_t i = 0; i < unroll; i++) { *output = @@ -78,7 +102,6 @@ struct u_cast { } } else if constexpr (in_offset_scale == 1 && out_offset_scale > 1) { using value_type = typename T2::element_type; - constexpr auto lanes = T2::shape(); while (count / unroll) { for (size_t i = 0; i < unroll; i++) { @@ -104,6 +127,7 @@ struct u_cast { } } else { + while (count / unroll) { for (size_t i = 0; i < unroll; i++) { *output = ntt::ops::cast()(*input); diff --git a/ntt/include/nncase/ntt/vector.h b/ntt/include/nncase/ntt/vector.h index 2c01fcbf9e..4388eba6a7 100644 --- a/ntt/include/nncase/ntt/vector.h +++ b/ntt/include/nncase/ntt/vector.h @@ -145,4 +145,32 @@ template struct vector_rank { }; template constexpr inline auto vector_rank_v = vector_rank::value; + +template +struct last_lane; + +template +struct last_lane> { + static constexpr size_t value = D::value; +}; + +template +struct last_lane> { + static constexpr size_t value = last_lane>::value; +}; + +template +struct get_last_lane_vector { + using element_type = typename TVec::element_type; + using shape_type = typename TVec::shape_type; + + static constexpr size_t last_dim = last_lane::value; + + using type = nncase::ntt::replace_lanes_t; +}; + +template +using get_last_lane_vector_t = typename get_last_lane_vector::type; + + } // namespace nncase::ntt diff --git a/ntt/include/nncase/ntt/vector_ops.h b/ntt/include/nncase/ntt/vector_ops.h index de2d35c02d..5a47839eb5 100644 --- a/ntt/include/nncase/ntt/vector_ops.h +++ b/ntt/include/nncase/ntt/vector_ops.h @@ -64,12 +64,17 @@ struct tensor_unary_impl { Op op_; }; -template