Skip to content

feat(nn): Add ReLU² (Squared ReLU) activation function #171

@m96-chan

Description

@m96-chan

Summary

Implement ReLU² (Squared ReLU) activation function for MLP layers.

Formula

ReLU²(x) = (max(0, x))² = ReLU(x)²

Background

ReLU² was introduced in the Primer paper and has shown benefits:

  • Sparsity: Stronger sparsity than standard ReLU
  • Smoothness: Continuous first derivative (unlike ReLU)
  • Performance: Improved training dynamics in some architectures

Used in:

  • Primer (Google, 2021)
  • Some MoE implementations
  • Sparse attention variants

Comparison

Activation Formula Derivative
ReLU max(0, x) 1 if x > 0 else 0
ReLU² max(0, x)² 2x if x > 0 else 0
GELU x * Φ(x) complex
SiLU x * σ(x) σ(x) + xσ(x)(1-σ(x))

Proposed Implementation

Native Kernels

native/ops/nn/activation/
├── gelu.inl      # Existing
├── silu.inl      # Existing
├── relu2.inl     # NEW

Add to activation_kernels.cuh:

// ReLU² kernel
__global__ void relu2_f32_kernel(
    const float* __restrict__ x,
    float* __restrict__ y,
    size_t n
) {
    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        float val = x[idx];
        float relu_val = fmaxf(0.0f, val);
        y[idx] = relu_val * relu_val;
    }
}

Python API

from pygpukit.ops.nn import relu2

# Basic usage
y = relu2(x)

# With pre-allocated output (for CUDA Graph)
relu2(x, out=y)

Tasks

  • Implement relu2 kernels (f32, f16, bf16)
  • Add in-place variant relu2_
  • Add Python bindings
  • Add Python API in ops/nn.py
  • Add tests
  • Add to benchmark suite

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions