Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions native/bindings/nn/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,14 @@ void init_nn_activation(py::module_& m) {
"Fused linear + bias + GELU: output = gelu(input @ weight^T + bias)\n"
"Uses CUTLASS TensorCore epilogue fusion for efficiency.\n"
"input: [batch, in_features], weight: [out_features, in_features], bias: [out_features]");

// ReLU squared (Primer paper)
m.def("relu2", py::overload_cast<const GPUArray&>(&ops::relu2),
py::arg("input"),
"ReLU squared activation: y = (max(0, x))^2\n"
"Introduced in the Primer paper (Google, 2021).");

m.def("relu2_", py::overload_cast<const GPUArray&, GPUArray&>(&ops::relu2),
py::arg("input"), py::arg("out"),
"ReLU squared with output buffer (for CUDA Graph capture)");
}
58 changes: 58 additions & 0 deletions native/bindings/nn/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,62 @@ void init_nn_rope(py::module_& m) {
"q: [seq_len, n_heads_q, head_dim] (bf16 or f16)\n"
"k: [seq_len, n_heads_k, head_dim] (bf16 or f16)\n"
"cos, sin: [seq_len, head_dim] (f32)");

// NTK-aware RoPE initialization
m.def("rope_init_ntk_aware", &ops::rope_init_ntk_aware,
py::arg("max_seq_len"), py::arg("head_dim"),
py::arg("base") = 10000.0f, py::arg("scale") = 1.0f,
"Initialize RoPE with NTK-aware frequency scaling.\n"
"Scales base frequency for context extension: base' = base * scale^(dim/(dim-2))\n"
"Returns: tuple of (cos_table, sin_table) each [max_seq_len, head_dim]");

// YaRN RoPE initialization
m.def("rope_init_yarn", &ops::rope_init_yarn,
py::arg("max_seq_len"), py::arg("head_dim"),
py::arg("base") = 10000.0f, py::arg("scale") = 1.0f,
py::arg("original_max_len") = 4096, py::arg("beta_fast") = 32.0f,
py::arg("beta_slow") = 1.0f, py::arg("mscale") = 0.1f,
"Initialize RoPE with YaRN dimension-wise interpolation.\n"
"Different scaling for different frequency bands (low/mid/high).\n"
"Returns: tuple of (cos_table, sin_table) each [max_seq_len, head_dim]");

// Linear position interpolation
m.def("rope_init_linear", &ops::rope_init_linear,
py::arg("max_seq_len"), py::arg("head_dim"),
py::arg("base") = 10000.0f, py::arg("scale") = 1.0f,
"Initialize RoPE with linear position interpolation.\n"
"Simple baseline: pos' = pos / scale. Degrades at high scales.\n"
"Returns: tuple of (cos_table, sin_table) each [max_seq_len, head_dim]");

// PoPE (Positional Encoding) - Alternative to RoPE
m.def("pope_init_encoding", &ops::pope_init_encoding,
py::arg("max_seq_len"), py::arg("head_dim"), py::arg("base") = 10000.0f,
"Initialize sinusoidal positional encoding table.\n"
"Returns: encoding tensor [max_seq_len, head_dim]");

m.def("pope_inplace", &ops::pope_inplace,
py::arg("q"), py::arg("k"), py::arg("encoding"), py::arg("start_pos") = 0,
"Apply additive positional encoding to Q and K in-place.\n"
"q: [seq_len, n_heads_q, head_dim]\n"
"k: [seq_len, n_heads_k, head_dim]\n"
"encoding: [max_seq_len, head_dim] (f32)");

// ALiBi (Attention with Linear Biases)
m.def("alibi_init_slopes", &ops::alibi_init_slopes,
py::arg("num_heads"),
"Initialize ALiBi head-specific slopes.\n"
"m_h = 2^(-8 * h / num_heads)\n"
"Returns: slopes tensor [num_heads]");

m.def("alibi_compute_bias", &ops::alibi_compute_bias,
py::arg("seq_len"), py::arg("num_heads"), py::arg("slopes"),
py::arg("causal") = true,
"Compute ALiBi bias matrix for attention.\n"
"Returns: bias tensor [num_heads, seq_len, seq_len]");

