From 4c4e6006f7d074b8435203e07d0f011ac87a761f Mon Sep 17 00:00:00 2001 From: m96-chan Date: Wed, 31 Dec 2025 06:56:38 +0900 Subject: [PATCH 1/2] feat(nn): add positional encoding operations (PoPE, ALiBi, YaRN, NTK, ReLU2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements issues #136, #167, #169, #171: - ReLU2 activation (squared ReLU from Primer paper) - PoPE (sinusoidal positional encoding, additive alternative to RoPE) - ALiBi (attention with linear biases) - YaRN/NTK-aware RoPE extensions for context length scaling - Linear position interpolation for RoPE New kernels: - native/ops/nn/activation/relu2.inl - native/ops/nn/pope/pope_kernels.cuh, pope.inl - native/ops/nn/alibi/alibi_kernels.cuh, alibi.inl - native/ops/nn/rope/rope_ext_kernels.cuh, rope_ext.inl Tests: 22 passed, 2 skipped (GPU-only bf16/f16 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- native/bindings/nn/activation.cpp | 10 + native/bindings/nn/rope.cpp | 58 ++++++ native/ops/nn/activation/relu2.inl | 74 ++++++++ native/ops/nn/activation_kernels.cuh | 38 ++++ native/ops/nn/alibi/alibi.inl | 109 +++++++++++ native/ops/nn/alibi/alibi_kernels.cuh | 135 +++++++++++++ native/ops/nn/nn.cu | 4 + native/ops/nn/pope/pope.inl | 103 ++++++++++ native/ops/nn/pope/pope_kernels.cuh | 178 +++++++++++++++++ native/ops/nn/rope/rope_ext.inl | 114 +++++++++++ native/ops/nn/rope/rope_ext_kernels.cuh | 214 +++++++++++++++++++++ native/ops/ops.cuh | 44 +++++ src/pygpukit/ops/nn/__init__.py | 24 +++ src/pygpukit/ops/nn/activation.py | 48 +++++ src/pygpukit/ops/nn/rope.py | 242 ++++++++++++++++++++++++ tests/test_positional_encoding.py | 235 +++++++++++++++++++++++ tests/test_relu2.py | 137 ++++++++++++++ 17 files changed, 1767 insertions(+) create mode 100644 native/ops/nn/activation/relu2.inl create mode 100644 native/ops/nn/alibi/alibi.inl create mode 100644 native/ops/nn/alibi/alibi_kernels.cuh create mode 100644 native/ops/nn/pope/pope.inl create mode 100644 native/ops/nn/pope/pope_kernels.cuh create mode 100644 native/ops/nn/rope/rope_ext.inl create mode 100644 native/ops/nn/rope/rope_ext_kernels.cuh create mode 100644 tests/test_positional_encoding.py create mode 100644 tests/test_relu2.py diff --git a/native/bindings/nn/activation.cpp b/native/bindings/nn/activation.cpp index 5c6fd95..d17a125 100644 --- a/native/bindings/nn/activation.cpp +++ b/native/bindings/nn/activation.cpp @@ -42,4 +42,14 @@ void init_nn_activation(py::module_& m) { "Fused linear + bias + GELU: output = gelu(input @ weight^T + bias)\n" "Uses CUTLASS TensorCore epilogue fusion for efficiency.\n" "input: [batch, in_features], weight: [out_features, in_features], bias: [out_features]"); + + // ReLU squared (Primer paper) + m.def("relu2", py::overload_cast(&ops::relu2), + py::arg("input"), + "ReLU squared activation: y = (max(0, x))^2\n" + "Introduced in the Primer paper (Google, 2021)."); + + m.def("relu2_", py::overload_cast(&ops::relu2), + py::arg("input"), py::arg("out"), + "ReLU squared with output buffer (for CUDA Graph capture)"); } diff --git a/native/bindings/nn/rope.cpp b/native/bindings/nn/rope.cpp index 40f96cb..858713f 100644 --- a/native/bindings/nn/rope.cpp +++ b/native/bindings/nn/rope.cpp @@ -19,4 +19,62 @@ void init_nn_rope(py::module_& m) { "q: [seq_len, n_heads_q, head_dim] (bf16 or f16)\n" "k: [seq_len, n_heads_k, head_dim] (bf16 or f16)\n" "cos, sin: [seq_len, head_dim] (f32)"); + + // NTK-aware RoPE initialization + m.def("rope_init_ntk_aware", &ops::rope_init_ntk_aware, + py::arg("max_seq_len"), py::arg("head_dim"), + py::arg("base") = 10000.0f, py::arg("scale") = 1.0f, + "Initialize RoPE with NTK-aware frequency scaling.\n" + "Scales base frequency for context extension: base' = base * scale^(dim/(dim-2))\n" + "Returns: tuple of (cos_table, sin_table) each [max_seq_len, head_dim]"); + + // YaRN RoPE initialization + m.def("rope_init_yarn", &ops::rope_init_yarn, + py::arg("max_seq_len"), py::arg("head_dim"), + py::arg("base") = 10000.0f, py::arg("scale") = 1.0f, + py::arg("original_max_len") = 4096, py::arg("beta_fast") = 32.0f, + py::arg("beta_slow") = 1.0f, py::arg("mscale") = 0.1f, + "Initialize RoPE with YaRN dimension-wise interpolation.\n" + "Different scaling for different frequency bands (low/mid/high).\n" + "Returns: tuple of (cos_table, sin_table) each [max_seq_len, head_dim]"); + + // Linear position interpolation + m.def("rope_init_linear", &ops::rope_init_linear, + py::arg("max_seq_len"), py::arg("head_dim"), + py::arg("base") = 10000.0f, py::arg("scale") = 1.0f, + "Initialize RoPE with linear position interpolation.\n" + "Simple baseline: pos' = pos / scale. Degrades at high scales.\n" + "Returns: tuple of (cos_table, sin_table) each [max_seq_len, head_dim]"); + + // PoPE (Positional Encoding) - Alternative to RoPE + m.def("pope_init_encoding", &ops::pope_init_encoding, + py::arg("max_seq_len"), py::arg("head_dim"), py::arg("base") = 10000.0f, + "Initialize sinusoidal positional encoding table.\n" + "Returns: encoding tensor [max_seq_len, head_dim]"); + + m.def("pope_inplace", &ops::pope_inplace, + py::arg("q"), py::arg("k"), py::arg("encoding"), py::arg("start_pos") = 0, + "Apply additive positional encoding to Q and K in-place.\n" + "q: [seq_len, n_heads_q, head_dim]\n" + "k: [seq_len, n_heads_k, head_dim]\n" + "encoding: [max_seq_len, head_dim] (f32)"); + + // ALiBi (Attention with Linear Biases) + m.def("alibi_init_slopes", &ops::alibi_init_slopes, + py::arg("num_heads"), + "Initialize ALiBi head-specific slopes.\n" + "m_h = 2^(-8 * h / num_heads)\n" + "Returns: slopes tensor [num_heads]"); + + m.def("alibi_compute_bias", &ops::alibi_compute_bias, + py::arg("seq_len"), py::arg("num_heads"), py::arg("slopes"), + py::arg("causal") = true, + "Compute ALiBi bias matrix for attention.\n" + "Returns: bias tensor [num_heads, seq_len, seq_len]"); + + m.def("alibi_add_bias", &ops::alibi_add_bias, + py::arg("scores"), py::arg("slopes"), py::arg("start_pos") = 0, + "Add ALiBi bias to attention scores in-place.\n" + "scores: [batch, num_heads, q_len, kv_len]\n" + "slopes: [num_heads]"); } diff --git a/native/ops/nn/activation/relu2.inl b/native/ops/nn/activation/relu2.inl new file mode 100644 index 0000000..6c4c7a5 --- /dev/null +++ b/native/ops/nn/activation/relu2.inl @@ -0,0 +1,74 @@ +/** + * ReLU squared (ReLU^2) activation: (max(0, x))^2 + * + * Introduced in the Primer paper (Google, 2021). + * Benefits: stronger sparsity, continuous first derivative. + */ + +namespace pygpukit { +namespace ops { + +// Internal dispatch helper with capture stream support +static void relu2_dispatch(const GPUArray& input, GPUArray& result) { + size_t n = input.size(); + const int block_size = 256; + const int grid_size = (n + block_size - 1) / block_size; + + // Use capture stream if available + cudaStream_t stream = internal::get_capture_stream(); + + switch (input.dtype()) { + case DataType::Float32: + nn::relu2_f32_kernel<<>>( + static_cast(input.data()), + static_cast(result.data()), + n); + break; + case DataType::Float16: + nn::relu2_f16_kernel<<>>( + static_cast(input.data()), + static_cast<__half*>(result.data()), + n); + break; + case DataType::BFloat16: + nn::relu2_bf16_kernel<<>>( + static_cast(input.data()), + static_cast<__nv_bfloat16*>(result.data()), + n); + break; + default: + break; + } +} + +GPUArray relu2(const GPUArray& input) { + if (input.dtype() != DataType::Float32 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("relu2 only supports float32, float16, bfloat16"); + } + + GPUArray result(input.shape(), input.dtype()); + relu2_dispatch(input, result); + sync_and_check("relu2 kernel failed"); + return result; +} + +// ReLU squared with output buffer (for CUDA Graph capture) +void relu2(const GPUArray& input, GPUArray& out) { + if (input.dtype() != DataType::Float32 && + input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) { + throw std::runtime_error("relu2 only supports float32, float16, bfloat16"); + } + if (input.dtype() != out.dtype()) { + throw std::runtime_error("relu2: dtype mismatch between input and output"); + } + if (input.shape() != out.shape()) { + throw std::runtime_error("relu2: shape mismatch between input and output"); + } + + relu2_dispatch(input, out); + sync_and_check("relu2 kernel failed"); +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/activation_kernels.cuh b/native/ops/nn/activation_kernels.cuh index 5669c1c..2f7db39 100644 --- a/native/ops/nn/activation_kernels.cuh +++ b/native/ops/nn/activation_kernels.cuh @@ -45,6 +45,11 @@ __device__ __forceinline__ float sigmoid_f32(float x) { return 1.0f / (1.0f + expf(-x)); } +__device__ __forceinline__ float relu2_f32(float x) { + float relu_val = fmaxf(0.0f, x); + return relu_val * relu_val; +} + // ============================================================================ // Kernel declarations (always available) // ============================================================================ @@ -88,6 +93,14 @@ __global__ void tanh_f16_kernel(const __half* __restrict__ input, __global__ void tanh_bf16_kernel(const __nv_bfloat16* __restrict__ input, __nv_bfloat16* __restrict__ output, size_t n); +// ReLU squared (Primer paper) +__global__ void relu2_f32_kernel(const float* __restrict__ input, + float* __restrict__ output, size_t n); +__global__ void relu2_f16_kernel(const __half* __restrict__ input, + __half* __restrict__ output, size_t n); +__global__ void relu2_bf16_kernel(const __nv_bfloat16* __restrict__ input, + __nv_bfloat16* __restrict__ output, size_t n); + // ============================================================================ // Kernel definitions (only when PYGPUKIT_IMPLEMENT_NN_KERNELS is defined) // ============================================================================ @@ -229,6 +242,31 @@ __global__ void tanh_bf16_kernel(const __nv_bfloat16* __restrict__ input, } } +// ReLU squared kernels +__global__ void relu2_f32_kernel(const float* __restrict__ input, + float* __restrict__ output, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) output[idx] = relu2_f32(input[idx]); +} + +__global__ void relu2_f16_kernel(const __half* __restrict__ input, + __half* __restrict__ output, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(input[idx]); + output[idx] = __float2half(relu2_f32(x)); + } +} + +__global__ void relu2_bf16_kernel(const __nv_bfloat16* __restrict__ input, + __nv_bfloat16* __restrict__ output, size_t n) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(input[idx]); + output[idx] = __float2bfloat16(relu2_f32(x)); + } +} + #endif // PYGPUKIT_IMPLEMENT_NN_KERNELS } // namespace nn diff --git a/native/ops/nn/alibi/alibi.inl b/native/ops/nn/alibi/alibi.inl new file mode 100644 index 0000000..e9a1c20 --- /dev/null +++ b/native/ops/nn/alibi/alibi.inl @@ -0,0 +1,109 @@ +/** + * ALiBi (Attention with Linear Biases) dispatch functions + * + * Provides: + * - alibi_init_slopes: Compute head-specific slopes + * - alibi_compute_bias: Create bias matrix for attention + * - alibi_add_bias: Add bias to attention scores in-place + */ + +#include "alibi_kernels.cuh" + +namespace pygpukit { +namespace ops { + +GPUArray alibi_init_slopes(int num_heads) { + // Create slopes tensor: [num_heads] + GPUArray slopes({(size_t)num_heads}, DataType::Float32); + + const int block_size = 256; + const int grid_size = (num_heads + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + nn::alibi_init_slopes_kernel<<>>( + static_cast(slopes.data()), + num_heads); + + sync_and_check("alibi_init_slopes kernel failed"); + return slopes; +} + +GPUArray alibi_compute_bias(int seq_len, int num_heads, const GPUArray& slopes, bool causal) { + // Create bias tensor: [num_heads, seq_len, seq_len] + if (slopes.dtype() != DataType::Float32) { + throw std::runtime_error("alibi_compute_bias: slopes must be float32"); + } + if (slopes.size() != (size_t)num_heads) { + throw std::runtime_error("alibi_compute_bias: slopes size must match num_heads"); + } + + GPUArray bias({(size_t)num_heads, (size_t)seq_len, (size_t)seq_len}, DataType::Float32); + + int total = num_heads * seq_len * seq_len; + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + if (causal) { + nn::alibi_compute_bias_causal_f32_kernel<<>>( + static_cast(bias.data()), + static_cast(slopes.data()), + seq_len, + num_heads); + } else { + nn::alibi_compute_bias_f32_kernel<<>>( + static_cast(bias.data()), + static_cast(slopes.data()), + seq_len, + num_heads); + } + + sync_and_check("alibi_compute_bias kernel failed"); + return bias; +} + +void alibi_add_bias(GPUArray& scores, const GPUArray& slopes, int start_pos) { + // scores: [batch, num_heads, q_len, kv_len] + // slopes: [num_heads] + + if (scores.ndim() != 4) { + throw std::runtime_error("alibi_add_bias: scores must be 4D [batch, heads, q_len, kv_len]"); + } + if (scores.dtype() != DataType::Float32) { + throw std::runtime_error("alibi_add_bias: scores must be float32"); + } + if (slopes.dtype() != DataType::Float32) { + throw std::runtime_error("alibi_add_bias: slopes must be float32"); + } + + int batch_size = scores.shape()[0]; + int num_heads = scores.shape()[1]; + int q_len = scores.shape()[2]; + int kv_len = scores.shape()[3]; + + if (slopes.size() != (size_t)num_heads) { + throw std::runtime_error("alibi_add_bias: slopes size must match num_heads"); + } + + int total = batch_size * num_heads * q_len * kv_len; + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + nn::alibi_add_bias_f32_kernel<<>>( + static_cast(scores.data()), + static_cast(slopes.data()), + batch_size, + num_heads, + q_len, + kv_len, + start_pos); + + sync_and_check("alibi_add_bias kernel failed"); +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/alibi/alibi_kernels.cuh b/native/ops/nn/alibi/alibi_kernels.cuh new file mode 100644 index 0000000..a69897a --- /dev/null +++ b/native/ops/nn/alibi/alibi_kernels.cuh @@ -0,0 +1,135 @@ +/** + * ALiBi (Attention with Linear Biases) kernels + * + * Adds linear bias to attention scores based on query-key distance. + * Paper: "Train Short, Test Long" (Press et al., 2022) + * + * Formula: attention_scores[i, j] = Q[i] @ K[j]^T - m * |i - j| + * Where m is a head-specific slope: m_h = 2^(-8 * h / num_heads) + */ +#pragma once + +#include +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace nn { + +// ============================================================================ +// ALiBi Init Slopes - Compute head-specific slopes +// ============================================================================ + +__global__ void alibi_init_slopes_kernel( + float* __restrict__ slopes, + int num_heads +) { + int h = blockIdx.x * blockDim.x + threadIdx.x; + if (h < num_heads) { + // m_h = 2^(-8 * (h+1) / num_heads) + // Note: h is 0-indexed, so we use (h+1) to match paper convention + float exponent = -8.0f * (float)(h + 1) / (float)num_heads; + slopes[h] = powf(2.0f, exponent); + } +} + +// ============================================================================ +// ALiBi Compute Bias - Create bias matrix for attention +// ============================================================================ + +__global__ void alibi_compute_bias_f32_kernel( + float* __restrict__ bias, + const float* __restrict__ slopes, + int seq_len, + int num_heads +) { + // bias: [num_heads, seq_len, seq_len] + // For causal attention, we only compute lower triangular + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = num_heads * seq_len * seq_len; + + if (idx < total) { + int h = idx / (seq_len * seq_len); + int i = (idx / seq_len) % seq_len; // query position + int j = idx % seq_len; // key position + + float slope = slopes[h]; + // ALiBi bias: -slope * |i - j| + // For causal: only j <= i is used, so distance is (i - j) + int distance = i - j; + if (distance >= 0) { + bias[idx] = -slope * (float)distance; + } else { + // For non-causal or positions j > i, set to large negative (masked) + bias[idx] = -1e9f; + } + } +} + +// Causal-only version (more efficient) +__global__ void alibi_compute_bias_causal_f32_kernel( + float* __restrict__ bias, + const float* __restrict__ slopes, + int seq_len, + int num_heads +) { + // bias: [num_heads, seq_len, seq_len] but we only compute lower triangular + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = num_heads * seq_len * seq_len; + + if (idx < total) { + int h = idx / (seq_len * seq_len); + int i = (idx / seq_len) % seq_len; + int j = idx % seq_len; + + if (j <= i) { + float slope = slopes[h]; + bias[idx] = -slope * (float)(i - j); + } else { + bias[idx] = -1e9f; // Causal mask + } + } +} + +// ============================================================================ +// ALiBi Add Bias - Add bias to attention scores in-place +// ============================================================================ + +__global__ void alibi_add_bias_f32_kernel( + float* __restrict__ scores, + const float* __restrict__ slopes, + int batch_size, + int num_heads, + int q_len, + int kv_len, + int start_pos +) { + // scores: [batch, num_heads, q_len, kv_len] + // For each (i, j) in scores, add -slope * |start_pos + i - j| + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_size * num_heads * q_len * kv_len; + + if (idx < total) { + int h = (idx / (q_len * kv_len)) % num_heads; + int i = (idx / kv_len) % q_len; // query position (relative) + int j = idx % kv_len; // key position + + int q_pos = start_pos + i; // absolute query position + int distance = q_pos - j; + + float slope = slopes[h]; + // Only apply for causal (j <= q_pos) + if (distance >= 0) { + scores[idx] += -slope * (float)distance; + } + // Note: causal masking should be applied separately + } +} + +} // namespace nn +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/nn.cu b/native/ops/nn/nn.cu index 05d7e77..3ca6ae3 100644 --- a/native/ops/nn/nn.cu +++ b/native/ops/nn/nn.cu @@ -27,9 +27,13 @@ #include "activation/silu.inl" #include "activation/sigmoid.inl" #include "activation/tanh.inl" +#include "activation/relu2.inl" #include "norm/layernorm.inl" #include "norm/rmsnorm.inl" #include "rope/rope_inplace.inl" +#include "rope/rope_ext.inl" +#include "pope/pope.inl" +#include "alibi/alibi.inl" #include "linear/linear_bias.inl" #include "attention/sdpa_causal.inl" #include "tensor/tensor.inl" diff --git a/native/ops/nn/pope/pope.inl b/native/ops/nn/pope/pope.inl new file mode 100644 index 0000000..a26cdc9 --- /dev/null +++ b/native/ops/nn/pope/pope.inl @@ -0,0 +1,103 @@ +/** + * PoPE (Positional Encoding) - Additive positional embedding + * + * Alternative to RoPE using simple addition instead of rotation. + */ + +#include "pope_kernels.cuh" + +namespace pygpukit { +namespace ops { + +GPUArray pope_init_encoding(int max_seq_len, int head_dim, float base) { + // Create sinusoidal positional encoding table + // Shape: [max_seq_len, head_dim] + GPUArray encoding({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + + int total = max_seq_len * head_dim; + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + nn::pope_init_sinusoidal_f32_kernel<<>>( + static_cast(encoding.data()), + max_seq_len, + head_dim, + base); + + sync_and_check("pope_init_encoding kernel failed"); + return encoding; +} + +void pope_inplace(GPUArray& q, GPUArray& k, const GPUArray& encoding, int start_pos) { + // q: [seq_len, n_heads_q, head_dim] + // k: [seq_len, n_heads_k, head_dim] + // encoding: [max_seq_len, head_dim] (float32) + + if (q.ndim() != 3 || k.ndim() != 3) { + throw std::runtime_error("pope_inplace: q and k must be 3D [seq_len, n_heads, head_dim]"); + } + if (encoding.ndim() != 2 || encoding.dtype() != DataType::Float32) { + throw std::runtime_error("pope_inplace: encoding must be 2D float32 [max_seq_len, head_dim]"); + } + if (q.dtype() != k.dtype()) { + throw std::runtime_error("pope_inplace: q and k dtype mismatch"); + } + + int seq_len = q.shape()[0]; + int n_heads_q = q.shape()[1]; + int n_heads_k = k.shape()[1]; + int head_dim = q.shape()[2]; + int max_seq_len = encoding.shape()[0]; + + if (k.shape()[0] != seq_len || k.shape()[2] != head_dim) { + throw std::runtime_error("pope_inplace: q and k shape mismatch"); + } + if (encoding.shape()[1] != head_dim) { + throw std::runtime_error("pope_inplace: encoding head_dim mismatch"); + } + if (start_pos + seq_len > max_seq_len) { + throw std::runtime_error("pope_inplace: position exceeds max_seq_len"); + } + + int total_q = seq_len * n_heads_q * head_dim; + int total_k = seq_len * n_heads_k * head_dim; + int total_work = std::max(total_q, total_k); + + const int block_size = 256; + const int grid_size = (total_work + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + switch (q.dtype()) { + case DataType::Float32: + nn::pope_apply_f32_kernel<<>>( + static_cast(q.data()), + static_cast(k.data()), + static_cast(encoding.data()), + seq_len, n_heads_q, n_heads_k, head_dim, start_pos); + break; + case DataType::Float16: + nn::pope_apply_f16_kernel<<>>( + static_cast<__half*>(q.data()), + static_cast<__half*>(k.data()), + static_cast(encoding.data()), + seq_len, n_heads_q, n_heads_k, head_dim, start_pos); + break; + case DataType::BFloat16: + nn::pope_apply_bf16_kernel<<>>( + static_cast<__nv_bfloat16*>(q.data()), + static_cast<__nv_bfloat16*>(k.data()), + static_cast(encoding.data()), + seq_len, n_heads_q, n_heads_k, head_dim, start_pos); + break; + default: + throw std::runtime_error("pope_inplace: unsupported dtype"); + } + + sync_and_check("pope_inplace kernel failed"); +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/pope/pope_kernels.cuh b/native/ops/nn/pope/pope_kernels.cuh new file mode 100644 index 0000000..2e431e3 --- /dev/null +++ b/native/ops/nn/pope/pope_kernels.cuh @@ -0,0 +1,178 @@ +/** + * PoPE (Positional Encoding) kernels + * + * Additive positional encoding as an alternative to RoPE. + * Adds sinusoidal or learned position embeddings to Q/K. + */ +#pragma once + +#include +#include +#include +#include + +namespace pygpukit { +namespace ops { +namespace nn { + +// ============================================================================ +// PoPE Init Encoding - Generate sinusoidal position embeddings +// ============================================================================ + +__global__ void pope_init_sinusoidal_f32_kernel( + float* __restrict__ encoding, + int max_seq_len, + int head_dim, + float base +) { + // encoding: [max_seq_len, head_dim] + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = max_seq_len * head_dim; + + if (idx < total) { + int pos = idx / head_dim; + int dim = idx % head_dim; + + // Sinusoidal encoding: PE(pos, 2i) = sin(pos / base^(2i/d)) + // PE(pos, 2i+1) = cos(pos / base^(2i/d)) + float freq = 1.0f / powf(base, (float)(dim / 2 * 2) / (float)head_dim); + float angle = (float)pos * freq; + + if (dim % 2 == 0) { + encoding[idx] = sinf(angle); + } else { + encoding[idx] = cosf(angle); + } + } +} + +// ============================================================================ +// PoPE Apply - Add position encoding to Q/K +// ============================================================================ + +// F32 kernel +__global__ void pope_apply_f32_kernel( + float* __restrict__ q, + float* __restrict__ k, + const float* __restrict__ encoding, + int seq_len, + int n_heads_q, + int n_heads_k, + int head_dim, + int start_pos +) { + // q: [seq_len, n_heads_q, head_dim] + // k: [seq_len, n_heads_k, head_dim] + // encoding: [max_seq_len, head_dim] + + int total_q = seq_len * n_heads_q * head_dim; + int total_k = seq_len * n_heads_k * head_dim; + int total_work = max(total_q, total_k); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_work) return; + + // Process Q + if (idx < total_q) { + int s = idx / (n_heads_q * head_dim); + int h = (idx / head_dim) % n_heads_q; + int d = idx % head_dim; + + int pos = start_pos + s; + float pe = encoding[pos * head_dim + d]; + q[idx] += pe; + } + + // Process K + if (idx < total_k) { + int s = idx / (n_heads_k * head_dim); + int h = (idx / head_dim) % n_heads_k; + int d = idx % head_dim; + + int pos = start_pos + s; + float pe = encoding[pos * head_dim + d]; + k[idx] += pe; + } +} + +// F16 kernel +__global__ void pope_apply_f16_kernel( + __half* __restrict__ q, + __half* __restrict__ k, + const float* __restrict__ encoding, // Keep encoding in f32 for precision + int seq_len, + int n_heads_q, + int n_heads_k, + int head_dim, + int start_pos +) { + int total_q = seq_len * n_heads_q * head_dim; + int total_k = seq_len * n_heads_k * head_dim; + int total_work = max(total_q, total_k); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_work) return; + + // Process Q + if (idx < total_q) { + int s = idx / (n_heads_q * head_dim); + int d = idx % head_dim; + int pos = start_pos + s; + float pe = encoding[pos * head_dim + d]; + float val = __half2float(q[idx]) + pe; + q[idx] = __float2half(val); + } + + // Process K + if (idx < total_k) { + int s = idx / (n_heads_k * head_dim); + int d = idx % head_dim; + int pos = start_pos + s; + float pe = encoding[pos * head_dim + d]; + float val = __half2float(k[idx]) + pe; + k[idx] = __float2half(val); + } +} + +// BF16 kernel +__global__ void pope_apply_bf16_kernel( + __nv_bfloat16* __restrict__ q, + __nv_bfloat16* __restrict__ k, + const float* __restrict__ encoding, + int seq_len, + int n_heads_q, + int n_heads_k, + int head_dim, + int start_pos +) { + int total_q = seq_len * n_heads_q * head_dim; + int total_k = seq_len * n_heads_k * head_dim; + int total_work = max(total_q, total_k); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_work) return; + + // Process Q + if (idx < total_q) { + int s = idx / (n_heads_q * head_dim); + int d = idx % head_dim; + int pos = start_pos + s; + float pe = encoding[pos * head_dim + d]; + float val = __bfloat162float(q[idx]) + pe; + q[idx] = __float2bfloat16(val); + } + + // Process K + if (idx < total_k) { + int s = idx / (n_heads_k * head_dim); + int d = idx % head_dim; + int pos = start_pos + s; + float pe = encoding[pos * head_dim + d]; + float val = __bfloat162float(k[idx]) + pe; + k[idx] = __float2bfloat16(val); + } +} + +} // namespace nn +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/rope/rope_ext.inl b/native/ops/nn/rope/rope_ext.inl new file mode 100644 index 0000000..61cee95 --- /dev/null +++ b/native/ops/nn/rope/rope_ext.inl @@ -0,0 +1,114 @@ +/** + * Extended RoPE dispatch functions for context length extension + * + * Provides: + * - rope_init_ntk_aware: NTK-aware frequency scaling + * - rope_init_yarn: YaRN dimension-wise interpolation + * - rope_init_linear: Simple linear position interpolation + */ + +#include "rope_ext_kernels.cuh" + +namespace pygpukit { +namespace ops { + +std::pair rope_init_ntk_aware( + int max_seq_len, + int head_dim, + float base, + float scale +) { + // NTK-aware interpolation: scales base frequency instead of positions + // base' = base * scale^(dim / (dim - 2)) + + GPUArray cos_table({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + GPUArray sin_table({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + + int total = max_seq_len * head_dim; + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + nn::rope_init_ntk_aware_f32_kernel<<>>( + static_cast(cos_table.data()), + static_cast(sin_table.data()), + max_seq_len, + head_dim, + base, + scale); + + sync_and_check("rope_init_ntk_aware kernel failed"); + return {std::move(cos_table), std::move(sin_table)}; +} + +std::pair rope_init_yarn( + int max_seq_len, + int head_dim, + float base, + float scale, + int original_max_len, + float beta_fast, + float beta_slow, + float mscale +) { + // YaRN: dimension-wise interpolation with attention scaling + // Different scaling for different frequency bands + + GPUArray cos_table({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + GPUArray sin_table({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + + int total = max_seq_len * head_dim; + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + nn::rope_init_yarn_f32_kernel<<>>( + static_cast(cos_table.data()), + static_cast(sin_table.data()), + max_seq_len, + head_dim, + base, + scale, + original_max_len, + beta_fast, + beta_slow, + mscale); + + sync_and_check("rope_init_yarn kernel failed"); + return {std::move(cos_table), std::move(sin_table)}; +} + +std::pair rope_init_linear( + int max_seq_len, + int head_dim, + float base, + float scale +) { + // Linear position interpolation (PI): pos' = pos / scale + // Simple baseline, degrades quality at high scales + + GPUArray cos_table({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + GPUArray sin_table({(size_t)max_seq_len, (size_t)head_dim}, DataType::Float32); + + int total = max_seq_len * head_dim; + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = internal::get_capture_stream(); + + nn::rope_init_linear_interpolation_f32_kernel<<>>( + static_cast(cos_table.data()), + static_cast(sin_table.data()), + max_seq_len, + head_dim, + base, + scale); + + sync_and_check("rope_init_linear kernel failed"); + return {std::move(cos_table), std::move(sin_table)}; +} + +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/nn/rope/rope_ext_kernels.cuh b/native/ops/nn/rope/rope_ext_kernels.cuh new file mode 100644 index 0000000..fd57862 --- /dev/null +++ b/native/ops/nn/rope/rope_ext_kernels.cuh @@ -0,0 +1,214 @@ +/** + * Extended RoPE kernels for context length extension + * + * Implements: + * - NTK-aware interpolation + * - YaRN (Yet another RoPE extensioN) + * + * These methods allow models to handle sequences longer than training context. + */ +#pragma once + +#include +#include +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace pygpukit { +namespace ops { +namespace nn { + +// ============================================================================ +// NTK-aware RoPE - Scale base frequency for context extension +// ============================================================================ + +__global__ void rope_init_ntk_aware_f32_kernel( + float* __restrict__ cos_table, + float* __restrict__ sin_table, + int max_seq_len, + int head_dim, + float base, + float scale +) { + // NTK-aware: base' = base * scale^(dim / (dim - 2)) + // This preserves high-frequency components better than linear interpolation + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = max_seq_len * head_dim; + + if (idx < total) { + int pos = idx / head_dim; + int dim = idx % head_dim; + + // NTK scaling factor for base + float ntk_factor = powf(scale, (float)head_dim / ((float)head_dim - 2.0f)); + float scaled_base = base * ntk_factor; + + // Compute frequency with scaled base + float freq = 1.0f / powf(scaled_base, (float)(dim / 2 * 2) / (float)head_dim); + float angle = (float)pos * freq; + + if (dim % 2 == 0) { + cos_table[idx] = cosf(angle); + sin_table[idx] = sinf(angle); + } else { + // For odd dimensions, use the same angle as even (paired) + float even_freq = 1.0f / powf(scaled_base, (float)((dim - 1) / 2 * 2) / (float)head_dim); + float even_angle = (float)pos * even_freq; + cos_table[idx] = cosf(even_angle); + sin_table[idx] = sinf(even_angle); + } + } +} + +// ============================================================================ +// YaRN RoPE - Dimension-wise interpolation with attention scaling +// ============================================================================ + +// YaRN uses different interpolation for different frequency bands: +// - Low frequency (local attention): no interpolation +// - High frequency: full interpolation +// - Mid frequency: gradual transition + +__device__ __forceinline__ float yarn_get_mscale(float scale, float mscale_factor) { + // Attention scaling factor to compensate for interpolation + // mscale = 0.1 * ln(scale) + 1.0 (default) + if (mscale_factor <= 0.0f) { + return 1.0f; // No scaling + } + return mscale_factor * logf(scale) + 1.0f; +} + +__device__ __forceinline__ float yarn_find_correction_dim( + int dim, + int head_dim, + float base, + int max_position_embeddings +) { + // Find the correction dimension for YaRN + // Based on wavelength analysis + return (float)dim * logf((float)max_position_embeddings / (2.0f * (float)M_PI * (float)dim)) / + (2.0f * logf(base)); +} + +__device__ __forceinline__ float yarn_find_correction_range( + float low_rot, + float high_rot, + int dim, + float base, + int max_position_embeddings +) { + // Linear ramp between correction ranges + float low = floorf(yarn_find_correction_dim(dim, dim, base, max_position_embeddings) * low_rot); + float high = ceilf(yarn_find_correction_dim(dim, dim, base, max_position_embeddings) * high_rot); + return fmaxf(low, 0.0f); +} + +__global__ void rope_init_yarn_f32_kernel( + float* __restrict__ cos_table, + float* __restrict__ sin_table, + int max_seq_len, + int head_dim, + float base, + float scale, + int original_max_len, + float beta_fast, + float beta_slow, + float mscale +) { + // YaRN: dimension-wise interpolation + // Low freq dims: no scaling (preserve local attention) + // High freq dims: full scaling + // Mid freq dims: linear ramp + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = max_seq_len * head_dim; + + if (idx < total) { + int pos = idx / head_dim; + int dim = idx % head_dim; + int dim_pair = dim / 2 * 2; // Even dimension in pair + + // Compute frequency bounds for interpolation + float low_freq_wavelen = (float)original_max_len / beta_fast; + float high_freq_wavelen = (float)original_max_len / beta_slow; + + // Current wavelength for this dimension + float wavelen = 2.0f * (float)M_PI * powf(base, (float)dim_pair / (float)head_dim); + + float freq_factor; + if (wavelen >= high_freq_wavelen) { + // High frequency: no interpolation + freq_factor = 1.0f; + } else if (wavelen <= low_freq_wavelen) { + // Low frequency: full interpolation + freq_factor = 1.0f / scale; + } else { + // Mid frequency: linear interpolation + float smooth = (wavelen - low_freq_wavelen) / (high_freq_wavelen - low_freq_wavelen); + freq_factor = (1.0f - smooth) / scale + smooth; + } + + // Compute angle with interpolated frequency + float inv_freq = 1.0f / powf(base, (float)dim_pair / (float)head_dim); + float scaled_freq = inv_freq * freq_factor; + float angle = (float)pos * scaled_freq; + + // Apply mscale (attention scaling) + float attention_scale = yarn_get_mscale(scale, mscale); + + if (dim % 2 == 0) { + cos_table[idx] = cosf(angle) * attention_scale; + sin_table[idx] = sinf(angle) * attention_scale; + } else { + cos_table[idx] = cosf(angle) * attention_scale; + sin_table[idx] = sinf(angle) * attention_scale; + } + } +} + +// ============================================================================ +// Linear Position Interpolation (PI) - Simple baseline +// ============================================================================ + +__global__ void rope_init_linear_interpolation_f32_kernel( + float* __restrict__ cos_table, + float* __restrict__ sin_table, + int max_seq_len, + int head_dim, + float base, + float scale +) { + // Linear interpolation: pos' = pos / scale + // Simple but degrades quality at high scales + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = max_seq_len * head_dim; + + if (idx < total) { + int pos = idx / head_dim; + int dim = idx % head_dim; + int dim_pair = dim / 2 * 2; + + // Scale position instead of frequency + float scaled_pos = (float)pos / scale; + float freq = 1.0f / powf(base, (float)dim_pair / (float)head_dim); + float angle = scaled_pos * freq; + + if (dim % 2 == 0) { + cos_table[idx] = cosf(angle); + sin_table[idx] = sinf(angle); + } else { + cos_table[idx] = cosf(angle); + sin_table[idx] = sinf(angle); + } + } +} + +} // namespace nn +} // namespace ops +} // namespace pygpukit diff --git a/native/ops/ops.cuh b/native/ops/ops.cuh index 01a55b1..7ee4626 100644 --- a/native/ops/ops.cuh +++ b/native/ops/ops.cuh @@ -11,6 +11,7 @@ #include "../core/memory.hpp" #include +#include namespace pygpukit { namespace ops { @@ -167,6 +168,10 @@ void sigmoid(const GPUArray& input, GPUArray& out); GPUArray tanh(const GPUArray& input); void tanh(const GPUArray& input, GPUArray& out); +// ReLU squared activation: y = (max(0, x))^2 +GPUArray relu2(const GPUArray& input); +void relu2(const GPUArray& input, GPUArray& out); + // RoPE (Rotary Position Embedding) - In-place // q: [seq_len, n_heads_q, head_dim] // k: [seq_len, n_heads_k, head_dim] @@ -179,6 +184,45 @@ void rope_inplace(GPUArray& q, GPUArray& k, const GPUArray& cos, const GPUArray& // cos, sin: [seq_len, head_dim] (f32) void rope_inplace_f32table(GPUArray& q, GPUArray& k, const GPUArray& cos, const GPUArray& sin); +// RoPE context extension: NTK-aware scaling +// Returns (cos_table, sin_table) each [max_seq_len, head_dim] +std::pair rope_init_ntk_aware( + int max_seq_len, int head_dim, float base = 10000.0f, float scale = 1.0f); + +// RoPE context extension: YaRN dimension-wise interpolation +// Returns (cos_table, sin_table) each [max_seq_len, head_dim] +std::pair rope_init_yarn( + int max_seq_len, int head_dim, float base = 10000.0f, float scale = 1.0f, + int original_max_len = 4096, float beta_fast = 32.0f, float beta_slow = 1.0f, float mscale = 0.1f); + +// RoPE context extension: Linear position interpolation +// Returns (cos_table, sin_table) each [max_seq_len, head_dim] +std::pair rope_init_linear( + int max_seq_len, int head_dim, float base = 10000.0f, float scale = 1.0f); + +// PoPE (Positional Encoding) - additive positional encoding +// Returns encoding tensor [max_seq_len, head_dim] +GPUArray pope_init_encoding(int max_seq_len, int head_dim, float base = 10000.0f); + +// PoPE in-place application +// q: [seq_len, n_heads_q, head_dim] +// k: [seq_len, n_heads_k, head_dim] +// encoding: [max_seq_len, head_dim] (f32) +void pope_inplace(GPUArray& q, GPUArray& k, const GPUArray& encoding, int start_pos = 0); + +// ALiBi (Attention with Linear Biases) - head-specific slopes +// Returns slopes tensor [num_heads] +GPUArray alibi_init_slopes(int num_heads); + +// ALiBi bias matrix computation +// Returns bias tensor [num_heads, seq_len, seq_len] +GPUArray alibi_compute_bias(int seq_len, int num_heads, const GPUArray& slopes, bool causal = true); + +// ALiBi in-place bias addition to attention scores +// scores: [batch, num_heads, q_len, kv_len] +// slopes: [num_heads] +void alibi_add_bias(GPUArray& scores, const GPUArray& slopes, int start_pos = 0); + // Split fused QKV projection output into separate Q, K, V tensors // qkv: [seq_len, q_dim + k_dim + v_dim] // q_out: [seq_len, q_dim] (can be pre-allocated buffer) diff --git a/src/pygpukit/ops/nn/__init__.py b/src/pygpukit/ops/nn/__init__.py index 3e11fdc..aeebe48 100644 --- a/src/pygpukit/ops/nn/__init__.py +++ b/src/pygpukit/ops/nn/__init__.py @@ -16,6 +16,7 @@ # Activation functions from pygpukit.ops.nn.activation import ( gelu, + relu2, sigmoid, silu, tanh, @@ -49,6 +50,17 @@ # RoPE operations from pygpukit.ops.nn.rope import ( + alibi_add_bias, + alibi_compute_bias, + # ALiBi + alibi_init_slopes, + # PoPE + pope_init_encoding, + pope_inplace, + rope_init_linear, + # RoPE extensions + rope_init_ntk_aware, + rope_init_yarn, rope_inplace, rope_inplace_f32table, ) @@ -56,6 +68,7 @@ __all__ = [ # Activation "gelu", + "relu2", "silu", "sigmoid", "tanh", @@ -69,6 +82,17 @@ # RoPE "rope_inplace", "rope_inplace_f32table", + # RoPE extensions + "rope_init_ntk_aware", + "rope_init_yarn", + "rope_init_linear", + # PoPE + "pope_init_encoding", + "pope_inplace", + # ALiBi + "alibi_init_slopes", + "alibi_compute_bias", + "alibi_add_bias", # Linear "bias_add_inplace", "split_qkv_batch", diff --git a/src/pygpukit/ops/nn/activation.py b/src/pygpukit/ops/nn/activation.py index 266b82d..ade8884 100644 --- a/src/pygpukit/ops/nn/activation.py +++ b/src/pygpukit/ops/nn/activation.py @@ -169,9 +169,57 @@ def tanh(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray: return from_numpy(np.tanh(x)) +def relu2(a: GPUArray, *, out: GPUArray | None = None) -> GPUArray: + """ReLU squared activation: y = (max(0, x))^2. + + Introduced in the Primer paper (Google, 2021). Benefits: + - Stronger sparsity than standard ReLU + - Continuous first derivative (unlike ReLU) + - Improved training dynamics in some architectures + + Args: + a: Input array (float32, float16, or bfloat16). + out: Optional pre-allocated output array. + + Returns: + A new GPUArray containing the ReLU squared values. + + Example: + >>> x = from_numpy(np.array([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=np.float32)) + >>> y = relu2(x) + >>> y.to_numpy() # [0.0, 0.0, 0.0, 1.0, 4.0] + """ + _validate_float_dtype(a, "relu2") + backend = get_backend() + + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + a_native = a._get_native() + + if out is not None: + out_native = out._get_native() + native.relu2_(a_native, out_native) + return out + else: + return GPUArray._wrap_native(native.relu2(a_native)) + else: + # CPU fallback + x = a.to_numpy() + relu_val = np.maximum(0, x) + result_np = (relu_val * relu_val).astype(x.dtype) + if out is not None: + # Update output buffer in-place + backend.copy_host_to_device(result_np.ravel(), out._device_ptr) + return out + return from_numpy(result_np) + + __all__ = [ "gelu", "silu", "sigmoid", "tanh", + "relu2", ] diff --git a/src/pygpukit/ops/nn/rope.py b/src/pygpukit/ops/nn/rope.py index 0c81a2f..56aa71c 100644 --- a/src/pygpukit/ops/nn/rope.py +++ b/src/pygpukit/ops/nn/rope.py @@ -130,7 +130,249 @@ def rope_inplace_f32table( native.rope_inplace_f32table(q_native, k_native, cos_native, sin_native) +def rope_init_ntk_aware( + max_seq_len: int, + head_dim: int, + base: float = 10000.0, + scale: float = 1.0, +) -> tuple[GPUArray, GPUArray]: + """Initialize RoPE with NTK-aware frequency scaling. + + NTK-aware interpolation scales the base frequency instead of positions: + base' = base * scale^(dim / (dim - 2)) + + This preserves high-frequency components better than linear interpolation. + + Args: + max_seq_len: Maximum sequence length. + head_dim: Dimension per head. + base: Base for frequency computation (default 10000). + scale: Context extension scale factor (e.g., 2.0 for 2x context). + + Returns: + Tuple of (cos_table, sin_table) each of shape [max_seq_len, head_dim]. + + Example: + >>> cos, sin = rope_init_ntk_aware(8192, 128, scale=2.0) + >>> rope_inplace(q, k, cos, sin) + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + cos_native, sin_native = native.rope_init_ntk_aware(max_seq_len, head_dim, base, scale) + return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + + +def rope_init_yarn( + max_seq_len: int, + head_dim: int, + base: float = 10000.0, + scale: float = 1.0, + original_max_len: int = 4096, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + mscale: float = 0.1, +) -> tuple[GPUArray, GPUArray]: + """Initialize RoPE with YaRN dimension-wise interpolation. + + YaRN (Yet another RoPE extensioN) combines NTK with attention scaling + and dimension-wise interpolation for state-of-the-art context extension. + + Different frequency bands are handled differently: + - Low frequency (local attention): no interpolation + - High frequency: full interpolation + - Mid frequency: gradual transition + + Args: + max_seq_len: Maximum sequence length (extended). + head_dim: Dimension per head. + base: Base for frequency computation (default 10000). + scale: Context extension scale factor. + original_max_len: Original training context length. + beta_fast: Fast wavelength threshold (default 32). + beta_slow: Slow wavelength threshold (default 1). + mscale: Attention scaling factor (default 0.1). + + Returns: + Tuple of (cos_table, sin_table) each of shape [max_seq_len, head_dim]. + + Example: + >>> cos, sin = rope_init_yarn(32768, 128, scale=4.0, original_max_len=4096) + >>> rope_inplace(q, k, cos, sin) + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + cos_native, sin_native = native.rope_init_yarn( + max_seq_len, head_dim, base, scale, original_max_len, beta_fast, beta_slow, mscale + ) + return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + + +def rope_init_linear( + max_seq_len: int, + head_dim: int, + base: float = 10000.0, + scale: float = 1.0, +) -> tuple[GPUArray, GPUArray]: + """Initialize RoPE with linear position interpolation. + + Simple baseline: pos' = pos / scale. + Works but degrades quality at high scales. + + Args: + max_seq_len: Maximum sequence length. + head_dim: Dimension per head. + base: Base for frequency computation (default 10000). + scale: Context extension scale factor. + + Returns: + Tuple of (cos_table, sin_table) each of shape [max_seq_len, head_dim]. + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + cos_native, sin_native = native.rope_init_linear(max_seq_len, head_dim, base, scale) + return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + + +def pope_init_encoding( + max_seq_len: int, + head_dim: int, + base: float = 10000.0, +) -> GPUArray: + """Initialize sinusoidal positional encoding table (PoPE). + + PoPE is an additive positional encoding alternative to RoPE. + Uses sinusoidal encoding: PE(pos, 2i) = sin(pos / base^(2i/d)) + PE(pos, 2i+1) = cos(pos / base^(2i/d)) + + Args: + max_seq_len: Maximum sequence length. + head_dim: Dimension per head. + base: Base for frequency computation (default 10000). + + Returns: + Encoding tensor of shape [max_seq_len, head_dim]. + + Example: + >>> encoding = pope_init_encoding(2048, 128) + >>> pope_inplace(q, k, encoding) + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + encoding_native = native.pope_init_encoding(max_seq_len, head_dim, base) + return GPUArray._wrap_native(encoding_native) + + +def pope_inplace( + q: GPUArray, + k: GPUArray, + encoding: GPUArray, + start_pos: int = 0, +) -> None: + """Apply additive positional encoding to Q and K in-place. + + PoPE adds positional information by simple addition (vs RoPE's rotation). + Simpler compute but limited extrapolation compared to RoPE. + + Args: + q: Query tensor [seq_len, n_heads_q, head_dim] (modified in-place). + k: Key tensor [seq_len, n_heads_k, head_dim] (modified in-place). + encoding: Position encoding [max_seq_len, head_dim] (f32). + start_pos: Starting position for incremental decoding. + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + native.pope_inplace(q._get_native(), k._get_native(), encoding._get_native(), start_pos) + + +def alibi_init_slopes(num_heads: int) -> GPUArray: + """Initialize ALiBi head-specific slopes. + + ALiBi (Attention with Linear Biases) adds a linear bias to attention + scores based on query-key distance: scores[i,j] -= slope * |i - j| + + Each head gets a different slope: m_h = 2^(-8 * h / num_heads) + + Args: + num_heads: Number of attention heads. + + Returns: + Slopes tensor of shape [num_heads]. + + Example: + >>> slopes = alibi_init_slopes(32) + >>> bias = alibi_compute_bias(512, 32, slopes) + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + slopes_native = native.alibi_init_slopes(num_heads) + return GPUArray._wrap_native(slopes_native) + + +def alibi_compute_bias( + seq_len: int, + num_heads: int, + slopes: GPUArray, + causal: bool = True, +) -> GPUArray: + """Compute ALiBi bias matrix for attention. + + Creates a bias tensor to be added to attention scores. + For causal attention, positions j > i are masked with -inf. + + Args: + seq_len: Sequence length. + num_heads: Number of attention heads. + slopes: Head-specific slopes [num_heads]. + causal: Whether to apply causal masking (default True). + + Returns: + Bias tensor of shape [num_heads, seq_len, seq_len]. + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + bias_native = native.alibi_compute_bias(seq_len, num_heads, slopes._get_native(), causal) + return GPUArray._wrap_native(bias_native) + + +def alibi_add_bias( + scores: GPUArray, + slopes: GPUArray, + start_pos: int = 0, +) -> None: + """Add ALiBi bias to attention scores in-place. + + Efficiently adds position-dependent bias during incremental decoding. + + Args: + scores: Attention scores [batch, num_heads, q_len, kv_len] (modified in-place). + slopes: Head-specific slopes [num_heads]. + start_pos: Starting position for incremental decoding. + """ + from pygpukit.core.backend import get_native_module + + native = get_native_module() + native.alibi_add_bias(scores._get_native(), slopes._get_native(), start_pos) + + __all__ = [ "rope_inplace", "rope_inplace_f32table", + # RoPE extensions + "rope_init_ntk_aware", + "rope_init_yarn", + "rope_init_linear", + # PoPE + "pope_init_encoding", + "pope_inplace", + # ALiBi + "alibi_init_slopes", + "alibi_compute_bias", + "alibi_add_bias", ] diff --git a/tests/test_positional_encoding.py b/tests/test_positional_encoding.py new file mode 100644 index 0000000..f852ae3 --- /dev/null +++ b/tests/test_positional_encoding.py @@ -0,0 +1,235 @@ +"""Tests for positional encoding operations (PoPE, ALiBi, YaRN, NTK).""" + +import numpy as np +import pytest + +from pygpukit import from_numpy +from pygpukit.ops.nn import ( + alibi_compute_bias, + alibi_init_slopes, + pope_init_encoding, + pope_inplace, + rope_init_linear, + rope_init_ntk_aware, + rope_init_yarn, +) + + +class TestPoPE: + """Test PoPE (Positional Encoding) operations.""" + + def test_pope_init_encoding_shape(self): + """Test that PoPE encoding has correct shape.""" + max_seq_len = 512 + head_dim = 128 + + encoding = pope_init_encoding(max_seq_len, head_dim) + + assert encoding.shape == (max_seq_len, head_dim) + assert str(encoding.dtype) == "float32" + + def test_pope_init_encoding_sinusoidal(self): + """Test that PoPE encoding follows sinusoidal pattern.""" + max_seq_len = 64 + head_dim = 32 + + encoding = pope_init_encoding(max_seq_len, head_dim) + enc_np = encoding.to_numpy() + + # Position 0 should have sin(0) = 0 for even dims + # and cos(0) = 1 for odd dims (approximately) + # Due to frequency scaling, only low-frequency dims will be close to 0/1 + assert enc_np[0, 0] == pytest.approx(0.0, abs=0.01) # sin(0) = 0 + assert enc_np[0, 1] == pytest.approx(1.0, abs=0.01) # cos(0) = 1 + + def test_pope_inplace(self): + """Test PoPE in-place application.""" + seq_len = 4 + n_heads = 2 + head_dim = 8 + max_seq_len = 16 + + q = from_numpy(np.ones((seq_len, n_heads, head_dim), dtype=np.float32)) + k = from_numpy(np.ones((seq_len, n_heads, head_dim), dtype=np.float32)) + encoding = pope_init_encoding(max_seq_len, head_dim) + + q_before = q.to_numpy().copy() + k_before = k.to_numpy().copy() + + pope_inplace(q, k, encoding) + + q_after = q.to_numpy() + k_after = k.to_numpy() + + # Values should be modified (encoding added) + assert not np.allclose(q_after, q_before) + assert not np.allclose(k_after, k_before) + + +class TestALiBi: + """Test ALiBi (Attention with Linear Biases) operations.""" + + def test_alibi_init_slopes_shape(self): + """Test that ALiBi slopes have correct shape.""" + num_heads = 8 + + slopes = alibi_init_slopes(num_heads) + + assert slopes.shape == (num_heads,) + assert str(slopes.dtype) == "float32" + + def test_alibi_init_slopes_values(self): + """Test that ALiBi slopes follow the formula m_h = 2^(-8*h/n).""" + num_heads = 8 + + slopes = alibi_init_slopes(num_heads) + slopes_np = slopes.to_numpy() + + # Verify formula: m_h = 2^(-8 * (h+1) / num_heads) + for h in range(num_heads): + expected = 2 ** (-8 * (h + 1) / num_heads) + assert slopes_np[h] == pytest.approx(expected, rel=1e-5) + + def test_alibi_compute_bias_shape(self): + """Test that ALiBi bias has correct shape.""" + seq_len = 32 + num_heads = 4 + + slopes = alibi_init_slopes(num_heads) + bias = alibi_compute_bias(seq_len, num_heads, slopes) + + assert bias.shape == (num_heads, seq_len, seq_len) + assert str(bias.dtype) == "float32" + + def test_alibi_compute_bias_causal(self): + """Test that ALiBi bias is causal (upper triangular is -inf).""" + seq_len = 8 + num_heads = 2 + + slopes = alibi_init_slopes(num_heads) + bias = alibi_compute_bias(seq_len, num_heads, slopes, causal=True) + bias_np = bias.to_numpy() + + # Check that upper triangular (j > i) is very negative (causal mask) + for h in range(num_heads): + for i in range(seq_len): + for j in range(seq_len): + if j > i: + assert bias_np[h, i, j] < -1e8 # Should be -inf or very negative + + def test_alibi_compute_bias_diagonal_zero(self): + """Test that ALiBi bias diagonal is zero (distance = 0).""" + seq_len = 8 + num_heads = 2 + + slopes = alibi_init_slopes(num_heads) + bias = alibi_compute_bias(seq_len, num_heads, slopes) + bias_np = bias.to_numpy() + + # Diagonal should be 0 (distance = 0) + for h in range(num_heads): + for i in range(seq_len): + assert bias_np[h, i, i] == pytest.approx(0.0, abs=1e-6) + + def test_alibi_compute_bias_linear_decrease(self): + """Test that ALiBi bias decreases linearly with distance.""" + seq_len = 8 + num_heads = 4 + + slopes = alibi_init_slopes(num_heads) + bias = alibi_compute_bias(seq_len, num_heads, slopes) + bias_np = bias.to_numpy() + slopes_np = slopes.to_numpy() + + # Check that bias[h, i, j] = -slope * (i - j) for j <= i + for h in range(num_heads): + slope = slopes_np[h] + for i in range(seq_len): + for j in range(i + 1): # Only lower triangular + expected = -slope * (i - j) + assert bias_np[h, i, j] == pytest.approx(expected, rel=1e-4) + + +class TestRoPEExtensions: + """Test RoPE extension methods (NTK, YaRN, Linear).""" + + def test_rope_init_ntk_aware_shape(self): + """Test that NTK-aware RoPE tables have correct shape.""" + max_seq_len = 512 + head_dim = 128 + + cos, sin = rope_init_ntk_aware(max_seq_len, head_dim) + + assert cos.shape == (max_seq_len, head_dim) + assert sin.shape == (max_seq_len, head_dim) + assert str(cos.dtype) == "float32" + assert str(sin.dtype) == "float32" + + def test_rope_init_ntk_aware_scale(self): + """Test that NTK-aware scaling affects frequencies.""" + max_seq_len = 128 + head_dim = 64 + + cos1, sin1 = rope_init_ntk_aware(max_seq_len, head_dim, scale=1.0) + cos2, sin2 = rope_init_ntk_aware(max_seq_len, head_dim, scale=2.0) + + cos1_np = cos1.to_numpy() + cos2_np = cos2.to_numpy() + + # With scale > 1, frequencies should be different + assert not np.allclose(cos1_np, cos2_np) + + def test_rope_init_yarn_shape(self): + """Test that YaRN RoPE tables have correct shape.""" + max_seq_len = 1024 + head_dim = 128 + + cos, sin = rope_init_yarn( + max_seq_len, + head_dim, + scale=2.0, + original_max_len=512, + ) + + assert cos.shape == (max_seq_len, head_dim) + assert sin.shape == (max_seq_len, head_dim) + + def test_rope_init_yarn_vs_linear(self): + """Test that YaRN produces different results than linear interpolation.""" + max_seq_len = 256 + head_dim = 64 + scale = 2.0 + + cos_yarn, sin_yarn = rope_init_yarn(max_seq_len, head_dim, scale=scale) + cos_linear, sin_linear = rope_init_linear(max_seq_len, head_dim, scale=scale) + + # YaRN should produce different frequencies than linear + assert not np.allclose(cos_yarn.to_numpy(), cos_linear.to_numpy()) + + def test_rope_init_linear_shape(self): + """Test that linear interpolation RoPE tables have correct shape.""" + max_seq_len = 512 + head_dim = 128 + + cos, sin = rope_init_linear(max_seq_len, head_dim) + + assert cos.shape == (max_seq_len, head_dim) + assert sin.shape == (max_seq_len, head_dim) + + def test_rope_tables_normalized(self): + """Test that cos^2 + sin^2 = 1 for all positions and dimensions.""" + max_seq_len = 64 + head_dim = 32 + + for init_fn in [rope_init_ntk_aware, rope_init_linear]: + cos, sin = init_fn(max_seq_len, head_dim) + cos_np = cos.to_numpy() + sin_np = sin.to_numpy() + + # cos^2 + sin^2 should be ~1 + sum_sq = cos_np**2 + sin_np**2 + np.testing.assert_allclose(sum_sq, 1.0, rtol=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_relu2.py b/tests/test_relu2.py new file mode 100644 index 0000000..789e26d --- /dev/null +++ b/tests/test_relu2.py @@ -0,0 +1,137 @@ +"""Tests for ReLU squared (relu2) activation function.""" + +import numpy as np +import pytest + +from pygpukit import from_numpy +from pygpukit.core.backend import NativeBackend, get_backend +from pygpukit.core.dtypes import DataType +from pygpukit.ops.nn import relu2 + + +def is_gpu_available(): + """Check if GPU backend is available.""" + backend = get_backend() + return isinstance(backend, NativeBackend) and backend.is_available() + + +def relu2_reference(x: np.ndarray) -> np.ndarray: + """Reference implementation of ReLU squared.""" + relu_val = np.maximum(0, x) + return relu_val * relu_val + + +class TestRelu2: + """Test ReLU squared activation.""" + + def test_relu2_basic_f32(self): + """Test basic ReLU squared with float32.""" + x = np.array([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=np.float32) + expected = relu2_reference(x) + + x_gpu = from_numpy(x) + y_gpu = relu2(x_gpu) + y = y_gpu.to_numpy() + + np.testing.assert_allclose(y, expected, rtol=1e-5) + + def test_relu2_negative_values(self): + """Test that negative values become 0.""" + x = np.array([-5.0, -3.0, -1.0, -0.5], dtype=np.float32) + x_gpu = from_numpy(x) + y_gpu = relu2(x_gpu) + y = y_gpu.to_numpy() + + np.testing.assert_allclose(y, np.zeros_like(x), rtol=1e-5) + + def test_relu2_positive_values(self): + """Test that positive values are squared.""" + x = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + expected = np.array([1.0, 4.0, 9.0, 16.0], dtype=np.float32) + + x_gpu = from_numpy(x) + y_gpu = relu2(x_gpu) + y = y_gpu.to_numpy() + + np.testing.assert_allclose(y, expected, rtol=1e-5) + + def test_relu2_2d_array(self): + """Test ReLU squared with 2D array.""" + x = np.random.randn(32, 64).astype(np.float32) + expected = relu2_reference(x) + + x_gpu = from_numpy(x) + y_gpu = relu2(x_gpu) + y = y_gpu.to_numpy() + + np.testing.assert_allclose(y, expected, rtol=1e-5) + + def test_relu2_3d_array(self): + """Test ReLU squared with 3D array (batch, seq, hidden).""" + x = np.random.randn(4, 128, 256).astype(np.float32) + expected = relu2_reference(x) + + x_gpu = from_numpy(x) + y_gpu = relu2(x_gpu) + y = y_gpu.to_numpy() + + np.testing.assert_allclose(y, expected, rtol=1e-5) + + def test_relu2_bf16(self): + """Test ReLU squared with bfloat16.""" + if not is_gpu_available(): + pytest.skip("BF16 requires GPU") + + x = np.random.randn(64, 128).astype(np.float32) + expected = relu2_reference(x) + + # Convert to bf16 on GPU + x_gpu = from_numpy(x).astype(DataType.from_string("bfloat16")) + y_gpu = relu2(x_gpu) + y = y_gpu.astype(DataType.from_string("float32")).to_numpy() + + # BF16 has lower precision + np.testing.assert_allclose(y, expected, rtol=1e-2, atol=1e-2) + + def test_relu2_f16(self): + """Test ReLU squared with float16.""" + if not is_gpu_available(): + pytest.skip("F16 requires GPU") + + x = np.random.randn(64, 128).astype(np.float32) + expected = relu2_reference(x) + + # Convert to f16 on GPU + x_gpu = from_numpy(x).astype(DataType.from_string("float16")) + y_gpu = relu2(x_gpu) + y = y_gpu.astype(DataType.from_string("float32")).to_numpy() + + # F16 has lower precision + np.testing.assert_allclose(y, expected, rtol=1e-2, atol=1e-2) + + def test_relu2_with_output_buffer(self): + """Test ReLU squared with pre-allocated output buffer.""" + x = np.random.randn(32, 64).astype(np.float32) + expected = relu2_reference(x) + + x_gpu = from_numpy(x) + out_gpu = from_numpy(np.zeros_like(x)) + + result = relu2(x_gpu, out=out_gpu) + y = out_gpu.to_numpy() + + # Verify output buffer contains correct values + np.testing.assert_allclose(y, expected, rtol=1e-5) + + def test_relu2_preserves_shape(self): + """Test that ReLU squared preserves input shape.""" + shapes = [(10,), (10, 20), (5, 10, 15), (2, 3, 4, 5)] + for shape in shapes: + x = np.random.randn(*shape).astype(np.float32) + x_gpu = from_numpy(x) + y_gpu = relu2(x_gpu) + assert y_gpu.shape == shape + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From cbcf5c128871f0aab57c319b57b0d7951864742d Mon Sep 17 00:00:00 2001 From: m96-chan Date: Wed, 31 Dec 2025 07:46:12 +0900 Subject: [PATCH 2/2] fix(nn): add CPU fallback for positional encoding operations Add CPU implementations for PoPE, ALiBi, and RoPE extension functions to support testing in CI environments without GPU: - rope_init_ntk_aware: NTK-aware frequency scaling - rope_init_yarn: YaRN dimension-wise interpolation - rope_init_linear: Linear position interpolation - pope_init_encoding: Sinusoidal positional encoding - pope_inplace: Additive positional encoding - alibi_init_slopes: ALiBi head slopes - alibi_compute_bias: ALiBi bias matrix - alibi_add_bias: In-place bias addition Also fixed in-place update bug in _rope_inplace_cpu using proper backend.copy_host_to_device() instead of non-existent _data attribute. Fixes CI test failures for positional encoding tests. Co-Authored-By: Claude Opus 4.5 --- src/pygpukit/ops/nn/rope.py | 368 ++++++++++++++++++++++++++++++++---- 1 file changed, 335 insertions(+), 33 deletions(-) diff --git a/src/pygpukit/ops/nn/rope.py b/src/pygpukit/ops/nn/rope.py index 56aa71c..7ab9c21 100644 --- a/src/pygpukit/ops/nn/rope.py +++ b/src/pygpukit/ops/nn/rope.py @@ -5,6 +5,8 @@ from __future__ import annotations +import numpy as np + from pygpukit.core.array import GPUArray from pygpukit.core.backend import NativeBackend, get_backend from pygpukit.core.factory import from_numpy @@ -51,6 +53,7 @@ def _rope_inplace_cpu( sin: GPUArray, ) -> None: """CPU implementation of rope_inplace.""" + backend = get_backend() q_np = q.to_numpy() k_np = k.to_numpy() @@ -82,8 +85,8 @@ def _rope_inplace_cpu( k_np[s, h, half_dim:] = k1 * c + k0 * sn # Update the GPUArray data in-place - q._data = from_numpy(q_np)._data - k._data = from_numpy(k_np)._data + backend.copy_host_to_device(q_np.ravel(), q._device_ptr) + backend.copy_host_to_device(k_np.ravel(), k._device_ptr) def _rope_inplace_native( @@ -156,11 +159,53 @@ def rope_init_ntk_aware( >>> cos, sin = rope_init_ntk_aware(8192, 128, scale=2.0) >>> rope_inplace(q, k, cos, sin) """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - cos_native, sin_native = native.rope_init_ntk_aware(max_seq_len, head_dim, base, scale) - return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + cos_native, sin_native = native.rope_init_ntk_aware( + max_seq_len, head_dim, base, scale + ) + return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + else: + return _rope_init_ntk_aware_cpu(max_seq_len, head_dim, base, scale) + + +def _rope_init_ntk_aware_cpu( + max_seq_len: int, + head_dim: int, + base: float, + scale: float, +) -> tuple[GPUArray, GPUArray]: + """CPU implementation of NTK-aware RoPE initialization.""" + # NTK-aware scaling: base' = base * scale^(dim / (dim - 2)) + scaled_base = base * (scale ** (head_dim / (head_dim - 2))) if scale > 1.0 else base + + # Compute inverse frequencies + half_dim = head_dim // 2 + inv_freq = 1.0 / (scaled_base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim)) + + # Compute positions + positions = np.arange(max_seq_len, dtype=np.float32) + + # Compute angles: [max_seq_len, half_dim] + angles = np.outer(positions, inv_freq) + + # Compute cos and sin, then interleave to get [max_seq_len, head_dim] + cos_half = np.cos(angles) + sin_half = np.sin(angles) + + # Interleave: [cos0, cos0, cos1, cos1, ...] for compatibility with RoPE apply + cos_table = np.zeros((max_seq_len, head_dim), dtype=np.float32) + sin_table = np.zeros((max_seq_len, head_dim), dtype=np.float32) + cos_table[:, 0::2] = cos_half + cos_table[:, 1::2] = cos_half + sin_table[:, 0::2] = sin_half + sin_table[:, 1::2] = sin_half + + return from_numpy(cos_table), from_numpy(sin_table) def rope_init_yarn( @@ -200,13 +245,79 @@ def rope_init_yarn( >>> cos, sin = rope_init_yarn(32768, 128, scale=4.0, original_max_len=4096) >>> rope_inplace(q, k, cos, sin) """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - cos_native, sin_native = native.rope_init_yarn( - max_seq_len, head_dim, base, scale, original_max_len, beta_fast, beta_slow, mscale + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + cos_native, sin_native = native.rope_init_yarn( + max_seq_len, + head_dim, + base, + scale, + original_max_len, + beta_fast, + beta_slow, + mscale, + ) + return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + else: + return _rope_init_yarn_cpu( + max_seq_len, head_dim, base, scale, original_max_len, beta_fast, beta_slow + ) + + +def _rope_init_yarn_cpu( + max_seq_len: int, + head_dim: int, + base: float, + scale: float, + original_max_len: int, + beta_fast: float, + beta_slow: float, +) -> tuple[GPUArray, GPUArray]: + """CPU implementation of YaRN RoPE initialization.""" + half_dim = head_dim // 2 + + # Compute base frequencies + inv_freq = 1.0 / (base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim)) + + # Compute wavelengths for each dimension + wavelengths = 2 * np.pi / inv_freq + + # Compute interpolation factors (YaRN dimension-wise interpolation) + low_freq_wavelen = original_max_len / beta_slow + high_freq_wavelen = original_max_len / beta_fast + + # Interpolation factor: 0 = no interpolation, 1 = full interpolation + smooth = np.clip( + (wavelengths - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen), 0, 1 ) - return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + + # Apply interpolation: mix between original and scaled frequencies + scaled_inv_freq = inv_freq / scale + interpolated_inv_freq = (1 - smooth) * scaled_inv_freq + smooth * inv_freq + + # Compute positions + positions = np.arange(max_seq_len, dtype=np.float32) + + # Compute angles + angles = np.outer(positions, interpolated_inv_freq) + + # Compute cos and sin + cos_half = np.cos(angles) + sin_half = np.sin(angles) + + # Interleave + cos_table = np.zeros((max_seq_len, head_dim), dtype=np.float32) + sin_table = np.zeros((max_seq_len, head_dim), dtype=np.float32) + cos_table[:, 0::2] = cos_half + cos_table[:, 1::2] = cos_half + sin_table[:, 0::2] = sin_half + sin_table[:, 1::2] = sin_half + + return from_numpy(cos_table), from_numpy(sin_table) def rope_init_linear( @@ -229,11 +340,51 @@ def rope_init_linear( Returns: Tuple of (cos_table, sin_table) each of shape [max_seq_len, head_dim]. """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - cos_native, sin_native = native.rope_init_linear(max_seq_len, head_dim, base, scale) - return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + cos_native, sin_native = native.rope_init_linear( + max_seq_len, head_dim, base, scale + ) + return GPUArray._wrap_native(cos_native), GPUArray._wrap_native(sin_native) + else: + return _rope_init_linear_cpu(max_seq_len, head_dim, base, scale) + + +def _rope_init_linear_cpu( + max_seq_len: int, + head_dim: int, + base: float, + scale: float, +) -> tuple[GPUArray, GPUArray]: + """CPU implementation of linear position interpolation RoPE.""" + half_dim = head_dim // 2 + + # Compute inverse frequencies + inv_freq = 1.0 / (base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim)) + + # Compute scaled positions (linear interpolation: pos' = pos / scale) + positions = np.arange(max_seq_len, dtype=np.float32) / scale + + # Compute angles + angles = np.outer(positions, inv_freq) + + # Compute cos and sin + cos_half = np.cos(angles) + sin_half = np.sin(angles) + + # Interleave + cos_table = np.zeros((max_seq_len, head_dim), dtype=np.float32) + sin_table = np.zeros((max_seq_len, head_dim), dtype=np.float32) + cos_table[:, 0::2] = cos_half + cos_table[:, 1::2] = cos_half + sin_table[:, 0::2] = sin_half + sin_table[:, 1::2] = sin_half + + return from_numpy(cos_table), from_numpy(sin_table) def pope_init_encoding( @@ -259,11 +410,40 @@ def pope_init_encoding( >>> encoding = pope_init_encoding(2048, 128) >>> pope_inplace(q, k, encoding) """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - encoding_native = native.pope_init_encoding(max_seq_len, head_dim, base) - return GPUArray._wrap_native(encoding_native) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + encoding_native = native.pope_init_encoding(max_seq_len, head_dim, base) + return GPUArray._wrap_native(encoding_native) + else: + return _pope_init_encoding_cpu(max_seq_len, head_dim, base) + + +def _pope_init_encoding_cpu( + max_seq_len: int, + head_dim: int, + base: float, +) -> GPUArray: + """CPU implementation of sinusoidal positional encoding.""" + encoding = np.zeros((max_seq_len, head_dim), dtype=np.float32) + + positions = np.arange(max_seq_len, dtype=np.float32) + half_dim = head_dim // 2 + + # Compute inverse frequencies + inv_freq = 1.0 / (base ** (np.arange(0, half_dim, dtype=np.float32) / half_dim)) + + # Compute angles + angles = np.outer(positions, inv_freq) + + # PE(pos, 2i) = sin, PE(pos, 2i+1) = cos + encoding[:, 0::2] = np.sin(angles) + encoding[:, 1::2] = np.cos(angles) + + return from_numpy(encoding) def pope_inplace( @@ -283,10 +463,51 @@ def pope_inplace( encoding: Position encoding [max_seq_len, head_dim] (f32). start_pos: Starting position for incremental decoding. """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - native.pope_inplace(q._get_native(), k._get_native(), encoding._get_native(), start_pos) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + native.pope_inplace( + q._get_native(), k._get_native(), encoding._get_native(), start_pos + ) + else: + _pope_inplace_cpu(q, k, encoding, start_pos) + + +def _pope_inplace_cpu( + q: GPUArray, + k: GPUArray, + encoding: GPUArray, + start_pos: int, +) -> None: + """CPU implementation of PoPE in-place application.""" + backend = get_backend() + + q_np = q.to_numpy() + k_np = k.to_numpy() + enc_np = encoding.to_numpy() + + seq_len = q_np.shape[0] + n_heads_q = q_np.shape[1] + n_heads_k = k_np.shape[1] + + # Add positional encoding to each position + for s in range(seq_len): + pos = start_pos + s + enc_pos = enc_np[pos] + + # Add to all heads + for h in range(n_heads_q): + q_np[s, h] = q_np[s, h] + enc_pos + + for h in range(n_heads_k): + k_np[s, h] = k_np[s, h] + enc_pos + + # Update the GPUArray data in-place + backend.copy_host_to_device(q_np.ravel(), q._device_ptr) + backend.copy_host_to_device(k_np.ravel(), k._device_ptr) def alibi_init_slopes(num_heads: int) -> GPUArray: @@ -307,11 +528,25 @@ def alibi_init_slopes(num_heads: int) -> GPUArray: >>> slopes = alibi_init_slopes(32) >>> bias = alibi_compute_bias(512, 32, slopes) """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - slopes_native = native.alibi_init_slopes(num_heads) - return GPUArray._wrap_native(slopes_native) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + slopes_native = native.alibi_init_slopes(num_heads) + return GPUArray._wrap_native(slopes_native) + else: + return _alibi_init_slopes_cpu(num_heads) + + +def _alibi_init_slopes_cpu(num_heads: int) -> GPUArray: + """CPU implementation of ALiBi slopes initialization.""" + # m_h = 2^(-8 * (h+1) / num_heads) + slopes = np.array( + [2 ** (-8 * (h + 1) / num_heads) for h in range(num_heads)], dtype=np.float32 + ) + return from_numpy(slopes) def alibi_compute_bias( @@ -334,11 +569,45 @@ def alibi_compute_bias( Returns: Bias tensor of shape [num_heads, seq_len, seq_len]. """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - bias_native = native.alibi_compute_bias(seq_len, num_heads, slopes._get_native(), causal) - return GPUArray._wrap_native(bias_native) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + bias_native = native.alibi_compute_bias( + seq_len, num_heads, slopes._get_native(), causal + ) + return GPUArray._wrap_native(bias_native) + else: + return _alibi_compute_bias_cpu(seq_len, num_heads, slopes, causal) + + +def _alibi_compute_bias_cpu( + seq_len: int, + num_heads: int, + slopes: GPUArray, + causal: bool, +) -> GPUArray: + """CPU implementation of ALiBi bias computation.""" + slopes_np = slopes.to_numpy() + + # Create bias tensor [num_heads, seq_len, seq_len] + bias = np.zeros((num_heads, seq_len, seq_len), dtype=np.float32) + + # Compute distance matrix + for h in range(num_heads): + slope = slopes_np[h] + for i in range(seq_len): + for j in range(seq_len): + if causal and j > i: + # Causal mask: future positions are masked + bias[h, i, j] = -1e9 + else: + # ALiBi bias: -slope * distance + bias[h, i, j] = -slope * (i - j) + + return from_numpy(bias) def alibi_add_bias( @@ -355,10 +624,43 @@ def alibi_add_bias( slopes: Head-specific slopes [num_heads]. start_pos: Starting position for incremental decoding. """ - from pygpukit.core.backend import get_native_module + backend = get_backend() - native = get_native_module() - native.alibi_add_bias(scores._get_native(), slopes._get_native(), start_pos) + if isinstance(backend, NativeBackend) and backend.is_available(): + from pygpukit.core.backend import get_native_module + + native = get_native_module() + native.alibi_add_bias(scores._get_native(), slopes._get_native(), start_pos) + else: + _alibi_add_bias_cpu(scores, slopes, start_pos) + + +def _alibi_add_bias_cpu( + scores: GPUArray, + slopes: GPUArray, + start_pos: int, +) -> None: + """CPU implementation of ALiBi in-place bias addition.""" + backend = get_backend() + + scores_np = scores.to_numpy() + slopes_np = slopes.to_numpy() + + # scores shape: [batch, num_heads, q_len, kv_len] + batch, num_heads, q_len, kv_len = scores_np.shape + + for b in range(batch): + for h in range(num_heads): + slope = slopes_np[h] + for qi in range(q_len): + q_pos = start_pos + qi + for kj in range(kv_len): + # Distance from query position to key position + distance = q_pos - kj + scores_np[b, h, qi, kj] -= slope * distance + + # Update the GPUArray data in-place + backend.copy_host_to_device(scores_np.ravel(), scores._device_ptr) __all__ = [