From a9d4983484fe9a46395811c9e64b64be7cbe8e37 Mon Sep 17 00:00:00 2001 From: jonah Date: Sun, 8 Feb 2026 00:17:37 -0500 Subject: [PATCH 1/2] add benchmark script (cherry picked from commit 9eb0b84) --- bench/benchmark_online_softmax.py | 161 ++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 bench/benchmark_online_softmax.py diff --git a/bench/benchmark_online_softmax.py b/bench/benchmark_online_softmax.py new file mode 100644 index 0000000..85f5046 --- /dev/null +++ b/bench/benchmark_online_softmax.py @@ -0,0 +1,161 @@ +"""Benchmark softmax_online op against torch.softmax and torch.compile(torch.softmax).""" + +import argparse + +import torch + +from forge_cute_py.ops.softmax_online import softmax_online, softmax_fwd, softmax_bwd +from forge_cute_py.util.bench import do_bench, estimate_bandwidth, summarize_times + +SHORT_M = [128, 512, 2048, 8192] +SHORT_N = [1024, 2048, 4096, 8192] + +LONG_M = [64, 128, 256] +LONG_N = [16384, 32768, 65536, 131072] + +DEFAULT_DTYPES = ["float16", "bfloat16", "float32"] + + +def parse_int_list(s: str) -> list[int]: + return [int(x.strip()) for x in s.split(",")] + + +def parse_str_list(s: str) -> list[str]: + return [x.strip() for x in s.split(",")] + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark softmax_online op") + parser.add_argument("--long", action="store_true", help="Use long-N benchmark suite (small M, large N)") + parser.add_argument("--m-sizes", type=parse_int_list, default=None) + parser.add_argument("--n-sizes", type=parse_int_list, default=None) + parser.add_argument("--dtypes", type=parse_str_list, default=DEFAULT_DTYPES) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--iterations", type=int, default=100) + args = parser.parse_args() + + if args.m_sizes is None: + args.m_sizes = LONG_M if args.long else SHORT_M + if args.n_sizes is None: + args.n_sizes = LONG_N if args.long else SHORT_N + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required for benchmarking") + + gpu_name = torch.cuda.get_device_name(0) + suite = "long" if args.long else "short" + print(f"softmax_online benchmarks [{suite}] ({gpu_name})") + print() + + header = ( + f"{'M':>6} {'N':>6} {'Dtype':<10} {'Op':<18} {'Pass':<5} " + f"{'p50 (ms)':>10} {'BW (GB/s)':>10} {'vs torch':>10}" + ) + print(header) + print("-" * len(header)) + + for m in args.m_sizes: + for n in args.n_sizes: + for dtype_str in args.dtypes: + dtype = getattr(torch, dtype_str) + x = torch.randn(m, n, device="cuda", dtype=dtype) + assert n % 32 == 0, f"Inner dimension N must be a multiple of 32, got {n}" + elem = x.element_size() + + # --- Forward bandwidth: read input + write output --- + fwd_bytes = 2 * m * n * elem + + # --- torch.softmax fwd baseline --- + torch_fn = lambda: torch.softmax(x, dim=-1) + torch_times = do_bench(torch_fn, warmup=args.warmup, rep=args.iterations) + torch_stats = summarize_times(torch_times) + torch_fwd_p50 = torch_stats["p50_ms"] + torch_fwd_bw = estimate_bandwidth(fwd_bytes, torch_fwd_p50) + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.softmax':<18} {'fwd':<5} " + f"{torch_fwd_p50:>10.4f} {torch_fwd_bw:>10.2f} {1.0:>10.2f}x" + ) + + # --- torch.compile fwd --- + try: + compiled_ref = torch.compile(lambda t: torch.softmax(t, dim=-1)) + compiled_ref(x) + fn = lambda: compiled_ref(x) + compiled_times = do_bench(fn, warmup=args.warmup, rep=args.iterations) + compiled_stats = summarize_times(compiled_times) + compiled_p50 = compiled_stats["p50_ms"] + compiled_bw = estimate_bandwidth(fwd_bytes, compiled_p50) + ratio = compiled_p50 / torch_fwd_p50 if torch_fwd_p50 > 0 else float("inf") + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.compile':<18} {'fwd':<5} " + f"{compiled_p50:>10.4f} {compiled_bw:>10.2f} {ratio:>10.2f}x" + ) + except Exception as e: + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.compile':<18} {'fwd':<5} " + f"{'ERROR':>10} {'':>10} {'':>10} {e}" + ) + + # --- softmax_online fwd --- + try: + softmax_fwd(x, dim=-1) + fn = lambda: softmax_fwd(x, dim=-1) + times = do_bench(fn, warmup=args.warmup, rep=args.iterations) + stats = summarize_times(times) + p50 = stats["p50_ms"] + bw = estimate_bandwidth(fwd_bytes, p50) + ratio = p50 / torch_fwd_p50 if torch_fwd_p50 > 0 else float("inf") + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'fwd':<5} " + f"{p50:>10.4f} {bw:>10.2f} {ratio:>10.2f}x" + ) + except Exception as e: + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'fwd':<5} " + f"{'ERROR':>10} {'':>10} {'':>10} {e}" + ) + + # --- Backward pass benchmarks --- + # Pre-compute softmax output y and fake upstream gradient dy + y = torch.softmax(x, dim=-1) + dy = torch.randn_like(y) + + # Backward bandwidth: read dy + read y + write dx = 3 * M * N * elem + bwd_bytes = 3 * m * n * elem + + # --- torch backward baseline --- + torch_bwd_fn = lambda: torch._softmax_backward_data(dy, y, -1, x.dtype) + torch_bwd_times = do_bench(torch_bwd_fn, warmup=args.warmup, rep=args.iterations) + torch_bwd_stats = summarize_times(torch_bwd_times) + torch_bwd_p50 = torch_bwd_stats["p50_ms"] + torch_bwd_bw = estimate_bandwidth(bwd_bytes, torch_bwd_p50) + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'torch.softmax':<18} {'bwd':<5} " + f"{torch_bwd_p50:>10.4f} {torch_bwd_bw:>10.2f} {1.0:>10.2f}x" + ) + + # --- softmax_online bwd --- + try: + y_ours = softmax_fwd(x, dim=-1) + softmax_bwd(dy, y_ours, dim=-1) + fn = lambda: softmax_bwd(dy, y_ours, dim=-1) + times = do_bench(fn, warmup=args.warmup, rep=args.iterations) + stats = summarize_times(times) + p50 = stats["p50_ms"] + bw = estimate_bandwidth(bwd_bytes, p50) + ratio = p50 / torch_bwd_p50 if torch_bwd_p50 > 0 else float("inf") + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'bwd':<5} " + f"{p50:>10.4f} {bw:>10.2f} {ratio:>10.2f}x" + ) + except Exception as e: + print( + f"{m:>6} {n:>6} {dtype_str:<10} {'softmax_online':<18} {'bwd':<5} " + f"{'ERROR':>10} {'':>10} {'':>10} {e}" + ) + + print() + + +if __name__ == "__main__": + main() From e59e0b59cecd63c4488769bfbeade2a86de7a15b Mon Sep 17 00:00:00 2001 From: debashishc Date: Sun, 8 Feb 2026 00:48:03 -0500 Subject: [PATCH 2/2] Add softmax impl mode switch and benchmark/test wiring --- DEVELOPMENT.md | 7 + bench/benchmark_online_softmax.py | 17 +- bench/run.py | 29 ++- forge_cute_py/ops/softmax_online.py | 273 ++++++++++++++++++++++------ tests/test_softmax_online.py | 69 +++++++ 5 files changed, 329 insertions(+), 66 deletions(-) diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 13be4d9..8ecc676 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -128,10 +128,17 @@ non-zero tolerance in tests. uv run python bench/run.py --suite smoke uv run python bench/benchmark_copy_transpose.py --tile-size 16 uv run python bench/benchmark_reduce.py +uv run python bench/benchmark_online_softmax.py --impl auto +uv run python bench/benchmark_online_softmax.py --impl kernel modal run bench/modal_bench.py --suite smoke --out results.json modal run bench/modal_bench.py --suite smoke --op reduce_sum --out results.json ``` +`softmax_online` backend mode is controlled by `FORGE_SOFTMAX_IMPL`: +- `auto` (default): try kernel if present, otherwise fallback to reference. +- `ref`: force reference path. +- `kernel`: require kernel path and fail fast if missing/incomplete. + > **Warning:** Modal benchmarks incur GPU costs. Review `bench/modal_bench.py` > and verify timeout/GPU settings before running. Start with `--suite smoke` > to validate your setup. You are responsible for any credits consumed. diff --git a/bench/benchmark_online_softmax.py b/bench/benchmark_online_softmax.py index 85f5046..97e53ed 100644 --- a/bench/benchmark_online_softmax.py +++ b/bench/benchmark_online_softmax.py @@ -1,10 +1,11 @@ """Benchmark softmax_online op against torch.softmax and torch.compile(torch.softmax).""" import argparse +import os import torch -from forge_cute_py.ops.softmax_online import softmax_online, softmax_fwd, softmax_bwd +from forge_cute_py.ops.softmax_online import softmax_fwd, softmax_bwd from forge_cute_py.util.bench import do_bench, estimate_bandwidth, summarize_times SHORT_M = [128, 512, 2048, 8192] @@ -26,13 +27,22 @@ def parse_str_list(s: str) -> list[str]: def main(): parser = argparse.ArgumentParser(description="Benchmark softmax_online op") - parser.add_argument("--long", action="store_true", help="Use long-N benchmark suite (small M, large N)") + parser.add_argument( + "--long", action="store_true", help="Use long-N benchmark suite (small M, large N)" + ) parser.add_argument("--m-sizes", type=parse_int_list, default=None) parser.add_argument("--n-sizes", type=parse_int_list, default=None) parser.add_argument("--dtypes", type=parse_str_list, default=DEFAULT_DTYPES) parser.add_argument("--warmup", type=int, default=20) parser.add_argument("--iterations", type=int, default=100) + parser.add_argument( + "--impl", + choices=["auto", "ref", "kernel"], + default="auto", + help="softmax_online backend mode (FORGE_SOFTMAX_IMPL)", + ) args = parser.parse_args() + os.environ["FORGE_SOFTMAX_IMPL"] = args.impl if args.m_sizes is None: args.m_sizes = LONG_M if args.long else SHORT_M @@ -44,7 +54,7 @@ def main(): gpu_name = torch.cuda.get_device_name(0) suite = "long" if args.long else "short" - print(f"softmax_online benchmarks [{suite}] ({gpu_name})") + print(f"softmax_online benchmarks [{suite}] ({gpu_name}) [impl={args.impl}]") print() header = ( @@ -59,7 +69,6 @@ def main(): for dtype_str in args.dtypes: dtype = getattr(torch, dtype_str) x = torch.randn(m, n, device="cuda", dtype=dtype) - assert n % 32 == 0, f"Inner dimension N must be a multiple of 32, got {n}" elem = x.element_size() # --- Forward bandwidth: read input + write output --- diff --git a/bench/run.py b/bench/run.py index f0a343e..1c1e1ec 100644 --- a/bench/run.py +++ b/bench/run.py @@ -1,5 +1,6 @@ import argparse import json +import os import sys from pathlib import Path @@ -137,10 +138,30 @@ def fn(): def fn(): return ops.softmax_online(x, dim=dim) - times = do_bench(fn, warmup=warmup, rep=iterations) - stats = summarize_times(times) - bytes_moved = _estimate_bytes(op_name, shape, dtype, dim=dim) - bw = estimate_bandwidth(bytes_moved, stats["p50_ms"]) + try: + times = do_bench(fn, warmup=warmup, rep=iterations) + stats = summarize_times(times) + bytes_moved = _estimate_bytes(op_name, shape, dtype, dim=dim) + bw = estimate_bandwidth(bytes_moved, stats["p50_ms"]) + except NotImplementedError as exc: + return { + "status": "skipped", + "op": op_name, + "shape": shape, + "dtype": str(dtype).replace("torch.", ""), + "dim": dim, + "reason": str(exc), + } + except Exception as exc: + impl = os.getenv("FORGE_SOFTMAX_IMPL", "auto") + return { + "status": "skipped", + "op": op_name, + "shape": shape, + "dtype": str(dtype).replace("torch.", ""), + "dim": dim, + "reason": f"softmax_online failed (impl={impl}): {exc}", + } return { "status": "ok", "op": op_name, diff --git a/forge_cute_py/ops/softmax_online.py b/forge_cute_py/ops/softmax_online.py index e1248e2..889a21e 100644 --- a/forge_cute_py/ops/softmax_online.py +++ b/forge_cute_py/ops/softmax_online.py @@ -1,31 +1,207 @@ +from __future__ import annotations + +import importlib +import os +from types import ModuleType +from typing import Callable + import torch +SoftmaxForwardImpl = Callable[[torch.Tensor, int], torch.Tensor] +SoftmaxBackwardImpl = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor] -@torch.library.custom_op("forge_cute_py::_softmax_fwd", mutates_args={"out"}) -def _softmax_fwd(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: - """Softmax forward pass. +_SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32, torch.float64) +_SUPPORTED_IMPL_MODES = {"auto", "ref", "kernel"} - Args: - x: Input tensor of shape (M, N) - out: Output tensor of same shape as x (mutated in-place) - dim: Dimension to apply softmax over - """ - assert x.dim() == 2, "Input must be 2D" - assert x.is_cuda, "Tensor must be on CUDA device" - assert x.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64], ( - "Unsupported dtype" + +def _normalize_dim(dim: int, ndim: int) -> int: + dim = dim if dim >= 0 else ndim + dim + if dim not in (0, 1): + raise ValueError(f"softmax_online expects dim in {{-1, 0, 1}} for 2D tensors, got {dim}") + return dim + + +def _validate_impl_mode() -> str: + impl = os.getenv("FORGE_SOFTMAX_IMPL", "auto").strip().lower() + if impl not in _SUPPORTED_IMPL_MODES: + choices = ", ".join(sorted(_SUPPORTED_IMPL_MODES)) + raise ValueError(f"FORGE_SOFTMAX_IMPL must be one of {{{choices}}}, got '{impl}'") + return impl + + +def _load_kernel_module() -> tuple[ModuleType | None, str | None]: + module_name = "forge_cute_py.kernels.softmax_online" + try: + return importlib.import_module(module_name), None + except ModuleNotFoundError as exc: + if exc.name == module_name: + return None, f"module '{module_name}' was not found" + if exc.name is None: + return None, f"failed importing '{module_name}': {exc}" + return None, f"dependency '{exc.name}' missing while importing '{module_name}'" + except Exception as exc: # pragma: no cover - defensive runtime diagnostics + return None, f"failed importing '{module_name}': {exc}" + + +def _resolve_kernel_forward(module: ModuleType) -> SoftmaxForwardImpl: + for attr in ("softmax_fwd", "softmax_online"): + fn = getattr(module, attr, None) + if callable(fn): + return fn + raise AttributeError( + "kernel module must define a callable 'softmax_fwd(x, dim)' or 'softmax_online(x, dim)'" ) - assert out.shape == x.shape, "Output shape must match input" - # Normalize dim to positive index - dim = dim if dim >= 0 else x.ndim + dim - assert dim in [0, 1], f"dim must be 0 or 1 for 2D tensors, got {dim}" - # For now, use reference implementation - # Future: call kernel implementation when available +def _resolve_kernel_backward(module: ModuleType) -> SoftmaxBackwardImpl | None: + fn = getattr(module, "softmax_bwd", None) + if callable(fn): + return fn + return None + + +def _call_forward(fn: SoftmaxForwardImpl, x: torch.Tensor, dim: int) -> torch.Tensor: + try: + return fn(x, dim=dim) + except TypeError: + return fn(x, dim) + + +def _call_backward( + fn: SoftmaxBackwardImpl, + dy: torch.Tensor, + y: torch.Tensor, + dim: int, +) -> torch.Tensor: + try: + return fn(dy, y, dim=dim) + except TypeError: + return fn(dy, y, dim) + + +def _reference_softmax_forward(x: torch.Tensor, dim: int) -> torch.Tensor: from forge_cute_py.ref import softmax_online as softmax_online_ref - result = softmax_online_ref(x, dim=dim) + return softmax_online_ref(x, dim=dim) + + +def _reference_softmax_backward(dy: torch.Tensor, y: torch.Tensor, dim: int) -> torch.Tensor: + dot_product = (dy * y).sum(dim=dim, keepdim=True) + return y * (dy - dot_product) + + +def _forward_impl(x: torch.Tensor, dim: int) -> torch.Tensor: + impl = _validate_impl_mode() + if impl == "ref": + return _reference_softmax_forward(x, dim) + + module, reason = _load_kernel_module() + if module is None: + if impl == "auto": + return _reference_softmax_forward(x, dim) + raise NotImplementedError( + "FORGE_SOFTMAX_IMPL=kernel requested, but softmax kernel is unavailable: " + f"{reason}. Add forge_cute_py/kernels/softmax_online.py with softmax_fwd()." + ) + + try: + kernel_forward = _resolve_kernel_forward(module) + except AttributeError as exc: + if impl == "auto": + return _reference_softmax_forward(x, dim) + raise NotImplementedError( + "FORGE_SOFTMAX_IMPL=kernel requested, but softmax kernel forward entry point is " + f"incomplete: {exc}" + ) from exc + + try: + return _call_forward(kernel_forward, x, dim) + except NotImplementedError as exc: + if impl == "auto": + return _reference_softmax_forward(x, dim) + raise NotImplementedError( + "FORGE_SOFTMAX_IMPL=kernel requested, but softmax kernel forward is not implemented." + ) from exc + + +def _backward_impl(dy: torch.Tensor, y: torch.Tensor, dim: int) -> torch.Tensor: + impl = _validate_impl_mode() + if impl == "ref": + return _reference_softmax_backward(dy, y, dim) + + module, reason = _load_kernel_module() + if module is None: + if impl == "auto": + return _reference_softmax_backward(dy, y, dim) + raise NotImplementedError( + "FORGE_SOFTMAX_IMPL=kernel requested, but softmax kernel is unavailable: " + f"{reason}. Add forge_cute_py/kernels/softmax_online.py with softmax_bwd()." + ) + + kernel_backward = _resolve_kernel_backward(module) + if kernel_backward is None: + if impl == "auto": + return _reference_softmax_backward(dy, y, dim) + raise NotImplementedError( + "FORGE_SOFTMAX_IMPL=kernel requested, but 'softmax_bwd' is missing in " + "forge_cute_py.kernels.softmax_online." + ) + + try: + return _call_backward(kernel_backward, dy, y, dim) + except NotImplementedError as exc: + if impl == "auto": + return _reference_softmax_backward(dy, y, dim) + raise NotImplementedError( + "FORGE_SOFTMAX_IMPL=kernel requested, but softmax kernel backward is not implemented." + ) from exc + + +def _ensure_forward_inputs(x: torch.Tensor, out: torch.Tensor, dim: int) -> int: + if x.dim() != 2: + raise ValueError("Input must be 2D") + if not x.is_cuda: + raise ValueError("Tensor must be on CUDA device") + if x.dtype not in _SUPPORTED_DTYPES: + raise ValueError(f"Unsupported dtype: {x.dtype}") + if out.shape != x.shape: + raise ValueError("Output shape must match input") + if out.dtype != x.dtype: + raise ValueError("Output dtype must match input dtype") + if out.device != x.device: + raise ValueError("Output device must match input device") + return _normalize_dim(dim, x.ndim) + + +def _ensure_backward_inputs(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: int) -> int: + if dy.dim() != 2 or y.dim() != 2 or dx.dim() != 2: + raise ValueError("Tensors must be 2D") + if dy.shape != y.shape or y.shape != dx.shape: + raise ValueError("All tensors must have same shape") + if not dy.is_cuda or not y.is_cuda or not dx.is_cuda: + raise ValueError("Tensors must be on CUDA") + if dy.dtype != y.dtype or y.dtype != dx.dtype: + raise ValueError("dy, y, and dx must have same dtype") + if dy.dtype not in _SUPPORTED_DTYPES: + raise ValueError(f"Unsupported dtype: {dy.dtype}") + return _normalize_dim(dim, dy.ndim) + + +@torch.library.custom_op("forge_cute_py::_softmax_fwd", mutates_args={"out"}) +def _softmax_fwd(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: + """Softmax forward pass.""" + dim = _ensure_forward_inputs(x, out, dim) + result = _forward_impl(x, dim) + if result.shape != x.shape: + raise ValueError( + f"softmax forward produced invalid shape {result.shape}, expected {x.shape}" + ) + if result.dtype != x.dtype: + raise ValueError( + f"softmax forward produced invalid dtype {result.dtype}, expected {x.dtype}" + ) + if result.device != x.device: + raise ValueError(f"softmax forward produced output on {result.device}, expected {x.device}") out.copy_(result) @@ -41,32 +217,21 @@ def softmax_fwd(x: torch.Tensor, dim: int = -1) -> torch.Tensor: @torch.library.custom_op("forge_cute_py::_softmax_backward", mutates_args={"dx"}) def _softmax_backward(dy: torch.Tensor, y: torch.Tensor, dx: torch.Tensor, dim: int = -1) -> None: - """Softmax backward pass. - - For softmax output y = softmax(x), gradient: grad_x = y * (grad_y - dot) - where dot = (grad_y * y).sum(dim, keepdim=True) - - Args: - dy: Upstream gradients (M, N) - y: Softmax output (M, N) - dx: Input gradients (mutated in-place) - dim: Dimension softmax was applied over - """ - assert dy.dim() == 2 and y.dim() == 2, "Tensors must be 2D" - assert dy.shape == y.shape == dx.shape, "All tensors must have same shape" - assert dy.is_cuda and y.is_cuda, "Tensors must be on CUDA" - assert dy.dtype == y.dtype, "dy and y must have same dtype" - assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64], ( - "Unsupported dtype" - ) - - # Normalize dim - dim = dim if dim >= 0 else dy.ndim + dim - assert dim in [0, 1], f"dim must be 0 or 1 for 2D, got {dim}" - - # Compute gradient (numerically stable) - dot_product = (dy * y).sum(dim=dim, keepdim=True) - result = y * (dy - dot_product) + """Softmax backward pass.""" + dim = _ensure_backward_inputs(dy, y, dx, dim) + result = _backward_impl(dy, y, dim) + if result.shape != dy.shape: + raise ValueError( + f"softmax backward produced invalid shape {result.shape}, expected {dy.shape}" + ) + if result.dtype != dy.dtype: + raise ValueError( + f"softmax backward produced invalid dtype {result.dtype}, expected {dy.dtype}" + ) + if result.device != dy.device: + raise ValueError( + f"softmax backward produced output on {result.device}, expected {dy.device}" + ) dx.copy_(result) @@ -96,19 +261,11 @@ def backward(ctx, dy): def softmax_online(x: torch.Tensor, dim: int = -1) -> torch.Tensor: - """Online softmax with automatic differentiation support. - - Args: - x: Input tensor of shape (M, N) - dim: Dimension to apply softmax (-1, 0, or 1) - - Returns: - Softmax output tensor of same shape and dtype as input + """Online softmax with autograd support. - Examples: - >>> x = torch.randn(32, 128, device='cuda', requires_grad=True) - >>> y = softmax_online(x, dim=-1) - >>> loss = y.sum() - >>> loss.backward() # Gradients computed automatically + Backend selection is controlled by FORGE_SOFTMAX_IMPL: + - auto (default): try kernel first, fallback to reference + - ref: force reference implementation + - kernel: require kernel implementation (raise if unavailable) """ return SoftmaxOnlineFunction.apply(x, dim) diff --git a/tests/test_softmax_online.py b/tests/test_softmax_online.py index 068b324..487d512 100644 --- a/tests/test_softmax_online.py +++ b/tests/test_softmax_online.py @@ -1,9 +1,31 @@ +import importlib + import pytest import torch from forge_cute_py.ops import softmax_online from forge_cute_py.ref import softmax_online as ref_softmax_online +softmax_online_ops = importlib.import_module("forge_cute_py.ops.softmax_online") + + +@pytest.fixture(autouse=True) +def _reset_softmax_impl_env(monkeypatch): + monkeypatch.delenv("FORGE_SOFTMAX_IMPL", raising=False) + + +def _patch_missing_kernel_module(monkeypatch): + original_import_module = softmax_online_ops.importlib.import_module + + def missing_kernel_module(name, *args, **kwargs): + if name == "forge_cute_py.kernels.softmax_online": + exc = ModuleNotFoundError(f"No module named '{name}'") + exc.name = name + raise exc + return original_import_module(name, *args, **kwargs) + + monkeypatch.setattr(softmax_online_ops.importlib, "import_module", missing_kernel_module) + @pytest.mark.parametrize("shape", [(4, 8), (2, 128)]) @pytest.mark.parametrize("dim", [-1, 0, 1]) @@ -55,6 +77,53 @@ def test_softmax_online_torch_compile(shape, dim, dtype, atol, rtol): torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) +def test_softmax_online_auto_falls_back_to_ref_when_kernel_missing(monkeypatch): + monkeypatch.setenv("FORGE_SOFTMAX_IMPL", "auto") + _patch_missing_kernel_module(monkeypatch) + + x = torch.randn(4, 8, device="cuda", dtype=torch.float16) + y = softmax_online(x, dim=-1) + y_ref = ref_softmax_online(x, dim=-1) + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) + + +def test_softmax_online_kernel_mode_requires_kernel(monkeypatch): + monkeypatch.setenv("FORGE_SOFTMAX_IMPL", "kernel") + _patch_missing_kernel_module(monkeypatch) + + x = torch.randn(4, 8, device="cuda", dtype=torch.float16) + with pytest.raises(NotImplementedError, match="FORGE_SOFTMAX_IMPL=kernel"): + softmax_online(x, dim=-1) + + +def test_softmax_online_ref_mode_skips_kernel_probe(monkeypatch): + import_called = {"value": False} + original_import_module = softmax_online_ops.importlib.import_module + + def tracking_import(name, *args, **kwargs): + if name == "forge_cute_py.kernels.softmax_online": + import_called["value"] = True + raise AssertionError("Kernel module should not be imported in ref mode") + return original_import_module(name, *args, **kwargs) + + monkeypatch.setenv("FORGE_SOFTMAX_IMPL", "ref") + monkeypatch.setattr(softmax_online_ops.importlib, "import_module", tracking_import) + + x = torch.randn(4, 8, device="cuda", dtype=torch.float16) + y = softmax_online(x, dim=-1) + y_ref = ref_softmax_online(x, dim=-1) + + assert import_called["value"] is False + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) + + +def test_softmax_online_rejects_invalid_impl_mode(monkeypatch): + monkeypatch.setenv("FORGE_SOFTMAX_IMPL", "unknown") + x = torch.randn(4, 8, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="FORGE_SOFTMAX_IMPL"): + softmax_online(x, dim=-1) + + @pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32]) def test_softmax_online_properties(input_dtype): x = torch.randn(16, 256, device="cuda", dtype=input_dtype)