m.def("alibi_add_bias", &ops::alibi_add_bias,
py::arg("scores"), py::arg("slopes"), py::arg("start_pos") = 0,
"Add ALiBi bias to attention scores in-place.\n"
"scores: [batch, num_heads, q_len, kv_len]\n"
"slopes: [num_heads]");
}
74 changes: 74 additions & 0 deletions native/ops/nn/activation/relu2.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* ReLU squared (ReLU^2) activation: (max(0, x))^2
*
* Introduced in the Primer paper (Google, 2021).
* Benefits: stronger sparsity, continuous first derivative.
*/

namespace pygpukit {
namespace ops {

// Internal dispatch helper with capture stream support
static void relu2_dispatch(const GPUArray& input, GPUArray& result) {
size_t n = input.size();
const int block_size = 256;
const int grid_size = (n + block_size - 1) / block_size;

// Use capture stream if available
cudaStream_t stream = internal::get_capture_stream();

switch (input.dtype()) {
case DataType::Float32:
nn::relu2_f32_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<const float*>(input.data()),
static_cast<float*>(result.data()),
n);
break;
case DataType::Float16:
nn::relu2_f16_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<const __half*>(input.data()),
static_cast<__half*>(result.data()),
n);
break;
case DataType::BFloat16:
nn::relu2_bf16_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<const __nv_bfloat16*>(input.data()),
static_cast<__nv_bfloat16*>(result.data()),
n);
break;
default:
break;
}
}

GPUArray relu2(const GPUArray& input) {
if (input.dtype() != DataType::Float32 &&
input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) {
throw std::runtime_error("relu2 only supports float32, float16, bfloat16");
}

GPUArray result(input.shape(), input.dtype());
relu2_dispatch(input, result);
sync_and_check("relu2 kernel failed");
return result;
}

// ReLU squared with output buffer (for CUDA Graph capture)
void relu2(const GPUArray& input, GPUArray& out) {
if (input.dtype() != DataType::Float32 &&
input.dtype() != DataType::Float16 && input.dtype() != DataType::BFloat16) {
throw std::runtime_error("relu2 only supports float32, float16, bfloat16");
}
if (input.dtype() != out.dtype()) {
throw std::runtime_error("relu2: dtype mismatch between input and output");
}
if (input.shape() != out.shape()) {
throw std::runtime_error("relu2: shape mismatch between input and output");
}

relu2_dispatch(input, out);
sync_and_check("relu2 kernel failed");
}

} // namespace ops
} // namespace pygpukit
38 changes: 38 additions & 0 deletions native/ops/nn/activation_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ __device__ __forceinline__ float sigmoid_f32(float x) {
return 1.0f / (1.0f + expf(-x));
}

__device__ __forceinline__ float relu2_f32(float x) {
float relu_val = fmaxf(0.0f, x);
return relu_val * relu_val;
}

// ============================================================================
// Kernel declarations (always available)
// ============================================================================
Expand Down Expand Up @@ -88,6 +93,14 @@ __global__ void tanh_f16_kernel(const __half* __restrict__ input,
__global__ void tanh_bf16_kernel(const __nv_bfloat16* __restrict__ input,
__nv_bfloat16* __restrict__ output, size_t n);

// ReLU squared (Primer paper)
__global__ void relu2_f32_kernel(const float* __restrict__ input,
float* __restrict__ output, size_t n);
__global__ void relu2_f16_kernel(const __half* __restrict__ input,
__half* __restrict__ output, size_t n);
__global__ void relu2_bf16_kernel(const __nv_bfloat16* __restrict__ input,
__nv_bfloat16* __restrict__ output, size_t n);

// ============================================================================
// Kernel definitions (only when PYGPUKIT_IMPLEMENT_NN_KERNELS is defined)
// ============================================================================
Expand Down Expand Up @@ -229,6 +242,31 @@ __global__ void tanh_bf16_kernel(const __nv_bfloat16* __restrict__ input,
}
}

// ReLU squared kernels
__global__ void relu2_f32_kernel(const float* __restrict__ input,
float* __restrict__ output, size_t n) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) output[idx] = relu2_f32(input[idx]);
}

