diff --git a/rmsnorm/build.toml b/rmsnorm/build.toml index 3a1a6bc..83717f5 100644 --- a/rmsnorm/build.toml +++ b/rmsnorm/build.toml @@ -5,6 +5,39 @@ universal = false [torch] src = ["torch-ext/torch_binding.cpp"] +[kernel.rmsnorm_cpu] +backend = "cpu" +depends = ["torch"] +src = [ + "rmsnorm_cpu/rmsnorm_cpu_torch.cpp", + "rmsnorm_cpu/rmsnorm_cpu.cpp", + "rmsnorm_cpu/rmsnorm_cpu.hpp", + "rmsnorm_cpu/cpu_features.hpp", +] +include = ["rmsnorm_cpu"] + +[kernel.rmsnorm_cpu_avx512] +backend = "cpu" +depends = ["torch"] +src = [ + "rmsnorm_cpu/rmsnorm_avx512.cpp", + "rmsnorm_cpu/rmsnorm_avx512.hpp", + "rmsnorm_cpu/cpu_types_avx512.hpp", +] +include = ["rmsnorm_cpu"] +cxx-flags = ["-mfma", "-fopenmp", "-mf16c", "-mavx512f", "-mavx512bf16", "-mavx512vl"] + +[kernel.rmsnorm_cpu_avx2] +backend = "cpu" +depends = ["torch"] +src = [ + "rmsnorm_cpu/rmsnorm_avx2.cpp", + "rmsnorm_cpu/rmsnorm_avx2.hpp", + "rmsnorm_cpu/cpu_types_avx2.hpp", +] +include = ["rmsnorm_cpu"] +cxx-flags = ["-mavx2", "-mfma", "-fopenmp", "-mf16c"] + [kernel.rmsnorm_xpu] backend = "xpu" depends = ["torch"] diff --git a/rmsnorm/flake.lock b/rmsnorm/flake.lock index e05ef3d..b8b968e 100644 --- a/rmsnorm/flake.lock +++ b/rmsnorm/flake.lock @@ -2,26 +2,11 @@ "nodes": { "flake-compat": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-compat_2": { - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -48,61 +33,18 @@ "type": "github" } }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "hf-nix": { - "inputs": { - "flake-compat": "flake-compat_2", - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs" - }, - "locked": { - "lastModified": 1760814603, - "narHash": "sha256-i5uuhnJPxOrd0dC8+btp31WMfzPDL8Uwz0TPG2n6nHE=", - "owner": "huggingface", - "repo": "hf-nix", - "rev": "c0b62ec3d0abb11dd2d960e3dfee3a46fc46d111", - "type": "github" - }, - "original": { - "owner": "huggingface", - "repo": "hf-nix", - "type": "github" - } - }, "kernel-builder": { "inputs": { "flake-compat": "flake-compat", "flake-utils": "flake-utils", - "hf-nix": "hf-nix", - "nixpkgs": [ - "kernel-builder", - "hf-nix", - "nixpkgs" - ] + "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1761645431, - "narHash": "sha256-Ns3m/L+FMAYnmKhwt4vlIf8lq6dOJWHAocFL23HasTM=", + "lastModified": 1763649391, + "narHash": "sha256-XM5ptF3zZ9ZOTMjWOQxD502Jc/46Y2k/nrgDf4i3bxA=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "289788986c318e6ccb92608f011c49d61b25b5b6", + "rev": "7987361891c46a9e3cfe07eba9aa3e52638906fe", "type": "github" }, "original": { @@ -113,17 +55,17 @@ }, "nixpkgs": { "locked": { - "lastModified": 1755963616, - "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", - "owner": "nixos", + "lastModified": 1763291491, + "narHash": "sha256-eEYvm+45PPmy+Qe+nZDpn1uhoMUjJwx3PwVVQoO9ksA=", + "owner": "NixOS", "repo": "nixpkgs", - "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", "type": "github" }, "original": { - "owner": "nixos", - "ref": "nixos-unstable-small", + "owner": "NixOS", "repo": "nixpkgs", + "rev": "c543a59edf25ada193719764f3bc0c6ba835f94d", "type": "github" } }, @@ -146,21 +88,6 @@ "repo": "default", "type": "github" } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } } }, "root": "root", diff --git a/rmsnorm/flake.nix b/rmsnorm/flake.nix index 6419526..b8d645d 100644 --- a/rmsnorm/flake.nix +++ b/rmsnorm/flake.nix @@ -8,5 +8,14 @@ kernel-builder.lib.genFlakeOutputs { inherit self; path = ./.; + + # This is a workaround, we should be able to specify flags per arch in + # kernel-builder. + torchVersions = + allVersions: + builtins.map ( + version: + version // { systems = builtins.filter (system: system == "x86_64-linux") version.systems; } + ) allVersions; }; } diff --git a/rmsnorm/rmsnorm_cpu/cpu_features.hpp b/rmsnorm/rmsnorm_cpu/cpu_features.hpp new file mode 100644 index 0000000..18e7c97 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/cpu_features.hpp @@ -0,0 +1,176 @@ +#pragma once + +#ifdef _MSC_VER +#include +#else +#include +#endif +#include +#include +#include +#include +namespace rmsnorm_cpu +{ + + // CPU feature detection + class CPUFeatures + { + public: + static bool hasAVX2() + { + static bool avx2_supported = checkAVX2(); + return avx2_supported; + } + + static bool hasAVX512BF16() + { + static bool bf16_supported = checkAVX512BF16(); + return bf16_supported; + } + + private: + static bool checkAVX2() + { +#ifdef _MSC_VER + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + + if (n_ids >= 7) + { + __cpuidex(cpu_info, 7, 0); + return (cpu_info[1] & (1 << 5)) != 0; // EBX bit 5 + } + return false; +#else + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + __cpuid_count(7, 0, eax, ebx, ecx, edx); + return (ebx & (1 << 5)) != 0; // EBX bit 5 +#endif + } + + static bool checkAVX512() + { +#ifdef _MSC_VER + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + if (n_ids < 7) + return false; + + __cpuidex(cpu_info, 7, 0); + bool avx512f = (cpu_info[1] & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation + if (!avx512f) + return false; + + __cpuid(cpu_info, 1); + bool osxsave = (cpu_info[2] & (1 << 27)) != 0; // ECX bit 27: OSXSAVE + if (!osxsave) + return false; + + // check XCR0: bits 1,2 (SSE/AVX) and 5,6,7 (AVX-512 state) must be enabled by OS + unsigned long long xcr0 = _xgetbv(0); + return ((xcr0 & 0xE6ULL) == 0xE6ULL); +#else + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + + __cpuid_count(7, 0, eax, ebx, ecx, edx); + bool avx512f = (ebx & (1 << 16)) != 0; // EBX bit 16: AVX-512 Foundation + if (!avx512f) + return false; + + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) == 0) + { + return false; + } + bool osxsave = (ecx & (1 << 27)) != 0; // ECX bit 27: OSXSAVE + if (!osxsave) + return false; + + unsigned int xcr0_lo = 0, xcr0_hi = 0; + __asm__ volatile("xgetbv" : "=a"(xcr0_lo), "=d"(xcr0_hi) : "c"(0)); + unsigned long long xcr0 = ((unsigned long long)xcr0_hi << 32) | xcr0_lo; + // require XCR0 bits 1,2,5,6,7 set -> mask 0xE6 (0b11100110) + return ((xcr0 & 0xE6ULL) == 0xE6ULL); +#endif + } + + static bool checkAVX512BF16() + { + // require AVX-512 foundation supported and OS enabled + if (!checkAVX512()) + return false; + +#ifndef _MSC_VER + // First: try Linux /proc/cpuinfo flags (most robust on Linux) + std::ifstream f("/proc/cpuinfo"); + if (f) + { + std::string line; + while (std::getline(f, line)) + { + // flags line contains many space-separated tokens including avx512_bf16 on supported CPUs + if (line.find("avx512_bf16") != std::string::npos || + line.find("avx512bf16") != std::string::npos) + { + return true; + } + } + } + + // Fallback: attempt CPUID subleaf check if available. + // Note: exact bit position for AVX512_BF16 may differ across vendors/CPUID versions. + // This fallback tries CPUID(7,1) and checks some common positions; if uncertain returns false. + if (__get_cpuid_max(0, nullptr) < 7) + { + return false; + } + unsigned int eax, ebx, ecx, edx; + // try subleaf 1 + __cpuid_count(7, 1, eax, ebx, ecx, edx); + // There isn't a universally agreed constant here in this file; check common candidate bits: + // - some implementations report AVX512_BF16 in ECX/EBX of subleaf 1. + // Try commonly used positions conservatively. + const unsigned int candidate_masks[] = { + (1u << 5), // candidate (may collide with other features) + (1u << 26), // another candidate position + }; + for (unsigned m : candidate_masks) + { + if ((ebx & m) || (ecx & m) || (edx & m)) + { + return true; + } + } + return false; +#else + // On MSVC / Windows, use CPUID if available (simple check). If unsure, return false. + int cpu_info[4]; + __cpuid(cpu_info, 0); + int n_ids = cpu_info[0]; + if (n_ids < 7) + return false; + __cpuidex(cpu_info, 7, 1); + // same conservative check as above + const int candidate_masks[] = {(1 << 5), (1 << 26)}; + for (int m : candidate_masks) + { + if ((cpu_info[1] & m) || (cpu_info[2] & m) || (cpu_info[3] & m)) + { + return true; + } + } + return false; +#endif + } + }; + +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/cpu_types_avx2.hpp b/rmsnorm/rmsnorm_cpu/cpu_types_avx2.hpp new file mode 100644 index 0000000..4bff51b --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/cpu_types_avx2.hpp @@ -0,0 +1,541 @@ + +#ifndef CPU_TYPES_AVX_HPP +#define CPU_TYPES_AVX_HPP + +#include +#include + +namespace vec_op_avx2 +{ + +#define DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); +#define CPU_KERNEL_GUARD_OUT(NAME) +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + + namespace + { + template + constexpr void unroll_loop_item(std::integer_sequence, F &&f) + { + (f(std::integral_constant{}), ...); + } + }; // namespace + + template >> + constexpr void unroll_loop(F &&f) + { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); + } + + template + struct Vec + { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } + }; + + struct FP32Vec8; + struct FP32Vec16; + + struct FP16Vec8 : public Vec + { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit FP16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit FP16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + }; + + struct FP16Vec16 : public Vec + { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + // normal load + explicit FP16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + // non-temporal load + explicit FP16Vec16(bool, void *ptr) + : reg(_mm256_stream_load_si256((__m256i *)ptr)) {} + + explicit FP16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { _mm256_storeu_si256((__m256i *)ptr, reg); } + + void save(void *ptr, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } + }; + + struct BF16Vec8 : public Vec + { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + }; + + struct BF16Vec16 : public Vec + { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + // normal load + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + // non-temporal load + explicit BF16Vec16(bool, void *ptr) + : reg(_mm256_stream_load_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { _mm256_storeu_si256((__m256i *)ptr, reg); } + + void save(void *ptr, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } + }; + + struct BF16Vec32 : public Vec + { + constexpr static int VEC_ELEM_NUM = 32; + + __m256i reg_low; + __m256i reg_high; + + explicit BF16Vec32(const void *ptr) + : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), + reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + + explicit BF16Vec32(__m256i low, __m256i high) + : reg_low(low), reg_high(high) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg_low((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), + reg_high((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} + + void save(void *ptr) const + { + _mm256_storeu_si256((__m256i *)ptr, reg_low); + _mm256_storeu_si256((__m256i *)ptr + 1, reg_high); + } + }; + + struct FP32Vec4 : public Vec + { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg + { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + }; + + struct FP32Vec8 : public Vec + { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg + { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + + explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {} + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_cvtepu16_epi32(v.reg), 16))) {} + + float reduce_sum() const + { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) + { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const + { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const + { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const + { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } + }; + + struct FP32Vec16 : public Vec + { + constexpr static int VEC_ELEM_NUM = 16; + + union AliasReg + { + __m256 reg; + float values[8]; + }; + + __m256 reg_low; + __m256 reg_high; + + explicit FP32Vec16(float v) + : reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {} + + explicit FP32Vec16() + : reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) + : reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {} + + explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg_low((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)), + reg_high((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg_low(data.reg), reg_high(data.reg) {} + + explicit FP32Vec16(const FP16Vec16 &v) + { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + reg_low = _mm256_cvtph_ps(low); + reg_high = _mm256_cvtph_ps(high); + } + + explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); + __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); + + __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); + __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); + + reg_low = _mm256_castsi256_ps(v_low_shifted); + reg_high = _mm256_castsi256_ps(v_high_shifted); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const + { + return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), + _mm256_mul_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const + { + return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), + _mm256_add_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const + { + return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), + _mm256_sub_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const + { + return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), + _mm256_div_ps(reg_high, b.reg_high)); + } + + FP32Vec16 max(const FP32Vec16 &b) const + { + return FP32Vec16(_mm256_max_ps(reg_low, b.reg_low), + _mm256_max_ps(reg_high, b.reg_high)); + } + + float reduce_max() const + { + __m256 v = _mm256_max_ps(reg_low, reg_high); + // Permute to compare elements within 128-bit lanes + __m256 v_shuffled = _mm256_permute_ps( + v, 0b00001011); // Swap halves within each 128-bit lane + __m256 v_max = _mm256_max_ps(v, v_shuffled); + + v_shuffled = _mm256_permute_ps( + v_max, 0b00000001); // Shuffle elements within each 128-bit lane + v_max = _mm256_max_ps(v_max, v_shuffled); + + // Permute to compare elements between 128-bit lanes + v_shuffled = + _mm256_permute2f128_ps(v_max, v_max, 0b00000001); // Swap 128-bit lanes + v_max = _mm256_max_ps(v_max, v_shuffled); + + // At this point, the maximum value is present in all elements of v_max. + // Extract the first element for the scalar result. + return _mm256_cvtss_f32(v_max); // Extract the lowest 32-bit float + } + + float reduce_sum() const + { + FP32Vec8 low = FP32Vec8(reg_low); + FP32Vec8 high = FP32Vec8(reg_high); + return low.reduce_sum() + high.reduce_sum(); + } + + template + float reduce_sub_sum(int idx) + { + float sum = 0.0; + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + uint32_t mask = base_mask << (idx * group_size); + + AliasReg ar; + + auto func = [&sum, &mask, &ar](int i) + { + int flag = mask & 0x1; + mask = mask >> 1; + if (flag != 0) + sum += ar.values[i]; + }; + + ar.reg = reg_low; + unroll_loop(func); + + ar.reg = reg_high; + unroll_loop(func); + + return sum; + } + + void save(float *ptr) const + { + _mm256_storeu_ps(ptr, reg_low); + _mm256_storeu_ps(ptr + 8, reg_high); + } + }; + + template + struct VecType + { + using vec_type = void; + }; + + template + using vec_t = typename VecType::vec_type; + + template <> + struct VecType + { + using vec_type = FP32Vec8; + }; + + template <> + struct VecType + { + using vec_type = FP16Vec8; + }; + + template <> + struct VecType + { + using vec_type = BF16Vec8; + }; + + template + void storeFP32(float v, T *ptr) + { + *ptr = v; + } + + inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) + { + acc = acc + a * b; + } + + template <> + inline void storeFP32(float v, c10::Half *ptr) + { + *reinterpret_cast(ptr) = + _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + + inline FP16Vec8::FP16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} + + inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) + : reg(_mm256_insertf128_si256( + _mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), + FP16Vec8(FP32Vec8(v.reg_high)).reg, 1)) {} + + template <> + inline void storeFP32(float v, c10::BFloat16 *ptr) + { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); + } + + namespace { + __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) + { + __m256i ai = _mm256_castps_si256(a); + const __m256i lowmask = _mm256_set1_epi32(0x0000FFFF); + const __m256i tieval = _mm256_set1_epi32(0x00008000); // 0x8000 low half = tie + const __m256i bias = _mm256_set1_epi32(0x00007FFF); // rounding bias + + // detect tie (low 16 == 0x8000) + __m256i low = _mm256_and_si256(ai, lowmask); + __m256i tie = _mm256_cmpeq_epi32(low, tieval); + + // hi bits and LSB of hi (for tie->round-to-even) + __m256i hi = _mm256_srli_epi32(ai, 16); + __m256i lsb = _mm256_and_si256(hi, _mm256_set1_epi32(1)); + + // if tie && hi_lsb == 1 then add extra 1 to rounding bias (tie-to-even) + __m256i tiecorr = _mm256_and_si256(tie, lsb); + __m256i tiecorr_shift = _mm256_slli_epi32(tiecorr, 16); // add at bit16 + + // add bias and optional tie correction, then shift down + __m256i tmp = _mm256_add_epi32(ai, bias); + tmp = _mm256_add_epi32(tmp, tiecorr_shift); + tmp = _mm256_srli_epi32(tmp, 16); // now each lane holds rounded bf16 in low16 + + // pack 32->16 and reorder to extract 128-bit with 8 bf16 lanes + tmp = _mm256_packus_epi32(tmp, tmp); + tmp = _mm256_permute4x64_epi64(tmp, 0b00111001); + return _mm256_extracti128_si256(tmp, 0); + } +} + + inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} + + inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + { + BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); + BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); + reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); + } + + inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + + inline void mem_barrier() { _mm_mfence(); } +}; // namespace vec_op_avx2 + +#endif diff --git a/rmsnorm/rmsnorm_cpu/cpu_types_avx512.hpp b/rmsnorm/rmsnorm_cpu/cpu_types_avx512.hpp new file mode 100644 index 0000000..aae3bca --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/cpu_types_avx512.hpp @@ -0,0 +1,469 @@ + +#ifndef CPU_TYPES_AVX512_HPP +#define CPU_TYPES_AVX512_HPP + +#include +#include + +namespace vec_op_avx512 +{ + +#define DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); +#define CPU_KERNEL_GUARD_OUT(NAME) +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + + namespace + { + template + constexpr void unroll_loop_item(std::integer_sequence, F &&f) + { + (f(std::integral_constant{}), ...); + } + }; // namespace + + template >> + constexpr void unroll_loop(F &&f) + { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); + } + + template + struct Vec + { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } + }; + + struct FP32Vec8; + struct FP32Vec16; + + struct FP16Vec8 : public Vec + { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit FP16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit FP16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + }; + + struct FP16Vec16 : public Vec + { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + // normal load + explicit FP16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + // non-temporal load + explicit FP16Vec16(bool, void *ptr) + : reg(_mm256_stream_load_si256((__m256i *)ptr)) {} + + explicit FP16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { _mm256_storeu_si256((__m256i *)ptr, reg); } + + void save(void *ptr, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } + }; + + struct BF16Vec8 : public Vec + { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } + }; + + struct BF16Vec16 : public Vec + { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + // normal load + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + // non-temporal load + explicit BF16Vec16(bool, void *ptr) + : reg(_mm256_stream_load_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { _mm256_storeu_si256((__m256i *)ptr, reg); } + + void save(void *ptr, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } + }; + + struct BF16Vec32 : public Vec + { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32() : reg(_mm512_setzero_si512()) {} + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } + }; + + struct FP32Vec4 : public Vec + { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg + { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} + }; + + struct FP32Vec8 : public Vec + { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg + { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + + explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {} + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const + { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop( + [&result, &ar](int i) + { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const + { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const + { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const + { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const + { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } + }; + + struct FP32Vec16 : public Vec + { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg + { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + // normal load + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + // non-temporal load + explicit FP32Vec16(bool, void *ptr) + : reg((__m512)_mm512_stream_load_si512(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {} + + explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const + { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const + { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const + { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const + { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + FP32Vec16 clamp(const FP32Vec16 &min, const FP32Vec16 &max) const + { + return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg))); + } + + FP32Vec16 max(const FP32Vec16 &b) const + { + return FP32Vec16(_mm512_max_ps(reg, b.reg)); + } + + FP32Vec16 max(const FP32Vec16 &b, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16 &b) const + { + return FP32Vec16(_mm512_min_ps(reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16 &b, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + float reduce_max() const { return _mm512_reduce_max_ps(reg); } + + float reduce_min() const { return _mm512_reduce_min_ps(reg); } + + template + float reduce_sub_sum(int idx) + { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } + + void save(float *ptr, const int elem_num) const + { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_ps(ptr, mask, reg); + } + }; + + template + struct VecType + { + using vec_type = void; + }; + + template + using vec_t = typename VecType::vec_type; + + template <> + struct VecType + { + using vec_type = FP32Vec16; + }; + + template <> + struct VecType + { + using vec_type = FP16Vec16; + }; + + template <> + struct VecType + { + using vec_type = BF16Vec16; + }; + + template + void storeFP32(float v, T *ptr) + { + *ptr = v; + } + + inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) + { + acc = acc + a * b; + } + + template <> + inline void storeFP32(float v, c10::Half *ptr) + { + *reinterpret_cast(ptr) = + _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + } + + inline FP16Vec8::FP16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} + + inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} + + template <> + inline void storeFP32(float v, c10::BFloat16 *ptr) + { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); + } + + inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + + inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + + inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) + { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); + } + + inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + + inline void mem_barrier() { _mm_mfence(); } +}; // namespace vec_op_avx512 + +#endif diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_avx2.cpp b/rmsnorm/rmsnorm_cpu/rmsnorm_avx2.cpp new file mode 100644 index 0000000..f9dbf2e --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_avx2.cpp @@ -0,0 +1,207 @@ +// AVX2 implementation - compile with -mavx2 +#include +#include +#include +#include +#include "cpu_types_avx2.hpp" + +namespace rmsnorm_cpu +{ + namespace avx2 + { + + template + void rms_norm_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) + { + using namespace vec_op_avx2; + using scalar_vec_t = vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) + { + FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + FP32Vec8 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + FP32Vec8 fp32_x(x); + FP32Vec8 fp32_w(w); + + FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(output_p + j); + } + } + } + + void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weight, + const float epsilon) + { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] + { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); + } + + template + void rms_norm_backward_impl(scalar_t *__restrict__ grad_input, + scalar_t *__restrict__ grad_weight, + const scalar_t *__restrict__ grad_out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) + { + using namespace vec_op_avx2; + using scalar_vec_t = vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + + int HS = hidden_size; + int NT = num_tokens; + + // initialize grad_weight to zero + for (int j = 0; j < HS; ++j) + { + grad_weight[j] = (scalar_t)0; + } + + // Allocate per-thread accumulators and re-run accumulation serially per-thread. + int max_threads = omp_get_max_threads(); + std::vector> all_acc(max_threads, std::vector(HS, (scalar_t)0)); + + // Parallel over tokens: compute grad_input and accumulate into thread-local + // buffers for grad_weight to avoid atomics. +#pragma omp parallel + { + int tid = omp_get_thread_num(); + int nthreads = omp_get_num_threads(); + int start = (NT * tid) / nthreads; + int end = (NT * (tid + 1)) / nthreads; + + auto &local_acc = all_acc[tid]; + for (int i = start; i < end; ++i) + { + const scalar_t *input_p = input + i * HS; + const scalar_t *grad_out_p = grad_out + i * HS; + scalar_t *grad_input_p = grad_input + i * HS; + + // compute variance and inv_rms for this token + FP32Vec8 variance(0.0f); + for (int j = 0; j < HS; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + FP32Vec8 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float inv_rms = 1.0f / sqrtf(variance.reduce_sum() / (float)HS + epsilon); + FP32Vec8 fp32_inv_rms(inv_rms); + + // compute S = sum_k (g * w * x) for this token + FP32Vec8 Svec(0.0f); + for (int j = 0; j < HS; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + scalar_vec_t g_out(grad_out_p + j); + scalar_vec_t w(weight + j); + + FP32Vec8 fp32_x(x); + FP32Vec8 fp32_g_out(g_out); + FP32Vec8 fp32_w(w); + + Svec = Svec + fp32_g_out * fp32_w * fp32_x; + } + float S = Svec.reduce_sum(); + float S_over_H = S / (float)HS; + float inv_rms3 = inv_rms * inv_rms * inv_rms; + FP32Vec8 fp32_inv_rms3(inv_rms3); + FP32Vec8 fp32_S_over_H(S_over_H); + + for (int j = 0; j < HS; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + scalar_vec_t g_out(grad_out_p + j); + scalar_vec_t w(weight + j); + + FP32Vec8 fp32_x(x); + FP32Vec8 fp32_g_out(g_out); + FP32Vec8 fp32_w(w); + + // grad_input = g * w * inv_rms - x * inv_rms^3 * (S/H) + FP32Vec8 term1 = fp32_g_out * fp32_w * fp32_inv_rms; + FP32Vec8 term2 = fp32_x * fp32_inv_rms3 * fp32_S_over_H; + FP32Vec8 fp32_grad_in = term1 - term2; + scalar_vec_t sgrad_in(fp32_grad_in); + sgrad_in.save(grad_input_p + j); + + // accumulate grad_weight += input * grad_out * inv_rms + FP32Vec8 prod = fp32_x * fp32_g_out * fp32_inv_rms; + scalar_vec_t sprod(prod); + scalar_t tmp[VEC_ELEM_NUM]; + sprod.save(tmp); + for (int e = 0; e < VEC_ELEM_NUM; ++e) + { + local_acc[j + e] += tmp[e]; + } + } + } + } + + // reduce all_acc into grad_weight + for (int t = 0; t < (int)all_acc.size(); ++t) + { + for (int j = 0; j < HS; ++j) + { + grad_weight[j] += all_acc[t][j]; + } + } + } + + void rms_norm_backward(torch::Tensor &grad_input, torch::Tensor &grad_weight, + const torch::Tensor &grad_out, const torch::Tensor &input, + const torch::Tensor &weight, const float epsilon) + { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_backward_impl", [&] + { + CPU_KERNEL_GUARD_IN(rms_norm_backward_impl) + rms_norm_backward_impl(grad_input.data_ptr(), + grad_weight.data_ptr(), + grad_out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_backward_impl) }); + } + + } // namespace avx2 +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_avx2.hpp b/rmsnorm/rmsnorm_cpu/rmsnorm_avx2.hpp new file mode 100644 index 0000000..538f140 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_avx2.hpp @@ -0,0 +1,21 @@ +// AVX2 implementation - compile with -mavx2 +#include +#include +#include +#include +#include "cpu_types_avx2.hpp" + +namespace rmsnorm_cpu +{ + namespace avx2 + { + + void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weight, + const float epsilon); + + void rms_norm_backward(torch::Tensor &grad_input, torch::Tensor &grad_weight, + const torch::Tensor &grad_out, const torch::Tensor &input, + const torch::Tensor &weight, const float epsilon); + + } // namespace avx2 +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_avx512.cpp b/rmsnorm/rmsnorm_cpu/rmsnorm_avx512.cpp new file mode 100644 index 0000000..861e182 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_avx512.cpp @@ -0,0 +1,208 @@ +// AVX512 implementation - compile with -mavx512f -mavx512bf16 +#include +#include +#include +#include +#include "cpu_types_avx512.hpp" + +namespace rmsnorm_cpu +{ + namespace avx512 + { + + template + void rms_norm_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) + { + using namespace vec_op_avx512; + using scalar_vec_t = vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) + { + FP32Vec16 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + FP32Vec16 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + FP32Vec16 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + FP32Vec16 fp32_x(x); + FP32Vec16 fp32_w(w); + + FP32Vec16 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out); + out.save(output_p + j); + } + } + } + + void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weight, + const float epsilon) + { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] + { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); + } + + template + void rms_norm_backward_impl(scalar_t *__restrict__ grad_input, + scalar_t *__restrict__ grad_weight, + const scalar_t *__restrict__ grad_out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) + { + using namespace vec_op_avx512; + using scalar_vec_t = vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + + int HS = hidden_size; + int NT = num_tokens; + + // initialize grad_weight to zero + for (int j = 0; j < HS; ++j) + { + grad_weight[j] = (scalar_t)0; + } + + // Allocate per-thread accumulators to avoid atomics/critical sections. + int max_threads = omp_get_max_threads(); + std::vector> all_acc(max_threads, std::vector(HS, (scalar_t)0)); + + // Parallel over tokens: compute grad_input and accumulate into thread-local + // buffers for grad_weight to avoid atomics. +#pragma omp parallel + { + int tid = omp_get_thread_num(); + int nthreads = omp_get_num_threads(); + int start = (NT * tid) / nthreads; + int end = (NT * (tid + 1)) / nthreads; + + auto &local_acc = all_acc[tid]; + + for (int i = start; i < end; ++i) + { + const scalar_t *input_p = input + i * HS; + const scalar_t *grad_out_p = grad_out + i * HS; + scalar_t *grad_input_p = grad_input + i * HS; + + // compute variance and inv_rms for this token + FP32Vec16 variance(0.0f); + for (int j = 0; j < HS; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + FP32Vec16 fp32_x(x); + variance = variance + fp32_x * fp32_x; + } + + float inv_rms = 1.0f / sqrtf(variance.reduce_sum() / (float)HS + epsilon); + FP32Vec16 fp32_inv_rms(inv_rms); + + // compute S = sum_k (g * w * x) for this token + FP32Vec16 Svec(0.0f); + for (int j = 0; j < HS; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + scalar_vec_t g_out(grad_out_p + j); + scalar_vec_t w(weight + j); + + FP32Vec16 fp32_x(x); + FP32Vec16 fp32_g_out(g_out); + FP32Vec16 fp32_w(w); + + Svec = Svec + fp32_g_out * fp32_w * fp32_x; + } + float S = Svec.reduce_sum(); + float S_over_H = S / (float)HS; + float inv_rms3 = inv_rms * inv_rms * inv_rms; + FP32Vec16 fp32_inv_rms3(inv_rms3); + FP32Vec16 fp32_S_over_H(S_over_H); + + for (int j = 0; j < HS; j += VEC_ELEM_NUM) + { + scalar_vec_t x(input_p + j); + scalar_vec_t g_out(grad_out_p + j); + scalar_vec_t w(weight + j); + + FP32Vec16 fp32_x(x); + FP32Vec16 fp32_g_out(g_out); + FP32Vec16 fp32_w(w); + + // grad_input = g * w * inv_rms - x * inv_rms^3 * (S/H) + FP32Vec16 term1 = fp32_g_out * fp32_w * fp32_inv_rms; + FP32Vec16 term2 = fp32_x * fp32_inv_rms3 * fp32_S_over_H; + FP32Vec16 fp32_grad_in = term1 - term2; + scalar_vec_t sgrad_in(fp32_grad_in); + sgrad_in.save(grad_input_p + j); + + // accumulate grad_weight into thread-local buffer + FP32Vec16 prod = fp32_x * fp32_g_out * fp32_inv_rms; + scalar_vec_t sprod(prod); + scalar_t tmp[VEC_ELEM_NUM]; + sprod.save(tmp); + for (int e = 0; e < VEC_ELEM_NUM; ++e) + { + local_acc[j + e] += tmp[e]; + } + } + } + } + + // Reduce per-thread accumulators into global grad_weight + for (int t = 0; t < (int)all_acc.size(); ++t) + { + for (int j = 0; j < HS; ++j) + { + grad_weight[j] += all_acc[t][j]; + } + } + } + + void rms_norm_backward(torch::Tensor &grad_input, torch::Tensor &grad_weight, + const torch::Tensor &grad_out, const torch::Tensor &input, + const torch::Tensor &weight, const float epsilon) + { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_backward_impl", [&] + { + CPU_KERNEL_GUARD_IN(rms_norm_backward_impl) + rms_norm_backward_impl(grad_input.data_ptr(), + grad_weight.data_ptr(), + grad_out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_backward_impl) }); + } + + } // namespace avx512 +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_avx512.hpp b/rmsnorm/rmsnorm_cpu/rmsnorm_avx512.hpp new file mode 100644 index 0000000..b1ac5f1 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_avx512.hpp @@ -0,0 +1,21 @@ +// AVX512 implementation - compile with -mavx512f -mavx512bf16 +#include +#include +#include +#include +#include "cpu_types_avx512.hpp" + +namespace rmsnorm_cpu +{ + namespace avx512 + { + + void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weight, + const float epsilon); + + void rms_norm_backward(torch::Tensor &grad_input, torch::Tensor &grad_weight, + const torch::Tensor &grad_out, const torch::Tensor &input, + const torch::Tensor &weight, const float epsilon); + + } // namespace avx512 +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_cpu.cpp b/rmsnorm/rmsnorm_cpu/rmsnorm_cpu.cpp new file mode 100644 index 0000000..91f6951 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_cpu.cpp @@ -0,0 +1,82 @@ +#include "cpu_features.hpp" +#include "rmsnorm_avx2.hpp" +#include "rmsnorm_avx512.hpp" +#include +#include + +namespace rmsnorm_cpu +{ + + // Main dispatcher that selects the best implementation based on runtime CPU features + void rmsnorm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weight, + float epsilon) + { + // Runtime CPU feature detection and dispatch + if (CPUFeatures::hasAVX512BF16()) + { + // Use AVX512 optimized implementation + rmsnorm_cpu::avx512::rms_norm(out, input, weight, epsilon); + } + else if (CPUFeatures::hasAVX2()) + { + // Use AVX2 optimized implementation + rmsnorm_cpu::avx2::rms_norm(out, input, weight, epsilon); + } + else + { + // Fallback to ATen implementation + auto input1 = input.to(at::kFloat); + auto variance = at::mean(at::pow(input1, 2), -1, true); + auto hidden_states = at::rsqrt(at::add(variance, epsilon)); + out = at::mul(weight, at::mul(input1, hidden_states)).to(input.scalar_type()); + } + } + + void rmsnorm_backward( + torch::Tensor &grad_input, + torch::Tensor &grad_weight, + const torch::Tensor &grad_output, + const torch::Tensor &hidden_states, + const torch::Tensor &weight, + float variance_epsilon) + { + // Runtime CPU feature detection and dispatch + if (CPUFeatures::hasAVX512BF16()) + { + // Use AVX512 optimized implementation + rmsnorm_cpu::avx512::rms_norm_backward(grad_input, grad_weight, grad_output, hidden_states, weight, variance_epsilon); + } + else if (CPUFeatures::hasAVX2()) + { + // Use AVX2 optimized implementation + rmsnorm_cpu::avx2::rms_norm_backward(grad_input, grad_weight, grad_output, hidden_states, weight, variance_epsilon); + } + else + { + // Fallback to ATen implementation (compute gradients in FP32) + auto g = grad_output.to(at::kFloat); // (N, H) + auto x = hidden_states.to(at::kFloat); // (N, H) + auto w = weight.to(at::kFloat); // (H,) + + const int64_t H = x.size(-1); + + // inv_rms per token: rsqrt(mean(x^2, -1) + eps) -> shape (N,1) + auto variance = x.pow(2).mean(-1, /*keepdim=*/true); + auto inv_rms = (variance + variance_epsilon).rsqrt(); + + // S = sum_k g * w * x over last dim -> shape (N,1) + auto S = (g * x * w).sum(-1, /*keepdim=*/true); + + // grad_input = g * w * inv_rms - (x * inv_rms^3 * S) / H + auto inv_rms3 = inv_rms.pow(3); + auto grad_input1 = g * w * inv_rms - x * inv_rms3 * S / static_cast(H); + + // grad_weight = sum over tokens of g * x * inv_rms -> shape (H,) + auto grad_weight1 = (g * x * inv_rms).sum(0); + + // copy back to requested dtypes / tensors + grad_input.copy_(grad_input1.to(grad_input.scalar_type())); + grad_weight.copy_(grad_weight1.to(grad_weight.scalar_type())); + } + } +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_cpu.hpp b/rmsnorm/rmsnorm_cpu/rmsnorm_cpu.hpp new file mode 100644 index 0000000..d8750a5 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_cpu.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "cpu_features.hpp" +#include "rmsnorm_avx2.hpp" +#include "rmsnorm_avx512.hpp" +#include +#include + +namespace rmsnorm_cpu +{ + + // Main dispatcher that selects the best implementation based on runtime CPU features + void rmsnorm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weight, + float epsilon); + + void rmsnorm_backward( + torch::Tensor &grad_input, + torch::Tensor &grad_weight, + const torch::Tensor &grad_output, + const torch::Tensor &hidden_states, + const torch::Tensor &weight, + float variance_epsilon); +} // namespace rmsnorm_cpu diff --git a/rmsnorm/rmsnorm_cpu/rmsnorm_cpu_torch.cpp b/rmsnorm/rmsnorm_cpu/rmsnorm_cpu_torch.cpp new file mode 100644 index 0000000..c5bc1e1 --- /dev/null +++ b/rmsnorm/rmsnorm_cpu/rmsnorm_cpu_torch.cpp @@ -0,0 +1,90 @@ +#include +#include "rmsnorm_cpu.hpp" + +// Forward implementation for CPU +torch::Tensor rmsnorm_cpu_forward( + const torch::Tensor &hidden_states, + const torch::Tensor &weight, + double variance_epsilon) +{ + TORCH_CHECK(hidden_states.is_contiguous(), "hidden_states must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + + auto output = torch::empty_like(hidden_states); + + rmsnorm_cpu::rmsnorm( + output, + hidden_states, + weight, + static_cast(variance_epsilon)); + + return output; +} + +// Backward implementation for CPU +std::tuple rmsnorm_cpu_backward( + const torch::Tensor &grad_output, + const torch::Tensor &hidden_states, + const torch::Tensor &weight, + double variance_epsilon) +{ + TORCH_CHECK(grad_output.is_contiguous(), "grad_output must be contiguous"); + TORCH_CHECK(hidden_states.is_contiguous(), "hidden_states must be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + + auto grad_input = torch::empty_like(hidden_states); + auto grad_weight = torch::zeros_like(weight); + + rmsnorm_cpu::rmsnorm_backward( + grad_input, + grad_weight, + grad_output, + hidden_states, + weight, + static_cast(variance_epsilon)); + + return std::make_tuple(grad_input, grad_weight); +} + +// Custom autograd function for CPU RMSNorm +class RMSNormCPUFunction : public torch::autograd::Function +{ +public: + static torch::Tensor forward( + torch::autograd::AutogradContext *ctx, + const torch::Tensor &hidden_states, + const torch::Tensor &weight, + double variance_epsilon) + { + ctx->save_for_backward({hidden_states, weight}); + ctx->saved_data["variance_epsilon"] = variance_epsilon; + return rmsnorm_cpu_forward(hidden_states, weight, variance_epsilon); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::variable_list grad_outputs) + { + auto saved = ctx->get_saved_variables(); + auto hidden_states = saved[0]; + auto weight = saved[1]; + auto variance_epsilon = ctx->saved_data["variance_epsilon"].toDouble(); + auto grad_output = grad_outputs[0]; + + auto grads = rmsnorm_cpu_backward(grad_output, hidden_states, weight, variance_epsilon); + auto grad_input = std::get<0>(grads); + auto grad_weight = std::get<1>(grads); + + return {grad_input, grad_weight, torch::Tensor()}; + } +}; + +torch::Tensor apply_rms_norm_cpu( + const torch::Tensor &hidden_states, + const torch::Tensor &weight, + double variance_epsilon) +{ + + auto output = RMSNormCPUFunction::apply(hidden_states, weight, variance_epsilon); + return output; +} diff --git a/rmsnorm/tests/test_rmsnorm.py b/rmsnorm/tests/test_rmsnorm.py index 0561e2c..3d76c9f 100644 --- a/rmsnorm/tests/test_rmsnorm.py +++ b/rmsnorm/tests/test_rmsnorm.py @@ -4,9 +4,11 @@ from rmsnorm.layers import RMSNorm - def test_rmsnorm(): - device = torch.device("xpu") + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + else: + device = torch.device("cpu") x = torch.randn(1024, 1024, dtype=torch.bfloat16, device=device) rmsnorm_layer = RMSNorm() rmsnorm_layer.weight = torch.randn(1024, device=device, dtype=torch.bfloat16) @@ -16,12 +18,15 @@ def test_rmsnorm(): x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + rmsnorm_layer.variance_epsilon) - ref_out = rmsnorm_layer.weight * x.to(torch.bfloat16) - torch.testing.assert_close(output, ref_out) + ref_out = x * rmsnorm_layer.weight + torch.testing.assert_close(output, ref_out.to(torch.bfloat16)) def test_rmsnorm_backward(): - device = torch.device("xpu") + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + else: + device = torch.device("cpu") x = torch.randn(1024, 1024, dtype=torch.float32, device=device, requires_grad=True) rmsnorm_layer = RMSNorm() rmsnorm_layer.weight = torch.randn(1024, device=device, dtype=torch.float32, requires_grad=True) diff --git a/rmsnorm/torch-ext/torch_binding.cpp b/rmsnorm/torch-ext/torch_binding.cpp index c9424d5..7d75981 100644 --- a/rmsnorm/torch-ext/torch_binding.cpp +++ b/rmsnorm/torch-ext/torch_binding.cpp @@ -4,23 +4,49 @@ #include #endif +// Forward declarations for XPU +#if defined(XPU_KERNEL) #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kXPU, #x " must be on XPU") - torch::Tensor _apply_rms_norm(torch::Tensor const &hidden_states, torch::Tensor const &weight, double variance_epsilon); +#endif +#if defined(CPU_KERNEL) +torch::Tensor apply_rms_norm_cpu( + torch::Tensor const &hidden_states, + torch::Tensor const &weight, + double variance_epsilon); +#endif +// Unified dispatcher for both CPU and XPU torch::Tensor apply_rms_norm(torch::Tensor const &hidden_states, torch::Tensor const &weight, double variance_epsilon) { - CHECK_DEVICE(hidden_states); CHECK_DEVICE(weight); -#if defined(XPU_KERNEL) - c10::DeviceGuard device_guard{hidden_states.device()}; +#if defined(CPU_KERNEL) + if (hidden_states.device().type() == torch::kCPU) { + // CPU path with autograd support + TORCH_CHECK(weight.device().type() == torch::kCPU, "weight must be on CPU"); + return apply_rms_norm_cpu(hidden_states, weight, variance_epsilon); + } +#elif defined(XPU_KERNEL) + if (hidden_states.device().type() == torch::kXPU) { + // XPU path + CHECK_DEVICE(hidden_states); CHECK_DEVICE(weight); + c10::DeviceGuard device_guard{hidden_states.device()}; + return _apply_rms_norm(hidden_states, weight, variance_epsilon); + } #endif - return _apply_rms_norm(hidden_states, weight, variance_epsilon); + else { + TORCH_CHECK(false, "Unsupported device type: ", hidden_states.device().type()); + } } TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("apply_rms_norm(Tensor hidden_states, Tensor weight, float variance_epsilon) -> Tensor"); -#if defined(XPU_KERNEL) +#if defined(CPU_KERNEL) + // Register CPU implementation + ops.impl("apply_rms_norm", torch::kCPU, &apply_rms_norm); + ops.impl("apply_rms_norm", c10::DispatchKey::Autograd, &apply_rms_norm); +#elif defined(XPU_KERNEL) + // Register XPU implementation ops.impl("apply_rms_norm", torch::kXPU, &apply_rms_norm); ops.impl("apply_rms_norm", c10::DispatchKey::Autograd, &apply_rms_norm); #endif