__global__ void relu2_f16_kernel(const __half* __restrict__ input,
__half* __restrict__ output, size_t n) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float x = __half2float(input[idx]);
output[idx] = __float2half(relu2_f32(x));
}
}

__global__ void relu2_bf16_kernel(const __nv_bfloat16* __restrict__ input,
__nv_bfloat16* __restrict__ output, size_t n) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float x = __bfloat162float(input[idx]);
output[idx] = __float2bfloat16(relu2_f32(x));
}
}

#endif // PYGPUKIT_IMPLEMENT_NN_KERNELS

} // namespace nn
Expand Down
109 changes: 109 additions & 0 deletions native/ops/nn/alibi/alibi.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/**
* ALiBi (Attention with Linear Biases) dispatch functions
*
* Provides:
* - alibi_init_slopes: Compute head-specific slopes
* - alibi_compute_bias: Create bias matrix for attention
* - alibi_add_bias: Add bias to attention scores in-place
*/

#include "alibi_kernels.cuh"

namespace pygpukit {
namespace ops {

GPUArray alibi_init_slopes(int num_heads) {
// Create slopes tensor: [num_heads]
GPUArray slopes({(size_t)num_heads}, DataType::Float32);

const int block_size = 256;
const int grid_size = (num_heads + block_size - 1) / block_size;

cudaStream_t stream = internal::get_capture_stream();

nn::alibi_init_slopes_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<float*>(slopes.data()),
num_heads);

sync_and_check("alibi_init_slopes kernel failed");
return slopes;
}

GPUArray alibi_compute_bias(int seq_len, int num_heads, const GPUArray& slopes, bool causal) {
// Create bias tensor: [num_heads, seq_len, seq_len]
if (slopes.dtype() != DataType::Float32) {
throw std::runtime_error("alibi_compute_bias: slopes must be float32");
}
if (slopes.size() != (size_t)num_heads) {
throw std::runtime_error("alibi_compute_bias: slopes size must match num_heads");
}

GPUArray bias({(size_t)num_heads, (size_t)seq_len, (size_t)seq_len}, DataType::Float32);

int total = num_heads * seq_len * seq_len;
const int block_size = 256;
const int grid_size = (total + block_size - 1) / block_size;

cudaStream_t stream = internal::get_capture_stream();

if (causal) {
nn::alibi_compute_bias_causal_f32_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<float*>(bias.data()),
static_cast<const float*>(slopes.data()),
seq_len,
num_heads);
} else {
nn::alibi_compute_bias_f32_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<float*>(bias.data()),
static_cast<const float*>(slopes.data()),
seq_len,
num_heads);
}

sync_and_check("alibi_compute_bias kernel failed");
return bias;
}

void alibi_add_bias(GPUArray& scores, const GPUArray& slopes, int start_pos) {
// scores: [batch, num_heads, q_len, kv_len]
// slopes: [num_heads]

if (scores.ndim() != 4) {
throw std::runtime_error("alibi_add_bias: scores must be 4D [batch, heads, q_len, kv_len]");
}
if (scores.dtype() != DataType::Float32) {
throw std::runtime_error("alibi_add_bias: scores must be float32");
}
if (slopes.dtype() != DataType::Float32) {
throw std::runtime_error("alibi_add_bias: slopes must be float32");
}

int batch_size = scores.shape()[0];
int num_heads = scores.shape()[1];
int q_len = scores.shape()[2];
int kv_len = scores.shape()[3];

if (slopes.size() != (size_t)num_heads) {
throw std::runtime_error("alibi_add_bias: slopes size must match num_heads");
}

int total = batch_size * num_heads * q_len * kv_len;
const int block_size = 256;
const int grid_size = (total + block_size - 1) / block_size;

cudaStream_t stream = internal::get_capture_stream();

nn::alibi_add_bias_f32_kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<float*>(scores.data()),
static_cast<const float*>(slopes.data()),
batch_size,
num_heads,
q_len,
kv_len,
start_pos);

sync_and_check("alibi_add_bias kernel failed");
}

} // namespace ops
} // namespace pygpukit
Loading