From 3c1d2bd4b9e4a8732e2788b74421017f7e6036f2 Mon Sep 17 00:00:00 2001 From: logansg Date: Mon, 23 Mar 2026 14:10:51 -0400 Subject: [PATCH 1/3] Add production cast loaders and autotuned runtime benchmarks --- kernelforge/__init__.py | 25 ++ run_cast.py | 488 +++++++++++++++++----- run_cast_autotune.py | 492 +++++++++++++++++++++++ tests/test_cast_ops_vs_pytorch.py | 391 ++++++++++++++++++ tests/test_load_cast_resnet50.py | 137 +++++++ tests/test_run_cast_autotune_resnet50.py | 135 +++++++ 6 files changed, 1577 insertions(+), 91 deletions(-) create mode 100644 kernelforge/__init__.py create mode 100644 run_cast_autotune.py create mode 100644 tests/test_cast_ops_vs_pytorch.py create mode 100644 tests/test_load_cast_resnet50.py create mode 100644 tests/test_run_cast_autotune_resnet50.py diff --git a/kernelforge/__init__.py b/kernelforge/__init__.py new file mode 100644 index 00000000..ce8c7fb2 --- /dev/null +++ b/kernelforge/__init__.py @@ -0,0 +1,25 @@ +"""KernelForge runtime loading helpers.""" + +from __future__ import annotations + +from run_cast import CastModelRuntime, load_cast + + +def load( + cast_path: str, + *, + device: str | None = None, + model_args: dict | None = None, + no_kernels: bool = False, + opt_level: str = "-O3", +) -> CastModelRuntime: + return load_cast( + cast_path, + model_args=model_args, + no_kernels=no_kernels, + opt_level=opt_level, + device=device, + ) + + +__all__ = ["CastModelRuntime", "load", "load_cast"] diff --git a/run_cast.py b/run_cast.py index 1d3db537..ab07e0bb 100755 --- a/run_cast.py +++ b/run_cast.py @@ -1,14 +1,21 @@ #!/usr/bin/env python3 -"""Run a .cast inference package produced by KernelForge.""" +"""Load and run a .cast inference package produced by KernelForge.""" import argparse +import contextlib +import glob import hashlib import inspect import json import os import re +import shutil +import threading import time import zipfile +from typing import Any, Callable + +import torch.nn as nn def verify_checksums(zf: zipfile.ZipFile) -> None: @@ -34,9 +41,68 @@ def verify_checksums(zf: zipfile.ZipFile) -> None: raise RuntimeError(f"Checksum mismatch for {rel_path}") -def compile_kernel(kernel_cu_path: str, op_name: str, build_dir: str, opt_level: str = "-O0"): +def ensure_cuda_toolkit_env() -> str | None: + """Point CUDA_HOME/CUDACXX/PATH and torch cpp_extension at a working nvcc.""" + + candidates: list[str] = [] + + cuda_home = os.getenv("CUDA_HOME") + if cuda_home: + candidates.append(os.path.abspath(os.path.expanduser(cuda_home))) + + cudacxx = os.getenv("CUDACXX") + if cudacxx: + candidates.append(os.path.dirname(os.path.dirname(os.path.abspath(os.path.expanduser(cudacxx))))) + + nvcc_on_path = shutil.which("nvcc") + if nvcc_on_path: + candidates.append(os.path.dirname(os.path.dirname(os.path.abspath(nvcc_on_path)))) + + candidates.append("/usr/local/cuda") + candidates.extend(sorted(glob.glob("/usr/local/cuda-*"), reverse=True)) + + seen: set[str] = set() + for candidate in candidates: + resolved = os.path.abspath(candidate) + if resolved in seen: + continue + seen.add(resolved) + + nvcc = os.path.join(resolved, "bin", "nvcc") + if not os.path.exists(nvcc): + continue + + os.environ["CUDA_HOME"] = resolved + os.environ["CUDACXX"] = nvcc + + path_entries = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else [] + nvcc_dir = os.path.dirname(nvcc) + if nvcc_dir not in path_entries: + os.environ["PATH"] = nvcc_dir if not path_entries else nvcc_dir + os.pathsep + os.environ["PATH"] + + try: + import torch.utils.cpp_extension as cpp_extension + + cpp_extension.CUDA_HOME = resolved + except Exception: + pass + + return resolved + + return None + + +def compile_kernel(kernel_cu_path: str, op_name: str, build_dir: str, opt_level: str = "-O3"): import re - from torch.utils.cpp_extension import load_inline + + cuda_home = ensure_cuda_toolkit_env() + if cuda_home is None: + raise RuntimeError("No CUDA toolkit with nvcc was found for JIT compilation.") + + import torch.utils.cpp_extension as cpp_extension + + cpp_extension.CUDA_HOME = cuda_home + load_inline = cpp_extension.load_inline with open(kernel_cu_path) as f: cuda_src = f.read() @@ -60,10 +126,218 @@ def compile_kernel(kernel_cu_path: str, op_name: str, build_dir: str, opt_level: ) -def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = False, opt_level: str = "-O0"): - import torch +_PATCH_STATE = threading.local() +_ORIGINAL_FUNCTIONALS: dict[str, Callable[..., Any]] = {} +_ATEN_FALLBACK_CALL_RE = re.compile( + r"\bat::(?:adaptive_avg_pool\d*d|batch_norm|conv\d*d|linear|max_pool\d*d|relu)\s*\(" +) +_KNOWN_FAST_OPS = { + "torch_nn_functional_linear", + "torch_nn_functional_max_pool2d", + "torch_nn_functional_relu", + "torch_nn_functional_batch_norm", +} +_FOCUS_OPS = { + "torch_nn_functional_linear", + "torch_nn_functional_max_pool2d", + "torch_nn_functional_batch_norm", +} +_KERNEL_POLICIES = ("all", "skip_aten", "known_fast", "focus_ops") + + +def _patch_stack() -> list[dict[str, Callable[..., Any]]]: + stack = getattr(_PATCH_STATE, "stack", None) + if stack is None: + stack = [] + _PATCH_STATE.stack = stack + return stack + + +@contextlib.contextmanager +def _activate_functional_patches( + patch_map: dict[str, Callable[..., Any]], +): + stack = _patch_stack() + stack.append(patch_map) + try: + yield + finally: + stack.pop() + + +def _ensure_functional_dispatch(fn_attr: str) -> Callable[..., Any] | None: import torch.nn.functional as F + current = getattr(F, fn_attr, None) + if current is None: + return None + + original = _ORIGINAL_FUNCTIONALS.get(fn_attr) + if original is not None: + return original + + original = current + _ORIGINAL_FUNCTIONALS[fn_attr] = original + + def dispatched(*args, **kwargs): + stack = getattr(_PATCH_STATE, "stack", None) + if stack: + patch = stack[-1].get(fn_attr) + if patch is not None: + return patch(*args, **kwargs) + return original(*args, **kwargs) + + setattr(F, fn_attr, dispatched) + return original + + +def _launch_arity(kernel_cu: str, ext: Any) -> int | None: + n_launch = None + if os.path.exists(kernel_cu): + try: + with open(kernel_cu) as handle: + cu_src = handle.read() + match = re.search(r"torch::Tensor\s+launch\s*\(([^)]*)\)", cu_src) + if match: + params = [part.strip() for part in match.group(1).split(",") if part.strip()] + n_launch = len(params) + except Exception: + pass + if n_launch is None: + try: + n_launch = len(inspect.signature(ext.launch).parameters) + except Exception: + pass + return n_launch + + +def _kernel_calls_aten_fallback(kernel_cu: str) -> bool: + if not os.path.exists(kernel_cu): + return False + try: + with open(kernel_cu) as handle: + cuda_src = handle.read() + except Exception: + return False + return bool(_ATEN_FALLBACK_CALL_RE.search(cuda_src)) + + +def _kernel_policy_skip_reason(op_name: str, kernel_cu: str, kernel_policy: str) -> str | None: + if kernel_policy == "all": + return None + + uses_aten_fallback = _kernel_calls_aten_fallback(kernel_cu) + + if kernel_policy == "skip_aten": + if uses_aten_fallback: + return "kernel source falls back to ATen" + return None + + if kernel_policy == "known_fast": + if uses_aten_fallback: + return "kernel source falls back to ATen" + if op_name not in _KNOWN_FAST_OPS: + return "op is not in the known-fast allowlist" + return None + + if kernel_policy == "focus_ops": + if uses_aten_fallback: + return "kernel source falls back to ATen" + if op_name not in _FOCUS_OPS: + return "op is not in the focus-ops allowlist" + return None + + raise ValueError(f"Unknown kernel policy: {kernel_policy}") + + +def _build_functional_patch( + *, + op_name: str, + ext: Any, + orig_fn: Callable[..., Any], + n_launch: int | None, + orig_params: list[str] | None, +) -> Callable[..., Any]: + import torch + + def patched(*args, **kwargs): + try: + if orig_params is not None: + resolved = {orig_params[i]: value for i, value in enumerate(args) if i < len(orig_params)} + resolved.update(kwargs) + ordered = [resolved.get(name) for name in orig_params] + else: + ordered = list(args) + + limit = n_launch if n_launch is not None else len(ordered) + call_args: list[Any] = [] + tensor_args: list[torch.Tensor] = [] + for value in ordered[:limit]: + if isinstance(value, torch.Tensor): + tensor_args.append(value) + call_args.append(value.contiguous()) + else: + call_args.append(value) + + if not tensor_args: + return orig_fn(*args, **kwargs) + + if any(not tensor.is_cuda for tensor in tensor_args): + return orig_fn(*args, **kwargs) + + first_device = tensor_args[0].device + if any(tensor.device != first_device for tensor in tensor_args): + return orig_fn(*args, **kwargs) + + return ext.launch(*call_args) + except Exception: + return orig_fn(*args, **kwargs) + + patched.__name__ = f"cast_patch_{op_name}" + return patched + + +class CastModelRuntime(nn.Module): + """A thin nn.Module wrapper that activates cast kernel patches per forward.""" + + def __init__(self, model, functional_patches: dict[str, Callable[..., Any]]): + import torch.nn as nn + + if not isinstance(model, nn.Module): + raise TypeError("model must be an nn.Module") + super().__init__() + self.model = model + self._cast_functional_patches = functional_patches + + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError: + model = self._modules.get("model") + if model is None: + raise + return getattr(model, name) + + def forward(self, *args, **kwargs): + if not self._cast_functional_patches: + return self.model(*args, **kwargs) + with _activate_functional_patches(self._cast_functional_patches): + return self.model(*args, **kwargs) + + def run(self, *args, **kwargs): + return self(*args, **kwargs) + + +def load_cast( + cast_path: str, + model_args: dict | None = None, + no_kernels: bool = False, + opt_level: str = "-O3", + device: str | None = None, + kernel_policy: str = "all", +): + import torch + cast_path = os.path.abspath(cast_path) cache_key = hashlib.sha256(open(cast_path, "rb").read()).hexdigest() cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "cast", cache_key) @@ -97,6 +371,12 @@ def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = os.makedirs(build_dir, exist_ok=True) _F_PREFIX = "torch_nn_functional_" + functional_patches: dict[str, Callable[..., Any]] = {} + + if kernel_policy not in _KERNEL_POLICIES: + raise ValueError( + f"Unknown kernel policy '{kernel_policy}'. Expected one of: {', '.join(_KERNEL_POLICIES)}" + ) for op in manifest["ops"]: op_name = op["name"] @@ -110,27 +390,40 @@ def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = print(f" [WARN] CUDA not available — skipping kernel for {op_name}") continue + skip_reason = _kernel_policy_skip_reason(op_name, kernel_cu, kernel_policy) + if skip_reason: + print(f" [kernel-policy={kernel_policy}] Skipping kernel for {op_name}: {skip_reason}") + continue + # Try precompiled .so for the current GPU first gpu_sm = "sm_{0}{1}".format(*torch.cuda.get_device_capability()) precompiled = op.get("precompiled", {}) so_rel = precompiled.get(gpu_sm) so_path = os.path.join(cache_dir, so_rel) if so_rel else None + ext = None if so_path and os.path.exists(so_path): - import importlib.util as _ilu - _spec = _ilu.spec_from_file_location(op_name, so_path) - ext = _ilu.module_from_spec(_spec) - _spec.loader.exec_module(ext) # type: ignore[union-attr] - print(f" Loaded precompiled {op_name} ({gpu_sm})") - else: + try: + import importlib.util as _ilu + + _spec = _ilu.spec_from_file_location(op_name, so_path) + ext = _ilu.module_from_spec(_spec) + _spec.loader.exec_module(ext) # type: ignore[union-attr] + print(f" Loaded precompiled {op_name} ({gpu_sm})") + except Exception as exc: + print(f" [WARN] Failed to load precompiled {op_name} ({gpu_sm}): {exc}") + + if ext is None: if so_rel: print(f" [WARN] Precompiled .so not found for {gpu_sm}, falling back to JIT") if not os.path.exists(kernel_cu): print(f" [WARN] No kernel.cu for {op_name}, skipping") continue - if "CUDA_HOME" not in os.environ: - os.environ["CUDA_HOME"] = "/usr/local/cuda-12.1" - ext = compile_kernel(kernel_cu, op_name, build_dir, opt_level=opt_level) + try: + ext = compile_kernel(kernel_cu, op_name, build_dir, opt_level=opt_level) + except Exception as exc: + print(f" [WARN] Failed to prepare kernel for {op_name}, using native PyTorch: {exc}") + continue # Generic patch: decode torch.nn.functional. from the op_name convention. if not op_name.startswith(_F_PREFIX): @@ -138,67 +431,32 @@ def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = continue fn_attr = op_name[len(_F_PREFIX):] - orig_fn = getattr(F, fn_attr, None) - if orig_fn is None: + original = _ensure_functional_dispatch(fn_attr) + if original is None: print(f" [WARN] torch.nn.functional.{fn_attr} not found, skipping patch") continue - # Determine how many args launch() expects by parsing the kernel source. - # This is more reliable than inspect.signature on C extensions. - n_launch = None - if os.path.exists(kernel_cu): - try: - with open(kernel_cu) as _f: - _cu_src = _f.read() - _m = re.search(r"torch::Tensor\s+launch\s*\(([^)]*)\)", _cu_src) - if _m: - _params = [p.strip() for p in _m.group(1).split(",") if p.strip()] - n_launch = len(_params) - except Exception: - pass - if n_launch is None: - try: - n_launch = len(inspect.signature(ext.launch).parameters) - except Exception: - pass + n_launch = _launch_arity(kernel_cu, ext) # Resolve the original function's parameter names so we can handle # both positional and keyword call sites correctly. try: - orig_params = list(inspect.signature(orig_fn).parameters.keys()) + orig_params = list(inspect.signature(original).parameters.keys()) except Exception: orig_params = None - def _make_patch(ext=ext, orig=orig_fn, n=n_launch, params=orig_params): - def patched(*args, **kwargs): - try: - # Resolve all argument values by name, handling mixed positional/keyword callers. - if params is not None: - resolved = {params[i]: v for i, v in enumerate(args) if i < len(params)} - resolved.update(kwargs) - ordered = [resolved.get(p) for p in params] - else: - ordered = list(args) - # Truncate to what launch() expects, move tensors to CUDA. - limit = n if n is not None else len(ordered) - call_args = [] - for val in ordered[:limit]: - if isinstance(val, torch.Tensor): - call_args.append(val.cuda().contiguous()) - else: - call_args.append(val) - return ext.launch(*call_args) - except Exception: - return orig(*args, **kwargs) - return patched - - setattr(F, fn_attr, _make_patch()) - print(f" Patched torch.nn.functional.{fn_attr} → {op_name}") + functional_patches[fn_attr] = _build_functional_patch( + op_name=op_name, + ext=ext, + orig_fn=original, + n_launch=n_launch, + orig_params=orig_params, + ) + print(f" Registered runtime patch torch.nn.functional.{fn_attr} → {op_name}") # 5. Load model class from model.py import importlib.util import sys - import torch.nn as nn model_py = os.path.join(cache_dir, "model.py") spec = importlib.util.spec_from_file_location("cast_model", model_py) @@ -273,10 +531,53 @@ def patched(*args, **kwargs): model.load_state_dict(state_dict) model.eval() - return model + runtime_model = CastModelRuntime(model, functional_patches) + runtime_model.eval() + if device: + runtime_model = runtime_model.to(device) + + return runtime_model + + +def _parse_shape(raw: str) -> tuple[int, ...]: + parts = [part.strip() for part in raw.split(",") if part.strip()] + if not parts: + raise ValueError("input shape must contain at least one dimension") + dims = tuple(int(part) for part in parts) + if any(dim <= 0 for dim in dims): + raise ValueError("input shape dimensions must be positive integers") + return dims + + +def _benchmark_loaded_model(model, device: str, runs: int, input_shape: tuple[int, ...]) -> None: + import torch + + dummy = torch.randn(*input_shape, device=device) + print(f"Running {runs} inference pass(es) with input shape {list(dummy.shape)} ...") + + with torch.inference_mode(): + _ = model(dummy) + if device == "cuda": + torch.cuda.synchronize() + + t0 = time.perf_counter() + for _ in range(runs): + out = model(dummy) + if device == "cuda": + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) / runs * 1000 + + logits = out.logits if hasattr(out, "logits") else out + if hasattr(logits, "shape"): + print(f"Output shape : {list(logits.shape)}") + print(f"Average latency : {elapsed_ms:.2f} ms") + if getattr(logits, "ndim", 0) >= 2 and logits.shape[0] >= 1: + top5 = logits[0].topk(min(5, logits.shape[-1])) + print(f"Top-5 indices : {top5.indices.tolist()}") + print(f"Top-5 scores : {[f'{v:.4f}' for v in top5.values.tolist()]}") -def main() -> None: +def main(*, default_kernel_policy: str = "all") -> None: import torch parser = argparse.ArgumentParser(description="Run a KernelForge .cast inference package") @@ -286,7 +587,6 @@ def main() -> None: default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (default: cuda if available)", ) - parser.add_argument("--runs", type=int, default=5, help="Inference passes for timing") parser.add_argument( "--model-args", metavar="JSON", @@ -301,9 +601,32 @@ def main() -> None: ) parser.add_argument( "--opt-level", - default="-O0", + default="-O3", choices=["-O0", "-O1", "-O2", "-O3"], - help="NVCC optimisation level for JIT compilation (default: -O0).", + help="NVCC optimisation level for JIT compilation (default: -O3).", + ) + parser.add_argument( + "--kernel-policy", + default=default_kernel_policy, + choices=list(_KERNEL_POLICIES), + help=( + "Kernel selection policy: " + "'all' loads every exported kernel, " + "'skip_aten' skips kernels that just call back into ATen, " + "'known_fast' keeps only the current known-fast allowlist, " + "'focus_ops' keeps only the smallest fast-op subset." + ), + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Run a dummy-input benchmark after loading the model.", + ) + parser.add_argument("--runs", type=int, default=5, help="Inference passes for timing with --benchmark") + parser.add_argument( + "--input-shape", + default="1,3,224,224", + help="Comma-separated dummy input shape for --benchmark (default: 1,3,224,224).", ) args = parser.parse_args() @@ -313,33 +636,16 @@ def main() -> None: model_args=extra_model_args, no_kernels=args.no_kernels, opt_level=args.opt_level, + device=args.device, + kernel_policy=args.kernel_policy, ) - device = args.device - model = model.to(device) - print(f"\nModel ready on {device}") - - dummy = torch.randn(1, 3, 224, 224, device=device) - print(f"Running {args.runs} inference pass(es) with input shape {list(dummy.shape)} ...") + print(f"\nModel ready on {args.device}") + print(f"Kernel policy : {args.kernel_policy}") + print("Import load_cast(...) in production code to load a .cast as a normal model object.") - with torch.no_grad(): - # warmup - _ = model(dummy) - if device == "cuda": - torch.cuda.synchronize() - - t0 = time.perf_counter() - for _ in range(args.runs): - out = model(dummy) - if device == "cuda": - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) / args.runs * 1000 - - logits = out.logits if hasattr(out, "logits") else out - print(f"Output shape : {list(logits.shape)}") - print(f"Average latency : {elapsed_ms:.2f} ms") - top5 = logits[0].topk(5) - print(f"Top-5 indices : {top5.indices.tolist()}") - print(f"Top-5 scores : {[f'{v:.4f}' for v in top5.values.tolist()]}") + if args.benchmark: + input_shape = _parse_shape(args.input_shape) + _benchmark_loaded_model(model, args.device, args.runs, input_shape) if __name__ == "__main__": diff --git a/run_cast_autotune.py b/run_cast_autotune.py new file mode 100644 index 00000000..bb43154d --- /dev/null +++ b/run_cast_autotune.py @@ -0,0 +1,492 @@ +#!/usr/bin/env python3 +"""Autotuning .cast runtime that keeps only end-to-end beneficial kernels.""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import shutil +import tempfile +import time +from pathlib import Path +from typing import Any + +import torch.nn as nn + +import run_cast + +_DEFAULT_MARGIN = 0.01 + + +def ensure_cuda_toolkit_env() -> str | None: + """Point CUDA_HOME/CUDACXX/PATH at a working nvcc if one is available.""" + + candidates: list[Path] = [] + + cuda_home = os.getenv("CUDA_HOME") + if cuda_home: + candidates.append(Path(cuda_home)) + + cudacxx = os.getenv("CUDACXX") + if cudacxx: + candidates.append(Path(cudacxx).expanduser().resolve().parent.parent) + + nvcc_on_path = shutil.which("nvcc") + if nvcc_on_path: + candidates.append(Path(nvcc_on_path).resolve().parent.parent) + + candidates.append(Path("/usr/local/cuda")) + candidates.extend(sorted(Path("/usr/local").glob("cuda-*"), reverse=True)) + + seen: set[Path] = set() + for candidate in candidates: + resolved = candidate.resolve(strict=False) + if resolved in seen: + continue + seen.add(resolved) + + nvcc = resolved / "bin" / "nvcc" + if not nvcc.exists(): + continue + + os.environ["CUDA_HOME"] = str(resolved) + os.environ.setdefault("CUDACXX", str(nvcc)) + + path_entries = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else [] + nvcc_dir = str(nvcc.parent) + if nvcc_dir not in path_entries: + os.environ["PATH"] = nvcc_dir if not path_entries else nvcc_dir + os.pathsep + os.environ["PATH"] + return str(resolved) + + return None + + +def _normalize_input(value: Any) -> Any: + import torch + + if isinstance(value, torch.Tensor): + return { + "type": "tensor", + "shape": list(value.shape), + "dtype": str(value.dtype), + "device": str(value.device), + "requires_grad": bool(value.requires_grad), + } + if isinstance(value, dict): + return {str(key): _normalize_input(val) for key, val in sorted(value.items(), key=lambda item: str(item[0]))} + if isinstance(value, (list, tuple)): + return [_normalize_input(item) for item in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return {"type": type(value).__name__, "repr": repr(value)} + + +def _input_signature(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + payload = { + "args": _normalize_input(args), + "kwargs": _normalize_input(kwargs), + } + return json.dumps(payload, sort_keys=True, separators=(",", ":")) + + +def _first_tensor_device(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + import torch + + def visit(value: Any) -> str | None: + if isinstance(value, torch.Tensor): + return value.device.type + if isinstance(value, dict): + for item in value.values(): + found = visit(item) + if found is not None: + return found + return None + if isinstance(value, (list, tuple)): + for item in value: + found = visit(item) + if found is not None: + return found + return None + return None + + found = visit(args) + if found is not None: + return found + found = visit(kwargs) + return found or "cpu" + + +def _flatten_output(value: Any) -> Any: + import torch + + if isinstance(value, torch.Tensor): + return ("tensor", value) + if hasattr(value, "to_tuple") and callable(value.to_tuple): + return ("tuple_like", tuple(_flatten_output(item) for item in value.to_tuple())) + if isinstance(value, dict): + return ("dict", tuple((str(key), _flatten_output(val)) for key, val in sorted(value.items(), key=lambda item: str(item[0])))) + if isinstance(value, (list, tuple)): + return ("seq", tuple(_flatten_output(item) for item in value)) + return ("value", value) + + +def _outputs_match(reference: Any, candidate: Any) -> bool: + import torch + + ref = _flatten_output(reference) + cand = _flatten_output(candidate) + + def compare(left: Any, right: Any) -> bool: + if left[0] != right[0]: + return False + kind = left[0] + if kind == "tensor": + try: + atol = 1e-3 if left[1].dtype in (torch.float16, torch.bfloat16) else 1e-4 + rtol = 1e-3 if left[1].dtype in (torch.float16, torch.bfloat16) else 1e-4 + torch.testing.assert_close(left[1], right[1], atol=atol, rtol=rtol) + return True + except Exception: + return False + if kind in {"tuple_like", "seq"}: + if len(left[1]) != len(right[1]): + return False + return all(compare(l_item, r_item) for l_item, r_item in zip(left[1], right[1])) + if kind == "dict": + if len(left[1]) != len(right[1]): + return False + return all( + l_key == r_key and compare(l_val, r_val) + for (l_key, l_val), (r_key, r_val) in zip(left[1], right[1]) + ) + return left[1] == right[1] + + return compare(ref, cand) + + +def _cast_cache_key(cast_path: str) -> str: + cast_bytes = Path(cast_path).read_bytes() + return hashlib.sha256(cast_bytes).hexdigest() + + +class AutotunedCastModelRuntime(nn.Module): + """A runtime wrapper that benchmarks patch subsets and keeps only the winners.""" + + def __init__( + self, + runtime_model: run_cast.CastModelRuntime, + *, + cast_path: str, + opt_level: str, + base_policy: str, + warmup_runs: int, + timed_runs: int, + improvement_margin: float, + cache_enabled: bool, + ) -> None: + if not isinstance(runtime_model, run_cast.CastModelRuntime): + raise TypeError("runtime_model must be a CastModelRuntime") + super().__init__() + self.model = runtime_model + self._all_patches = dict(runtime_model._cast_functional_patches) + self._active_patch_names: list[str] = list(self._all_patches.keys()) + self._signature_patch_names: dict[str, list[str]] = {} + self._signature_patch_maps: dict[str, dict[str, Any]] = {} + self._opt_level = opt_level + self._base_policy = base_policy + self._warmup_runs = warmup_runs + self._timed_runs = timed_runs + self._improvement_margin = improvement_margin + self._cache_enabled = cache_enabled + self._cast_key = _cast_cache_key(cast_path) + self._cache_path = self._build_cache_path(self._cast_key) + self._cache_data = self._load_cache() + + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError: + model = self._modules.get("model") + if model is None: + raise + return getattr(model, name) + + @staticmethod + def _build_cache_path(cast_key: str) -> Path: + cache_root = Path.home() / ".cache" / "cast_autotune" + gpu_name = "cpu" + if os.getenv("CUDA_VISIBLE_DEVICES", "") == "": + pass + try: + import torch + + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0).replace(" ", "_") + except Exception: + pass + return cache_root / f"{cast_key}_{gpu_name}.json" + + def _load_cache(self) -> dict[str, Any]: + if not self._cache_enabled or not self._cache_path.exists(): + return {} + try: + return json.loads(self._cache_path.read_text()) + except Exception: + return {} + + def _persist_cache(self) -> None: + if not self._cache_enabled: + return + self._cache_path.parent.mkdir(parents=True, exist_ok=True) + payload = json.dumps(self._cache_data, indent=2, sort_keys=True) + with tempfile.NamedTemporaryFile("w", dir=self._cache_path.parent, delete=False) as handle: + handle.write(payload) + temp_path = Path(handle.name) + temp_path.replace(self._cache_path) + + def _select_patch_names(self, names: list[str]) -> dict[str, Any]: + return {name: self._all_patches[name] for name in names if name in self._all_patches} + + def _set_active_patch_names(self, names: list[str]) -> None: + self._active_patch_names = list(names) + self.model._cast_functional_patches = self._select_patch_names(names) + + def _run_once(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return self.model.run(*args, **kwargs) + + def _benchmark_patch_names(self, patch_names: list[str], args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, float]: + import torch + + self._set_active_patch_names(patch_names) + device_type = _first_tensor_device(args, kwargs) + + with torch.inference_mode(): + output = None + for _ in range(self._warmup_runs): + output = self._run_once(args, kwargs) + if device_type == "cuda": + torch.cuda.synchronize() + + if device_type == "cuda": + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(self._timed_runs): + output = self._run_once(args, kwargs) + end.record() + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) / self._timed_runs + else: + t0 = time.perf_counter() + for _ in range(self._timed_runs): + output = self._run_once(args, kwargs) + elapsed_ms = (time.perf_counter() - t0) / self._timed_runs * 1000.0 + + return output, elapsed_ms + + def autotune(self, *args: Any, **kwargs: Any) -> list[str]: + signature = _input_signature(args, kwargs) + cached = self._signature_patch_names.get(signature) + if cached is not None: + return list(cached) + + cache_entry = self._cache_data.get(signature) + if isinstance(cache_entry, list): + names = [name for name in cache_entry if name in self._all_patches] + self._signature_patch_names[signature] = names + self._signature_patch_maps[signature] = self._select_patch_names(names) + self._set_active_patch_names(names) + return list(names) + + baseline_output, baseline_ms = self._benchmark_patch_names([], args, kwargs) + + current_names: list[str] = [] + current_ms = baseline_ms + current_output = baseline_output + remaining = list(self._all_patches.keys()) + + while remaining: + best_candidate_name: str | None = None + best_candidate_output = None + best_candidate_ms = current_ms + + for candidate in remaining: + candidate_names = current_names + [candidate] + candidate_output, candidate_ms = self._benchmark_patch_names(candidate_names, args, kwargs) + if not _outputs_match(baseline_output, candidate_output): + continue + if candidate_ms < best_candidate_ms: + best_candidate_name = candidate + best_candidate_output = candidate_output + best_candidate_ms = candidate_ms + + if best_candidate_name is None: + break + if best_candidate_ms >= current_ms * (1.0 - self._improvement_margin): + break + + current_names.append(best_candidate_name) + remaining.remove(best_candidate_name) + current_ms = best_candidate_ms + current_output = best_candidate_output + + improved = True + while improved and current_names: + improved = False + for candidate in list(current_names): + trial_names = [name for name in current_names if name != candidate] + trial_output, trial_ms = self._benchmark_patch_names(trial_names, args, kwargs) + if not _outputs_match(baseline_output, trial_output): + continue + if trial_ms < current_ms * (1.0 - self._improvement_margin): + current_names = trial_names + current_ms = trial_ms + current_output = trial_output + improved = True + break + + if not _outputs_match(baseline_output, current_output): + current_names = [] + + self._signature_patch_names[signature] = list(current_names) + self._signature_patch_maps[signature] = self._select_patch_names(current_names) + self._cache_data[signature] = list(current_names) + self._set_active_patch_names(current_names) + self._persist_cache() + return list(current_names) + + def active_patch_names(self, *args: Any, **kwargs: Any) -> list[str]: + if args or kwargs: + return self.autotune(*args, **kwargs) + return list(self._active_patch_names) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + signature = _input_signature(args, kwargs) + patch_names = self._signature_patch_names.get(signature) + if patch_names is None: + patch_names = self.autotune(*args, **kwargs) + self._set_active_patch_names(patch_names) + return self.model.run(*args, **kwargs) + + def run(self, *args: Any, **kwargs: Any) -> Any: + return self(*args, **kwargs) + + +def load_cast( + cast_path: str, + *, + device: str | None = None, + model_args: dict | None = None, + opt_level: str = "-O3", + base_policy: str = "skip_aten", + warmup_runs: int = 3, + timed_runs: int = 5, + improvement_margin: float = _DEFAULT_MARGIN, + cache_enabled: bool = True, +) -> AutotunedCastModelRuntime: + ensure_cuda_toolkit_env() + + runtime_model = run_cast.load_cast( + cast_path, + model_args=model_args, + no_kernels=False, + opt_level=opt_level, + device=device, + kernel_policy=base_policy, + ) + + return AutotunedCastModelRuntime( + runtime_model, + cast_path=cast_path, + opt_level=opt_level, + base_policy=base_policy, + warmup_runs=warmup_runs, + timed_runs=timed_runs, + improvement_margin=improvement_margin, + cache_enabled=cache_enabled, + ) + + +def _parse_shape(raw: str) -> tuple[int, ...]: + return run_cast._parse_shape(raw) + + +def _benchmark_loaded_model(model: AutotunedCastModelRuntime, device: str, runs: int, input_shape: tuple[int, ...]) -> None: + import torch + + dummy = torch.randn(*input_shape, device=device) + tuned = model.autotune(dummy) + print(f"Autotuned kernels: {tuned}") + run_cast._benchmark_loaded_model(model, device, runs, input_shape) + + +def main() -> None: + import torch + + parser = argparse.ArgumentParser(description="Run a KernelForge .cast inference package with end-to-end autotuning") + parser.add_argument("cast_file", help="Path to .cast file") + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run on (default: cuda if available)", + ) + parser.add_argument( + "--model-args", + metavar="JSON", + default=None, + help='JSON config to instantiate the model, e.g. \'{"model_type":"resnet",...}\'.', + ) + parser.add_argument( + "--opt-level", + default="-O3", + choices=["-O0", "-O1", "-O2", "-O3"], + help="NVCC optimisation level for JIT compilation (default: -O3).", + ) + parser.add_argument( + "--base-policy", + default="skip_aten", + choices=list(run_cast._KERNEL_POLICIES), + help="Candidate kernel policy before autotuning (default: skip_aten).", + ) + parser.add_argument("--warmup-runs", type=int, default=3, help="Warmup runs for autotuning benchmarks.") + parser.add_argument("--timed-runs", type=int, default=5, help="Timed runs for autotuning benchmarks.") + parser.add_argument( + "--improvement-margin", + type=float, + default=_DEFAULT_MARGIN, + help="Minimum fractional improvement required to keep a kernel (default: 0.01).", + ) + parser.add_argument("--no-cache", action="store_true", help="Disable autotune cache persistence.") + parser.add_argument("--benchmark", action="store_true", help="Run a dummy-input benchmark after loading the model.") + parser.add_argument("--runs", type=int, default=5, help="Inference passes for timing with --benchmark") + parser.add_argument( + "--input-shape", + default="1,3,224,224", + help="Comma-separated dummy input shape for --benchmark (default: 1,3,224,224).", + ) + args = parser.parse_args() + + extra_model_args = json.loads(args.model_args) if args.model_args else None + model = load_cast( + args.cast_file, + device=args.device, + model_args=extra_model_args, + opt_level=args.opt_level, + base_policy=args.base_policy, + warmup_runs=args.warmup_runs, + timed_runs=args.timed_runs, + improvement_margin=args.improvement_margin, + cache_enabled=not args.no_cache, + ) + print(f"\nModel ready on {args.device}") + print(f"Candidate policy: {args.base_policy}") + + if args.benchmark: + input_shape = _parse_shape(args.input_shape) + _benchmark_loaded_model(model, args.device, args.runs, input_shape) + + +if __name__ == "__main__": + main() diff --git a/tests/test_cast_ops_vs_pytorch.py b/tests/test_cast_ops_vs_pytorch.py new file mode 100644 index 00000000..d83685d2 --- /dev/null +++ b/tests/test_cast_ops_vs_pytorch.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import hashlib +import importlib.util +import inspect +import json +import os +import shutil +import sys +import time +import zipfile +from pathlib import Path +from typing import Any + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +torch = pytest.importorskip("torch") +import torch.nn.functional as F + +import kernelforge +from run_cast import _launch_arity, compile_kernel, verify_checksums + +_F_PREFIX = "torch_nn_functional_" +_DEFAULT_SKIP_OPS = {"torch_nn_functional_conv2d"} +_FUNCTIONAL_PARAM_NAMES: dict[str, list[str]] = { + "adaptive_avg_pool2d": ["input", "output_size"], + "batch_norm": ["input", "running_mean", "running_var", "weight", "bias", "training", "momentum", "eps"], + "conv2d": ["input", "weight", "bias", "stride", "padding", "dilation", "groups"], + "linear": ["input", "weight", "bias"], + "max_pool2d": ["input", "kernel_size", "stride", "padding", "dilation", "ceil_mode", "return_indices"], + "relu": ["input", "inplace"], +} +_FUNCTIONAL_DEFAULTS: dict[str, dict[str, Any]] = { + "batch_norm": {"weight": None, "bias": None, "training": False, "momentum": 0.1, "eps": 1e-5}, + "conv2d": {"bias": None, "stride": 1, "padding": 0, "dilation": 1, "groups": 1}, + "linear": {"bias": None}, + "max_pool2d": {"stride": None, "padding": 0, "dilation": 1, "ceil_mode": False, "return_indices": False}, + "relu": {"inplace": False}, +} + + +def _candidate_cast_paths() -> list[Path]: + candidates: list[Path] = [] + + env_path = os.getenv("RESNET50_CAST_PATH") + if env_path: + candidates.append(Path(env_path).expanduser()) + + default_export = ( + REPO_ROOT + / "kernels" + / "projects" + / "oioioio - RTX 3050 Laptop GPU" + / "exports" + / "oioioio - RTX 3050 Laptop GPU.cast" + ) + candidates.append(default_export) + + exports_root = REPO_ROOT / "kernels" / "projects" + candidates.extend(sorted(exports_root.glob("*/exports/*.cast"))) + + unique: list[Path] = [] + seen: set[Path] = set() + for path in candidates: + resolved = path.resolve(strict=False) + if resolved in seen: + continue + seen.add(resolved) + unique.append(path) + return unique + + +def _find_resnet50_cast() -> Path: + for path in _candidate_cast_paths(): + if path.exists(): + return path + pytest.skip("No ResNet-50 .cast export found. Set RESNET50_CAST_PATH to a cast file.") + + +def _cache_dir_for(cast_path: Path) -> Path: + cache_key = hashlib.sha256(cast_path.read_bytes()).hexdigest() + return Path.home() / ".cache" / "cast" / cache_key + + +def _extract_cast(cast_path: Path) -> tuple[Path, dict[str, Any]]: + cache_dir = _cache_dir_for(cast_path) + with zipfile.ZipFile(cast_path) as zf: + verify_checksums(zf) + if not cache_dir.is_dir(): + zf.extractall(cache_dir) + manifest = json.loads(zf.read("manifest.json")) + return cache_dir, manifest + + +def _ensure_cuda_toolkit() -> None: + candidates: list[Path] = [] + + cuda_home = os.getenv("CUDA_HOME") + if cuda_home: + candidates.append(Path(cuda_home)) + + cudacxx = os.getenv("CUDACXX") + if cudacxx: + candidates.append(Path(cudacxx).resolve().parent.parent) + + nvcc_on_path = shutil.which("nvcc") + if nvcc_on_path: + candidates.append(Path(nvcc_on_path).resolve().parent.parent) + + candidates.append(Path("/usr/local/cuda")) + candidates.extend(sorted(Path("/usr/local").glob("cuda-*"), reverse=True)) + + seen: set[Path] = set() + for candidate in candidates: + resolved = candidate.resolve(strict=False) + if resolved in seen: + continue + seen.add(resolved) + + nvcc = resolved / "bin" / "nvcc" + if not nvcc.exists(): + continue + + os.environ["CUDA_HOME"] = str(resolved) + os.environ.setdefault("CUDACXX", str(nvcc)) + + path_entries = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else [] + nvcc_dir = str(nvcc.parent) + if nvcc_dir not in path_entries: + os.environ["PATH"] = nvcc_dir if not path_entries else nvcc_dir + os.pathsep + os.environ["PATH"] + return + + pytest.skip("No CUDA toolkit with nvcc was found for JIT kernel benchmarking.") + + +def _capture_functional_calls( + model: torch.nn.Module, + pixel_values: torch.Tensor, + fn_attrs: list[str], +) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]: + captures: dict[str, dict[str, Any]] = {} + originals: dict[str, Any] = {} + + for fn_attr in fn_attrs: + original = getattr(F, fn_attr, None) + if original is not None: + originals[fn_attr] = original + + def make_wrapper(name: str, original): + def wrapped(*args, **kwargs): + captures.setdefault(name, {"args": args, "kwargs": kwargs}) + return original(*args, **kwargs) + + wrapped.__name__ = f"capture_{name}" + return wrapped + + for fn_attr, original in originals.items(): + setattr(F, fn_attr, make_wrapper(fn_attr, original)) + + try: + with torch.no_grad(): + _ = model.run(pixel_values=pixel_values) + if pixel_values.is_cuda: + torch.cuda.synchronize() + finally: + for fn_attr, original in originals.items(): + setattr(F, fn_attr, original) + + return captures, originals + + +def _load_extension( + cache_dir: Path, + op: dict[str, Any], + *, + opt_level: str, +) -> tuple[Any, Path]: + op_name = op["name"] + kernel_cu = cache_dir / op["cuda_source"] + gpu_sm = "sm_{0}{1}".format(*torch.cuda.get_device_capability()) + precompiled = op.get("precompiled", {}) + so_rel = precompiled.get(gpu_sm) + so_path = cache_dir / so_rel if so_rel else None + + if so_path and so_path.exists(): + spec = importlib.util.spec_from_file_location(op_name, so_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Could not load precompiled module for {op_name}") + ext = importlib.util.module_from_spec(spec) + spec.loader.exec_module(ext) + return ext, kernel_cu + + if not kernel_cu.exists(): + raise RuntimeError(f"No kernel source found for {op_name}") + + _ensure_cuda_toolkit() + build_dir = cache_dir / "build_kernel_bench" + build_dir.mkdir(parents=True, exist_ok=True) + ext = compile_kernel(str(kernel_cu), op_name, str(build_dir), opt_level=opt_level) + return ext, kernel_cu + + +def _prepare_launch_args( + fn_attr: str, + original, + ext: Any, + kernel_cu: Path, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> list[Any]: + param_names = _FUNCTIONAL_PARAM_NAMES.get(fn_attr) + if param_names is not None: + resolved = {param_names[i]: value for i, value in enumerate(args) if i < len(param_names)} + resolved.update(kwargs) + defaults = _FUNCTIONAL_DEFAULTS.get(fn_attr, {}) + ordered = [resolved.get(name, defaults.get(name)) for name in param_names] + if fn_attr == "max_pool2d" and len(ordered) >= 3 and ordered[2] is None: + ordered[2] = ordered[1] + else: + try: + signature = inspect.signature(original) + except Exception: + signature = None + + if signature is not None: + bound = signature.bind_partial(*args, **kwargs) + bound.apply_defaults() + ordered = [bound.arguments.get(name) for name in signature.parameters.keys()] + else: + ordered = list(args) + + if not ordered: + ordered = list(args) + + launch_arity = _launch_arity(str(kernel_cu), ext) + try: + ext_arity = len(inspect.signature(ext.launch).parameters) + except Exception: + ext_arity = None + + if launch_arity is None: + limit = ext_arity if ext_arity is not None else len(ordered) + elif ext_arity is None: + limit = launch_arity + else: + limit = max(launch_arity, ext_arity) + + launch_args: list[Any] = [] + for value in ordered[:limit]: + if isinstance(value, torch.Tensor): + launch_args.append(value.contiguous()) + else: + launch_args.append(value) + return launch_args + + +def _benchmark_callable( + fn, + *, + device: str, + warmup_runs: int = 3, + timed_runs: int = 10, +) -> float: + with torch.no_grad(): + if device == "cuda": + for _ in range(warmup_runs): + fn() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(timed_runs): + fn() + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) / timed_runs + + for _ in range(warmup_runs): + fn() + start = time.perf_counter() + for _ in range(timed_runs): + fn() + return (time.perf_counter() - start) / timed_runs * 1000.0 + + +def _benchmark_signature(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + parts: list[str] = [] + for value in args: + if isinstance(value, torch.Tensor): + parts.append(f"Tensor{tuple(value.shape)}:{value.dtype}") + else: + parts.append(repr(value)) + for key, value in kwargs.items(): + if isinstance(value, torch.Tensor): + parts.append(f"{key}=Tensor{tuple(value.shape)}:{value.dtype}") + else: + parts.append(f"{key}={value!r}") + return ", ".join(parts) + + +def _assert_outputs_close(op_name: str, torch_output: Any, kernel_output: Any) -> None: + assert isinstance(torch_output, torch.Tensor), f"{op_name}: expected tensor output from PyTorch" + assert isinstance(kernel_output, torch.Tensor), f"{op_name}: expected tensor output from custom kernel" + assert torch_output.shape == kernel_output.shape, f"{op_name}: output shapes differ" + + atol = 1e-3 if torch_output.dtype in (torch.float16, torch.bfloat16) else 1e-4 + rtol = 1e-3 if torch_output.dtype in (torch.float16, torch.bfloat16) else 1e-4 + torch.testing.assert_close(torch_output, kernel_output, atol=atol, rtol=rtol) + + +def test_cast_kernel_benchmarks_vs_pytorch() -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for kernel-vs-PyTorch benchmarking.") + + cast_path = _find_resnet50_cast() + cache_dir, manifest = _extract_cast(cast_path) + + device = "cuda" + pixel_values = torch.randn(1, 3, 224, 224, device=device) + model = kernelforge.load(str(cast_path), device=device, no_kernels=True) + + benchable_ops = [ + op for op in manifest["ops"] + if op["name"].startswith(_F_PREFIX) and op["name"] not in _DEFAULT_SKIP_OPS + ] + fn_attrs = [op["name"][len(_F_PREFIX):] for op in benchable_ops] + + captures, originals = _capture_functional_calls(model, pixel_values, fn_attrs) + + results: list[dict[str, Any]] = [] + for op in benchable_ops: + op_name = op["name"] + fn_attr = op_name[len(_F_PREFIX):] + capture = captures.get(fn_attr) + original = originals.get(fn_attr) + if capture is None or original is None: + continue + + ext, kernel_cu = _load_extension(cache_dir, op, opt_level="-O3") + launch_args = _prepare_launch_args( + fn_attr, + original, + ext, + kernel_cu, + capture["args"], + capture["kwargs"], + ) + + with torch.no_grad(): + torch_output = original(*capture["args"], **capture["kwargs"]) + kernel_output = ext.launch(*launch_args) + if device == "cuda": + torch.cuda.synchronize() + + _assert_outputs_close(op_name, torch_output, kernel_output) + + torch_ms = _benchmark_callable( + lambda: original(*capture["args"], **capture["kwargs"]), + device=device, + ) + kernel_ms = _benchmark_callable( + lambda: ext.launch(*launch_args), + device=device, + ) + + results.append( + { + "op_name": op_name, + "signature": _benchmark_signature(capture["args"], capture["kwargs"]), + "torch_ms": torch_ms, + "kernel_ms": kernel_ms, + "speedup": torch_ms / kernel_ms if kernel_ms else float("inf"), + } + ) + + assert results, "No benchmarkable optimized kernels were found in the .cast file." + + print("\nKernel benchmark results:") + for result in results: + print( + f" {result['op_name']:<40}" + f"kernel {result['kernel_ms']:>8.3f} ms | " + f"torch {result['torch_ms']:>8.3f} ms | " + f"speedup {result['speedup']:>6.2f}x" + ) + print(f" sample: {result['signature']}") diff --git a/tests/test_load_cast_resnet50.py b/tests/test_load_cast_resnet50.py new file mode 100644 index 00000000..03aa3e47 --- /dev/null +++ b/tests/test_load_cast_resnet50.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import os +import sys +import time +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +torch = pytest.importorskip("torch") + +import kernelforge + + +def _candidate_cast_paths() -> list[Path]: + candidates: list[Path] = [] + + env_path = os.getenv("RESNET50_CAST_PATH") + if env_path: + candidates.append(Path(env_path).expanduser()) + + default_export = ( + REPO_ROOT + / "kernels" + / "projects" + / "oioioio - RTX 3050 Laptop GPU" + / "exports" + / "oioioio - RTX 3050 Laptop GPU.cast" + ) + candidates.append(default_export) + + exports_root = REPO_ROOT / "kernels" / "projects" + candidates.extend(sorted(exports_root.glob("*/exports/*.cast"))) + + unique: list[Path] = [] + seen: set[Path] = set() + for path in candidates: + resolved = path.resolve(strict=False) + if resolved in seen: + continue + seen.add(resolved) + unique.append(path) + return unique + + +def _find_resnet50_cast() -> Path: + for path in _candidate_cast_paths(): + if path.exists(): + return path + pytest.skip("No ResNet-50 .cast export found. Set RESNET50_CAST_PATH to a cast file.") + + +def _run_once(model: torch.nn.Module, pixel_values: torch.Tensor) -> object: + device = pixel_values.device.type + + output = model.run(pixel_values=pixel_values) + if device == "cuda": + torch.cuda.synchronize() + return output + + +def _benchmark_models_alternating( + pytorch_model: torch.nn.Module, + optimized_model: torch.nn.Module, + pixel_values: torch.Tensor, + *, + warmup_runs: int = 3, + timed_runs: int = 5, +) -> tuple[object, float, object, float]: + pytorch_total_ms = 0.0 + optimized_total_ms = 0.0 + + with torch.inference_mode(): + for _ in range(warmup_runs): + _run_once(pytorch_model, pixel_values) + _run_once(optimized_model, pixel_values) + + pytorch_output = None + optimized_output = None + for _ in range(timed_runs): + start = time.perf_counter() + pytorch_output = _run_once(pytorch_model, pixel_values) + pytorch_total_ms += (time.perf_counter() - start) * 1000.0 + + start = time.perf_counter() + optimized_output = _run_once(optimized_model, pixel_values) + optimized_total_ms += (time.perf_counter() - start) * 1000.0 + + return ( + pytorch_output, + pytorch_total_ms / timed_runs, + optimized_output, + optimized_total_ms / timed_runs, + ) + + +def test_load_cast_resnet50_benchmark() -> None: + cast_path = _find_resnet50_cast() + device = "cuda" if torch.cuda.is_available() else "cpu" + pixel_values = torch.randn(1, 3, 224, 224, device=device) + + optimized_model = kernelforge.load(str(cast_path), device=device, opt_level="-O3") + pytorch_model = kernelforge.load(str(cast_path), device=device, no_kernels=True) + + assert isinstance(optimized_model, torch.nn.Module) + assert isinstance(pytorch_model, torch.nn.Module) + assert optimized_model.training is False + assert pytorch_model.training is False + assert hasattr(optimized_model, "run") + assert hasattr(pytorch_model, "run") + + pytorch_output, pytorch_ms, optimized_output, optimized_ms = _benchmark_models_alternating( + pytorch_model, + optimized_model, + pixel_values, + warmup_runs=3, + timed_runs=5, + ) + + optimized_logits = optimized_output.logits if hasattr(optimized_output, "logits") else optimized_output + pytorch_logits = pytorch_output.logits if hasattr(pytorch_output, "logits") else pytorch_output + + assert isinstance(optimized_logits, torch.Tensor) + assert isinstance(pytorch_logits, torch.Tensor) + assert optimized_logits.device.type == device + assert pytorch_logits.device.type == device + assert optimized_logits.shape == pytorch_logits.shape + assert optimized_logits.shape[0] == 1 + assert optimized_logits.ndim == 2 + assert optimized_logits.shape[1] > 0 + + print(f"\n.cast runtime average latency : {optimized_ms:.2f} ms") + print(f"PyTorch average latency : {pytorch_ms:.2f} ms") diff --git a/tests/test_run_cast_autotune_resnet50.py b/tests/test_run_cast_autotune_resnet50.py new file mode 100644 index 00000000..c5cec02b --- /dev/null +++ b/tests/test_run_cast_autotune_resnet50.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +torch = pytest.importorskip("torch") + +import run_cast +import run_cast_autotune + + +def _candidate_cast_paths() -> list[Path]: + candidates: list[Path] = [] + + env_path = os.getenv("RESNET50_CAST_PATH") + if env_path: + candidates.append(Path(env_path).expanduser()) + + default_export = ( + REPO_ROOT + / "kernels" + / "projects" + / "oioioio - RTX 3050 Laptop GPU" + / "exports" + / "oioioio - RTX 3050 Laptop GPU.cast" + ) + candidates.append(default_export) + + exports_root = REPO_ROOT / "kernels" / "projects" + candidates.extend(sorted(exports_root.glob("*/exports/*.cast"))) + + unique: list[Path] = [] + seen: set[Path] = set() + for path in candidates: + resolved = path.resolve(strict=False) + if resolved in seen: + continue + seen.add(resolved) + unique.append(path) + return unique + + +def _find_resnet50_cast() -> Path: + for path in _candidate_cast_paths(): + if path.exists(): + return path + pytest.skip("No ResNet-50 .cast export found. Set RESNET50_CAST_PATH to a cast file.") + + +def _benchmark_model( + model: torch.nn.Module, + pixel_values: torch.Tensor, + *, + warmup_runs: int = 3, + timed_runs: int = 5, +) -> tuple[object, float]: + device = pixel_values.device.type + + with torch.inference_mode(): + output = None + for _ in range(warmup_runs): + output = model.run(pixel_values=pixel_values) + if device == "cuda": + torch.cuda.synchronize() + + if device == "cuda": + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(timed_runs): + output = model.run(pixel_values=pixel_values) + end.record() + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) / timed_runs + else: + import time + + t0 = time.perf_counter() + for _ in range(timed_runs): + output = model.run(pixel_values=pixel_values) + elapsed_ms = (time.perf_counter() - t0) / timed_runs * 1000.0 + + return output, elapsed_ms + + +def test_run_cast_autotune_resnet50_speed() -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for the autotuned cast benchmark.") + + cast_path = _find_resnet50_cast() + device = "cuda" + pixel_values = torch.randn(1, 3, 224, 224, device=device) + + run_cast_autotune.ensure_cuda_toolkit_env() + + pytorch_model = run_cast.load_cast(str(cast_path), device=device, no_kernels=True) + full_model = run_cast.load_cast(str(cast_path), device=device, opt_level="-O3") + autotuned_model = run_cast_autotune.load_cast( + str(cast_path), + device=device, + opt_level="-O3", + base_policy="skip_aten", + warmup_runs=3, + timed_runs=5, + cache_enabled=False, + ) + + tuned_patch_names = autotuned_model.autotune(pixel_values=pixel_values) + + pytorch_output, pytorch_ms = _benchmark_model(pytorch_model, pixel_values) + full_output, full_ms = _benchmark_model(full_model, pixel_values) + autotuned_output, autotuned_ms = _benchmark_model(autotuned_model, pixel_values) + + pytorch_logits = pytorch_output.logits if hasattr(pytorch_output, "logits") else pytorch_output + full_logits = full_output.logits if hasattr(full_output, "logits") else full_output + autotuned_logits = autotuned_output.logits if hasattr(autotuned_output, "logits") else autotuned_output + + assert isinstance(pytorch_logits, torch.Tensor) + assert isinstance(full_logits, torch.Tensor) + assert isinstance(autotuned_logits, torch.Tensor) + assert pytorch_logits.shape == full_logits.shape == autotuned_logits.shape + assert torch.allclose(pytorch_logits, full_logits, atol=1e-3, rtol=1e-3) + assert torch.allclose(pytorch_logits, autotuned_logits, atol=1e-3, rtol=1e-3) + + print("\nAutotuned kernels :", tuned_patch_names) + print(f"PyTorch fallback latency : {pytorch_ms:.2f} ms") + print(f"run_cast full latency : {full_ms:.2f} ms") + print(f"autotuned runtime latency : {autotuned_ms:.2f} ms") From e16eff07dbfe95377458410d1b5b2180b0944052 Mon Sep 17 00:00:00 2001 From: logansg Date: Wed, 25 Mar 2026 00:10:43 -0400 Subject: [PATCH 2/3] Remove autotuned cast runtime from PR --- run_cast_autotune.py | 492 ----------------------- tests/test_run_cast_autotune_resnet50.py | 135 ------- 2 files changed, 627 deletions(-) delete mode 100644 run_cast_autotune.py delete mode 100644 tests/test_run_cast_autotune_resnet50.py diff --git a/run_cast_autotune.py b/run_cast_autotune.py deleted file mode 100644 index bb43154d..00000000 --- a/run_cast_autotune.py +++ /dev/null @@ -1,492 +0,0 @@ -#!/usr/bin/env python3 -"""Autotuning .cast runtime that keeps only end-to-end beneficial kernels.""" - -from __future__ import annotations - -import argparse -import hashlib -import json -import os -import shutil -import tempfile -import time -from pathlib import Path -from typing import Any - -import torch.nn as nn - -import run_cast - -_DEFAULT_MARGIN = 0.01 - - -def ensure_cuda_toolkit_env() -> str | None: - """Point CUDA_HOME/CUDACXX/PATH at a working nvcc if one is available.""" - - candidates: list[Path] = [] - - cuda_home = os.getenv("CUDA_HOME") - if cuda_home: - candidates.append(Path(cuda_home)) - - cudacxx = os.getenv("CUDACXX") - if cudacxx: - candidates.append(Path(cudacxx).expanduser().resolve().parent.parent) - - nvcc_on_path = shutil.which("nvcc") - if nvcc_on_path: - candidates.append(Path(nvcc_on_path).resolve().parent.parent) - - candidates.append(Path("/usr/local/cuda")) - candidates.extend(sorted(Path("/usr/local").glob("cuda-*"), reverse=True)) - - seen: set[Path] = set() - for candidate in candidates: - resolved = candidate.resolve(strict=False) - if resolved in seen: - continue - seen.add(resolved) - - nvcc = resolved / "bin" / "nvcc" - if not nvcc.exists(): - continue - - os.environ["CUDA_HOME"] = str(resolved) - os.environ.setdefault("CUDACXX", str(nvcc)) - - path_entries = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else [] - nvcc_dir = str(nvcc.parent) - if nvcc_dir not in path_entries: - os.environ["PATH"] = nvcc_dir if not path_entries else nvcc_dir + os.pathsep + os.environ["PATH"] - return str(resolved) - - return None - - -def _normalize_input(value: Any) -> Any: - import torch - - if isinstance(value, torch.Tensor): - return { - "type": "tensor", - "shape": list(value.shape), - "dtype": str(value.dtype), - "device": str(value.device), - "requires_grad": bool(value.requires_grad), - } - if isinstance(value, dict): - return {str(key): _normalize_input(val) for key, val in sorted(value.items(), key=lambda item: str(item[0]))} - if isinstance(value, (list, tuple)): - return [_normalize_input(item) for item in value] - if isinstance(value, (str, int, float, bool)) or value is None: - return value - return {"type": type(value).__name__, "repr": repr(value)} - - -def _input_signature(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: - payload = { - "args": _normalize_input(args), - "kwargs": _normalize_input(kwargs), - } - return json.dumps(payload, sort_keys=True, separators=(",", ":")) - - -def _first_tensor_device(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: - import torch - - def visit(value: Any) -> str | None: - if isinstance(value, torch.Tensor): - return value.device.type - if isinstance(value, dict): - for item in value.values(): - found = visit(item) - if found is not None: - return found - return None - if isinstance(value, (list, tuple)): - for item in value: - found = visit(item) - if found is not None: - return found - return None - return None - - found = visit(args) - if found is not None: - return found - found = visit(kwargs) - return found or "cpu" - - -def _flatten_output(value: Any) -> Any: - import torch - - if isinstance(value, torch.Tensor): - return ("tensor", value) - if hasattr(value, "to_tuple") and callable(value.to_tuple): - return ("tuple_like", tuple(_flatten_output(item) for item in value.to_tuple())) - if isinstance(value, dict): - return ("dict", tuple((str(key), _flatten_output(val)) for key, val in sorted(value.items(), key=lambda item: str(item[0])))) - if isinstance(value, (list, tuple)): - return ("seq", tuple(_flatten_output(item) for item in value)) - return ("value", value) - - -def _outputs_match(reference: Any, candidate: Any) -> bool: - import torch - - ref = _flatten_output(reference) - cand = _flatten_output(candidate) - - def compare(left: Any, right: Any) -> bool: - if left[0] != right[0]: - return False - kind = left[0] - if kind == "tensor": - try: - atol = 1e-3 if left[1].dtype in (torch.float16, torch.bfloat16) else 1e-4 - rtol = 1e-3 if left[1].dtype in (torch.float16, torch.bfloat16) else 1e-4 - torch.testing.assert_close(left[1], right[1], atol=atol, rtol=rtol) - return True - except Exception: - return False - if kind in {"tuple_like", "seq"}: - if len(left[1]) != len(right[1]): - return False - return all(compare(l_item, r_item) for l_item, r_item in zip(left[1], right[1])) - if kind == "dict": - if len(left[1]) != len(right[1]): - return False - return all( - l_key == r_key and compare(l_val, r_val) - for (l_key, l_val), (r_key, r_val) in zip(left[1], right[1]) - ) - return left[1] == right[1] - - return compare(ref, cand) - - -def _cast_cache_key(cast_path: str) -> str: - cast_bytes = Path(cast_path).read_bytes() - return hashlib.sha256(cast_bytes).hexdigest() - - -class AutotunedCastModelRuntime(nn.Module): - """A runtime wrapper that benchmarks patch subsets and keeps only the winners.""" - - def __init__( - self, - runtime_model: run_cast.CastModelRuntime, - *, - cast_path: str, - opt_level: str, - base_policy: str, - warmup_runs: int, - timed_runs: int, - improvement_margin: float, - cache_enabled: bool, - ) -> None: - if not isinstance(runtime_model, run_cast.CastModelRuntime): - raise TypeError("runtime_model must be a CastModelRuntime") - super().__init__() - self.model = runtime_model - self._all_patches = dict(runtime_model._cast_functional_patches) - self._active_patch_names: list[str] = list(self._all_patches.keys()) - self._signature_patch_names: dict[str, list[str]] = {} - self._signature_patch_maps: dict[str, dict[str, Any]] = {} - self._opt_level = opt_level - self._base_policy = base_policy - self._warmup_runs = warmup_runs - self._timed_runs = timed_runs - self._improvement_margin = improvement_margin - self._cache_enabled = cache_enabled - self._cast_key = _cast_cache_key(cast_path) - self._cache_path = self._build_cache_path(self._cast_key) - self._cache_data = self._load_cache() - - def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - model = self._modules.get("model") - if model is None: - raise - return getattr(model, name) - - @staticmethod - def _build_cache_path(cast_key: str) -> Path: - cache_root = Path.home() / ".cache" / "cast_autotune" - gpu_name = "cpu" - if os.getenv("CUDA_VISIBLE_DEVICES", "") == "": - pass - try: - import torch - - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(0).replace(" ", "_") - except Exception: - pass - return cache_root / f"{cast_key}_{gpu_name}.json" - - def _load_cache(self) -> dict[str, Any]: - if not self._cache_enabled or not self._cache_path.exists(): - return {} - try: - return json.loads(self._cache_path.read_text()) - except Exception: - return {} - - def _persist_cache(self) -> None: - if not self._cache_enabled: - return - self._cache_path.parent.mkdir(parents=True, exist_ok=True) - payload = json.dumps(self._cache_data, indent=2, sort_keys=True) - with tempfile.NamedTemporaryFile("w", dir=self._cache_path.parent, delete=False) as handle: - handle.write(payload) - temp_path = Path(handle.name) - temp_path.replace(self._cache_path) - - def _select_patch_names(self, names: list[str]) -> dict[str, Any]: - return {name: self._all_patches[name] for name in names if name in self._all_patches} - - def _set_active_patch_names(self, names: list[str]) -> None: - self._active_patch_names = list(names) - self.model._cast_functional_patches = self._select_patch_names(names) - - def _run_once(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: - return self.model.run(*args, **kwargs) - - def _benchmark_patch_names(self, patch_names: list[str], args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, float]: - import torch - - self._set_active_patch_names(patch_names) - device_type = _first_tensor_device(args, kwargs) - - with torch.inference_mode(): - output = None - for _ in range(self._warmup_runs): - output = self._run_once(args, kwargs) - if device_type == "cuda": - torch.cuda.synchronize() - - if device_type == "cuda": - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(self._timed_runs): - output = self._run_once(args, kwargs) - end.record() - torch.cuda.synchronize() - elapsed_ms = start.elapsed_time(end) / self._timed_runs - else: - t0 = time.perf_counter() - for _ in range(self._timed_runs): - output = self._run_once(args, kwargs) - elapsed_ms = (time.perf_counter() - t0) / self._timed_runs * 1000.0 - - return output, elapsed_ms - - def autotune(self, *args: Any, **kwargs: Any) -> list[str]: - signature = _input_signature(args, kwargs) - cached = self._signature_patch_names.get(signature) - if cached is not None: - return list(cached) - - cache_entry = self._cache_data.get(signature) - if isinstance(cache_entry, list): - names = [name for name in cache_entry if name in self._all_patches] - self._signature_patch_names[signature] = names - self._signature_patch_maps[signature] = self._select_patch_names(names) - self._set_active_patch_names(names) - return list(names) - - baseline_output, baseline_ms = self._benchmark_patch_names([], args, kwargs) - - current_names: list[str] = [] - current_ms = baseline_ms - current_output = baseline_output - remaining = list(self._all_patches.keys()) - - while remaining: - best_candidate_name: str | None = None - best_candidate_output = None - best_candidate_ms = current_ms - - for candidate in remaining: - candidate_names = current_names + [candidate] - candidate_output, candidate_ms = self._benchmark_patch_names(candidate_names, args, kwargs) - if not _outputs_match(baseline_output, candidate_output): - continue - if candidate_ms < best_candidate_ms: - best_candidate_name = candidate - best_candidate_output = candidate_output - best_candidate_ms = candidate_ms - - if best_candidate_name is None: - break - if best_candidate_ms >= current_ms * (1.0 - self._improvement_margin): - break - - current_names.append(best_candidate_name) - remaining.remove(best_candidate_name) - current_ms = best_candidate_ms - current_output = best_candidate_output - - improved = True - while improved and current_names: - improved = False - for candidate in list(current_names): - trial_names = [name for name in current_names if name != candidate] - trial_output, trial_ms = self._benchmark_patch_names(trial_names, args, kwargs) - if not _outputs_match(baseline_output, trial_output): - continue - if trial_ms < current_ms * (1.0 - self._improvement_margin): - current_names = trial_names - current_ms = trial_ms - current_output = trial_output - improved = True - break - - if not _outputs_match(baseline_output, current_output): - current_names = [] - - self._signature_patch_names[signature] = list(current_names) - self._signature_patch_maps[signature] = self._select_patch_names(current_names) - self._cache_data[signature] = list(current_names) - self._set_active_patch_names(current_names) - self._persist_cache() - return list(current_names) - - def active_patch_names(self, *args: Any, **kwargs: Any) -> list[str]: - if args or kwargs: - return self.autotune(*args, **kwargs) - return list(self._active_patch_names) - - def forward(self, *args: Any, **kwargs: Any) -> Any: - signature = _input_signature(args, kwargs) - patch_names = self._signature_patch_names.get(signature) - if patch_names is None: - patch_names = self.autotune(*args, **kwargs) - self._set_active_patch_names(patch_names) - return self.model.run(*args, **kwargs) - - def run(self, *args: Any, **kwargs: Any) -> Any: - return self(*args, **kwargs) - - -def load_cast( - cast_path: str, - *, - device: str | None = None, - model_args: dict | None = None, - opt_level: str = "-O3", - base_policy: str = "skip_aten", - warmup_runs: int = 3, - timed_runs: int = 5, - improvement_margin: float = _DEFAULT_MARGIN, - cache_enabled: bool = True, -) -> AutotunedCastModelRuntime: - ensure_cuda_toolkit_env() - - runtime_model = run_cast.load_cast( - cast_path, - model_args=model_args, - no_kernels=False, - opt_level=opt_level, - device=device, - kernel_policy=base_policy, - ) - - return AutotunedCastModelRuntime( - runtime_model, - cast_path=cast_path, - opt_level=opt_level, - base_policy=base_policy, - warmup_runs=warmup_runs, - timed_runs=timed_runs, - improvement_margin=improvement_margin, - cache_enabled=cache_enabled, - ) - - -def _parse_shape(raw: str) -> tuple[int, ...]: - return run_cast._parse_shape(raw) - - -def _benchmark_loaded_model(model: AutotunedCastModelRuntime, device: str, runs: int, input_shape: tuple[int, ...]) -> None: - import torch - - dummy = torch.randn(*input_shape, device=device) - tuned = model.autotune(dummy) - print(f"Autotuned kernels: {tuned}") - run_cast._benchmark_loaded_model(model, device, runs, input_shape) - - -def main() -> None: - import torch - - parser = argparse.ArgumentParser(description="Run a KernelForge .cast inference package with end-to-end autotuning") - parser.add_argument("cast_file", help="Path to .cast file") - parser.add_argument( - "--device", - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device to run on (default: cuda if available)", - ) - parser.add_argument( - "--model-args", - metavar="JSON", - default=None, - help='JSON config to instantiate the model, e.g. \'{"model_type":"resnet",...}\'.', - ) - parser.add_argument( - "--opt-level", - default="-O3", - choices=["-O0", "-O1", "-O2", "-O3"], - help="NVCC optimisation level for JIT compilation (default: -O3).", - ) - parser.add_argument( - "--base-policy", - default="skip_aten", - choices=list(run_cast._KERNEL_POLICIES), - help="Candidate kernel policy before autotuning (default: skip_aten).", - ) - parser.add_argument("--warmup-runs", type=int, default=3, help="Warmup runs for autotuning benchmarks.") - parser.add_argument("--timed-runs", type=int, default=5, help="Timed runs for autotuning benchmarks.") - parser.add_argument( - "--improvement-margin", - type=float, - default=_DEFAULT_MARGIN, - help="Minimum fractional improvement required to keep a kernel (default: 0.01).", - ) - parser.add_argument("--no-cache", action="store_true", help="Disable autotune cache persistence.") - parser.add_argument("--benchmark", action="store_true", help="Run a dummy-input benchmark after loading the model.") - parser.add_argument("--runs", type=int, default=5, help="Inference passes for timing with --benchmark") - parser.add_argument( - "--input-shape", - default="1,3,224,224", - help="Comma-separated dummy input shape for --benchmark (default: 1,3,224,224).", - ) - args = parser.parse_args() - - extra_model_args = json.loads(args.model_args) if args.model_args else None - model = load_cast( - args.cast_file, - device=args.device, - model_args=extra_model_args, - opt_level=args.opt_level, - base_policy=args.base_policy, - warmup_runs=args.warmup_runs, - timed_runs=args.timed_runs, - improvement_margin=args.improvement_margin, - cache_enabled=not args.no_cache, - ) - print(f"\nModel ready on {args.device}") - print(f"Candidate policy: {args.base_policy}") - - if args.benchmark: - input_shape = _parse_shape(args.input_shape) - _benchmark_loaded_model(model, args.device, args.runs, input_shape) - - -if __name__ == "__main__": - main() diff --git a/tests/test_run_cast_autotune_resnet50.py b/tests/test_run_cast_autotune_resnet50.py deleted file mode 100644 index c5cec02b..00000000 --- a/tests/test_run_cast_autotune_resnet50.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations - -import os -import sys -from pathlib import Path - -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -torch = pytest.importorskip("torch") - -import run_cast -import run_cast_autotune - - -def _candidate_cast_paths() -> list[Path]: - candidates: list[Path] = [] - - env_path = os.getenv("RESNET50_CAST_PATH") - if env_path: - candidates.append(Path(env_path).expanduser()) - - default_export = ( - REPO_ROOT - / "kernels" - / "projects" - / "oioioio - RTX 3050 Laptop GPU" - / "exports" - / "oioioio - RTX 3050 Laptop GPU.cast" - ) - candidates.append(default_export) - - exports_root = REPO_ROOT / "kernels" / "projects" - candidates.extend(sorted(exports_root.glob("*/exports/*.cast"))) - - unique: list[Path] = [] - seen: set[Path] = set() - for path in candidates: - resolved = path.resolve(strict=False) - if resolved in seen: - continue - seen.add(resolved) - unique.append(path) - return unique - - -def _find_resnet50_cast() -> Path: - for path in _candidate_cast_paths(): - if path.exists(): - return path - pytest.skip("No ResNet-50 .cast export found. Set RESNET50_CAST_PATH to a cast file.") - - -def _benchmark_model( - model: torch.nn.Module, - pixel_values: torch.Tensor, - *, - warmup_runs: int = 3, - timed_runs: int = 5, -) -> tuple[object, float]: - device = pixel_values.device.type - - with torch.inference_mode(): - output = None - for _ in range(warmup_runs): - output = model.run(pixel_values=pixel_values) - if device == "cuda": - torch.cuda.synchronize() - - if device == "cuda": - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(timed_runs): - output = model.run(pixel_values=pixel_values) - end.record() - torch.cuda.synchronize() - elapsed_ms = start.elapsed_time(end) / timed_runs - else: - import time - - t0 = time.perf_counter() - for _ in range(timed_runs): - output = model.run(pixel_values=pixel_values) - elapsed_ms = (time.perf_counter() - t0) / timed_runs * 1000.0 - - return output, elapsed_ms - - -def test_run_cast_autotune_resnet50_speed() -> None: - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for the autotuned cast benchmark.") - - cast_path = _find_resnet50_cast() - device = "cuda" - pixel_values = torch.randn(1, 3, 224, 224, device=device) - - run_cast_autotune.ensure_cuda_toolkit_env() - - pytorch_model = run_cast.load_cast(str(cast_path), device=device, no_kernels=True) - full_model = run_cast.load_cast(str(cast_path), device=device, opt_level="-O3") - autotuned_model = run_cast_autotune.load_cast( - str(cast_path), - device=device, - opt_level="-O3", - base_policy="skip_aten", - warmup_runs=3, - timed_runs=5, - cache_enabled=False, - ) - - tuned_patch_names = autotuned_model.autotune(pixel_values=pixel_values) - - pytorch_output, pytorch_ms = _benchmark_model(pytorch_model, pixel_values) - full_output, full_ms = _benchmark_model(full_model, pixel_values) - autotuned_output, autotuned_ms = _benchmark_model(autotuned_model, pixel_values) - - pytorch_logits = pytorch_output.logits if hasattr(pytorch_output, "logits") else pytorch_output - full_logits = full_output.logits if hasattr(full_output, "logits") else full_output - autotuned_logits = autotuned_output.logits if hasattr(autotuned_output, "logits") else autotuned_output - - assert isinstance(pytorch_logits, torch.Tensor) - assert isinstance(full_logits, torch.Tensor) - assert isinstance(autotuned_logits, torch.Tensor) - assert pytorch_logits.shape == full_logits.shape == autotuned_logits.shape - assert torch.allclose(pytorch_logits, full_logits, atol=1e-3, rtol=1e-3) - assert torch.allclose(pytorch_logits, autotuned_logits, atol=1e-3, rtol=1e-3) - - print("\nAutotuned kernels :", tuned_patch_names) - print(f"PyTorch fallback latency : {pytorch_ms:.2f} ms") - print(f"run_cast full latency : {full_ms:.2f} ms") - print(f"autotuned runtime latency : {autotuned_ms:.2f} ms") From 5cb3d89d24ac143ac7ebd5167ed9ff8317f2e295 Mon Sep 17 00:00:00 2001 From: logansg Date: Wed, 25 Mar 2026 00:13:22 -0400 Subject: [PATCH 3/3] remove test cases --- tests/test_cast_ops_vs_pytorch.py | 391 ------------------------------ tests/test_load_cast_resnet50.py | 137 ----------- 2 files changed, 528 deletions(-) delete mode 100644 tests/test_cast_ops_vs_pytorch.py delete mode 100644 tests/test_load_cast_resnet50.py diff --git a/tests/test_cast_ops_vs_pytorch.py b/tests/test_cast_ops_vs_pytorch.py deleted file mode 100644 index d83685d2..00000000 --- a/tests/test_cast_ops_vs_pytorch.py +++ /dev/null @@ -1,391 +0,0 @@ -from __future__ import annotations - -import hashlib -import importlib.util -import inspect -import json -import os -import shutil -import sys -import time -import zipfile -from pathlib import Path -from typing import Any - -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -torch = pytest.importorskip("torch") -import torch.nn.functional as F - -import kernelforge -from run_cast import _launch_arity, compile_kernel, verify_checksums - -_F_PREFIX = "torch_nn_functional_" -_DEFAULT_SKIP_OPS = {"torch_nn_functional_conv2d"} -_FUNCTIONAL_PARAM_NAMES: dict[str, list[str]] = { - "adaptive_avg_pool2d": ["input", "output_size"], - "batch_norm": ["input", "running_mean", "running_var", "weight", "bias", "training", "momentum", "eps"], - "conv2d": ["input", "weight", "bias", "stride", "padding", "dilation", "groups"], - "linear": ["input", "weight", "bias"], - "max_pool2d": ["input", "kernel_size", "stride", "padding", "dilation", "ceil_mode", "return_indices"], - "relu": ["input", "inplace"], -} -_FUNCTIONAL_DEFAULTS: dict[str, dict[str, Any]] = { - "batch_norm": {"weight": None, "bias": None, "training": False, "momentum": 0.1, "eps": 1e-5}, - "conv2d": {"bias": None, "stride": 1, "padding": 0, "dilation": 1, "groups": 1}, - "linear": {"bias": None}, - "max_pool2d": {"stride": None, "padding": 0, "dilation": 1, "ceil_mode": False, "return_indices": False}, - "relu": {"inplace": False}, -} - - -def _candidate_cast_paths() -> list[Path]: - candidates: list[Path] = [] - - env_path = os.getenv("RESNET50_CAST_PATH") - if env_path: - candidates.append(Path(env_path).expanduser()) - - default_export = ( - REPO_ROOT - / "kernels" - / "projects" - / "oioioio - RTX 3050 Laptop GPU" - / "exports" - / "oioioio - RTX 3050 Laptop GPU.cast" - ) - candidates.append(default_export) - - exports_root = REPO_ROOT / "kernels" / "projects" - candidates.extend(sorted(exports_root.glob("*/exports/*.cast"))) - - unique: list[Path] = [] - seen: set[Path] = set() - for path in candidates: - resolved = path.resolve(strict=False) - if resolved in seen: - continue - seen.add(resolved) - unique.append(path) - return unique - - -def _find_resnet50_cast() -> Path: - for path in _candidate_cast_paths(): - if path.exists(): - return path - pytest.skip("No ResNet-50 .cast export found. Set RESNET50_CAST_PATH to a cast file.") - - -def _cache_dir_for(cast_path: Path) -> Path: - cache_key = hashlib.sha256(cast_path.read_bytes()).hexdigest() - return Path.home() / ".cache" / "cast" / cache_key - - -def _extract_cast(cast_path: Path) -> tuple[Path, dict[str, Any]]: - cache_dir = _cache_dir_for(cast_path) - with zipfile.ZipFile(cast_path) as zf: - verify_checksums(zf) - if not cache_dir.is_dir(): - zf.extractall(cache_dir) - manifest = json.loads(zf.read("manifest.json")) - return cache_dir, manifest - - -def _ensure_cuda_toolkit() -> None: - candidates: list[Path] = [] - - cuda_home = os.getenv("CUDA_HOME") - if cuda_home: - candidates.append(Path(cuda_home)) - - cudacxx = os.getenv("CUDACXX") - if cudacxx: - candidates.append(Path(cudacxx).resolve().parent.parent) - - nvcc_on_path = shutil.which("nvcc") - if nvcc_on_path: - candidates.append(Path(nvcc_on_path).resolve().parent.parent) - - candidates.append(Path("/usr/local/cuda")) - candidates.extend(sorted(Path("/usr/local").glob("cuda-*"), reverse=True)) - - seen: set[Path] = set() - for candidate in candidates: - resolved = candidate.resolve(strict=False) - if resolved in seen: - continue - seen.add(resolved) - - nvcc = resolved / "bin" / "nvcc" - if not nvcc.exists(): - continue - - os.environ["CUDA_HOME"] = str(resolved) - os.environ.setdefault("CUDACXX", str(nvcc)) - - path_entries = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else [] - nvcc_dir = str(nvcc.parent) - if nvcc_dir not in path_entries: - os.environ["PATH"] = nvcc_dir if not path_entries else nvcc_dir + os.pathsep + os.environ["PATH"] - return - - pytest.skip("No CUDA toolkit with nvcc was found for JIT kernel benchmarking.") - - -def _capture_functional_calls( - model: torch.nn.Module, - pixel_values: torch.Tensor, - fn_attrs: list[str], -) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]: - captures: dict[str, dict[str, Any]] = {} - originals: dict[str, Any] = {} - - for fn_attr in fn_attrs: - original = getattr(F, fn_attr, None) - if original is not None: - originals[fn_attr] = original - - def make_wrapper(name: str, original): - def wrapped(*args, **kwargs): - captures.setdefault(name, {"args": args, "kwargs": kwargs}) - return original(*args, **kwargs) - - wrapped.__name__ = f"capture_{name}" - return wrapped - - for fn_attr, original in originals.items(): - setattr(F, fn_attr, make_wrapper(fn_attr, original)) - - try: - with torch.no_grad(): - _ = model.run(pixel_values=pixel_values) - if pixel_values.is_cuda: - torch.cuda.synchronize() - finally: - for fn_attr, original in originals.items(): - setattr(F, fn_attr, original) - - return captures, originals - - -def _load_extension( - cache_dir: Path, - op: dict[str, Any], - *, - opt_level: str, -) -> tuple[Any, Path]: - op_name = op["name"] - kernel_cu = cache_dir / op["cuda_source"] - gpu_sm = "sm_{0}{1}".format(*torch.cuda.get_device_capability()) - precompiled = op.get("precompiled", {}) - so_rel = precompiled.get(gpu_sm) - so_path = cache_dir / so_rel if so_rel else None - - if so_path and so_path.exists(): - spec = importlib.util.spec_from_file_location(op_name, so_path) - if spec is None or spec.loader is None: - raise RuntimeError(f"Could not load precompiled module for {op_name}") - ext = importlib.util.module_from_spec(spec) - spec.loader.exec_module(ext) - return ext, kernel_cu - - if not kernel_cu.exists(): - raise RuntimeError(f"No kernel source found for {op_name}") - - _ensure_cuda_toolkit() - build_dir = cache_dir / "build_kernel_bench" - build_dir.mkdir(parents=True, exist_ok=True) - ext = compile_kernel(str(kernel_cu), op_name, str(build_dir), opt_level=opt_level) - return ext, kernel_cu - - -def _prepare_launch_args( - fn_attr: str, - original, - ext: Any, - kernel_cu: Path, - args: tuple[Any, ...], - kwargs: dict[str, Any], -) -> list[Any]: - param_names = _FUNCTIONAL_PARAM_NAMES.get(fn_attr) - if param_names is not None: - resolved = {param_names[i]: value for i, value in enumerate(args) if i < len(param_names)} - resolved.update(kwargs) - defaults = _FUNCTIONAL_DEFAULTS.get(fn_attr, {}) - ordered = [resolved.get(name, defaults.get(name)) for name in param_names] - if fn_attr == "max_pool2d" and len(ordered) >= 3 and ordered[2] is None: - ordered[2] = ordered[1] - else: - try: - signature = inspect.signature(original) - except Exception: - signature = None - - if signature is not None: - bound = signature.bind_partial(*args, **kwargs) - bound.apply_defaults() - ordered = [bound.arguments.get(name) for name in signature.parameters.keys()] - else: - ordered = list(args) - - if not ordered: - ordered = list(args) - - launch_arity = _launch_arity(str(kernel_cu), ext) - try: - ext_arity = len(inspect.signature(ext.launch).parameters) - except Exception: - ext_arity = None - - if launch_arity is None: - limit = ext_arity if ext_arity is not None else len(ordered) - elif ext_arity is None: - limit = launch_arity - else: - limit = max(launch_arity, ext_arity) - - launch_args: list[Any] = [] - for value in ordered[:limit]: - if isinstance(value, torch.Tensor): - launch_args.append(value.contiguous()) - else: - launch_args.append(value) - return launch_args - - -def _benchmark_callable( - fn, - *, - device: str, - warmup_runs: int = 3, - timed_runs: int = 10, -) -> float: - with torch.no_grad(): - if device == "cuda": - for _ in range(warmup_runs): - fn() - torch.cuda.synchronize() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for _ in range(timed_runs): - fn() - end_event.record() - torch.cuda.synchronize() - return start_event.elapsed_time(end_event) / timed_runs - - for _ in range(warmup_runs): - fn() - start = time.perf_counter() - for _ in range(timed_runs): - fn() - return (time.perf_counter() - start) / timed_runs * 1000.0 - - -def _benchmark_signature(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: - parts: list[str] = [] - for value in args: - if isinstance(value, torch.Tensor): - parts.append(f"Tensor{tuple(value.shape)}:{value.dtype}") - else: - parts.append(repr(value)) - for key, value in kwargs.items(): - if isinstance(value, torch.Tensor): - parts.append(f"{key}=Tensor{tuple(value.shape)}:{value.dtype}") - else: - parts.append(f"{key}={value!r}") - return ", ".join(parts) - - -def _assert_outputs_close(op_name: str, torch_output: Any, kernel_output: Any) -> None: - assert isinstance(torch_output, torch.Tensor), f"{op_name}: expected tensor output from PyTorch" - assert isinstance(kernel_output, torch.Tensor), f"{op_name}: expected tensor output from custom kernel" - assert torch_output.shape == kernel_output.shape, f"{op_name}: output shapes differ" - - atol = 1e-3 if torch_output.dtype in (torch.float16, torch.bfloat16) else 1e-4 - rtol = 1e-3 if torch_output.dtype in (torch.float16, torch.bfloat16) else 1e-4 - torch.testing.assert_close(torch_output, kernel_output, atol=atol, rtol=rtol) - - -def test_cast_kernel_benchmarks_vs_pytorch() -> None: - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for kernel-vs-PyTorch benchmarking.") - - cast_path = _find_resnet50_cast() - cache_dir, manifest = _extract_cast(cast_path) - - device = "cuda" - pixel_values = torch.randn(1, 3, 224, 224, device=device) - model = kernelforge.load(str(cast_path), device=device, no_kernels=True) - - benchable_ops = [ - op for op in manifest["ops"] - if op["name"].startswith(_F_PREFIX) and op["name"] not in _DEFAULT_SKIP_OPS - ] - fn_attrs = [op["name"][len(_F_PREFIX):] for op in benchable_ops] - - captures, originals = _capture_functional_calls(model, pixel_values, fn_attrs) - - results: list[dict[str, Any]] = [] - for op in benchable_ops: - op_name = op["name"] - fn_attr = op_name[len(_F_PREFIX):] - capture = captures.get(fn_attr) - original = originals.get(fn_attr) - if capture is None or original is None: - continue - - ext, kernel_cu = _load_extension(cache_dir, op, opt_level="-O3") - launch_args = _prepare_launch_args( - fn_attr, - original, - ext, - kernel_cu, - capture["args"], - capture["kwargs"], - ) - - with torch.no_grad(): - torch_output = original(*capture["args"], **capture["kwargs"]) - kernel_output = ext.launch(*launch_args) - if device == "cuda": - torch.cuda.synchronize() - - _assert_outputs_close(op_name, torch_output, kernel_output) - - torch_ms = _benchmark_callable( - lambda: original(*capture["args"], **capture["kwargs"]), - device=device, - ) - kernel_ms = _benchmark_callable( - lambda: ext.launch(*launch_args), - device=device, - ) - - results.append( - { - "op_name": op_name, - "signature": _benchmark_signature(capture["args"], capture["kwargs"]), - "torch_ms": torch_ms, - "kernel_ms": kernel_ms, - "speedup": torch_ms / kernel_ms if kernel_ms else float("inf"), - } - ) - - assert results, "No benchmarkable optimized kernels were found in the .cast file." - - print("\nKernel benchmark results:") - for result in results: - print( - f" {result['op_name']:<40}" - f"kernel {result['kernel_ms']:>8.3f} ms | " - f"torch {result['torch_ms']:>8.3f} ms | " - f"speedup {result['speedup']:>6.2f}x" - ) - print(f" sample: {result['signature']}") diff --git a/tests/test_load_cast_resnet50.py b/tests/test_load_cast_resnet50.py deleted file mode 100644 index 03aa3e47..00000000 --- a/tests/test_load_cast_resnet50.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import annotations - -import os -import sys -import time -from pathlib import Path - -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -torch = pytest.importorskip("torch") - -import kernelforge - - -def _candidate_cast_paths() -> list[Path]: - candidates: list[Path] = [] - - env_path = os.getenv("RESNET50_CAST_PATH") - if env_path: - candidates.append(Path(env_path).expanduser()) - - default_export = ( - REPO_ROOT - / "kernels" - / "projects" - / "oioioio - RTX 3050 Laptop GPU" - / "exports" - / "oioioio - RTX 3050 Laptop GPU.cast" - ) - candidates.append(default_export) - - exports_root = REPO_ROOT / "kernels" / "projects" - candidates.extend(sorted(exports_root.glob("*/exports/*.cast"))) - - unique: list[Path] = [] - seen: set[Path] = set() - for path in candidates: - resolved = path.resolve(strict=False) - if resolved in seen: - continue - seen.add(resolved) - unique.append(path) - return unique - - -def _find_resnet50_cast() -> Path: - for path in _candidate_cast_paths(): - if path.exists(): - return path - pytest.skip("No ResNet-50 .cast export found. Set RESNET50_CAST_PATH to a cast file.") - - -def _run_once(model: torch.nn.Module, pixel_values: torch.Tensor) -> object: - device = pixel_values.device.type - - output = model.run(pixel_values=pixel_values) - if device == "cuda": - torch.cuda.synchronize() - return output - - -def _benchmark_models_alternating( - pytorch_model: torch.nn.Module, - optimized_model: torch.nn.Module, - pixel_values: torch.Tensor, - *, - warmup_runs: int = 3, - timed_runs: int = 5, -) -> tuple[object, float, object, float]: - pytorch_total_ms = 0.0 - optimized_total_ms = 0.0 - - with torch.inference_mode(): - for _ in range(warmup_runs): - _run_once(pytorch_model, pixel_values) - _run_once(optimized_model, pixel_values) - - pytorch_output = None - optimized_output = None - for _ in range(timed_runs): - start = time.perf_counter() - pytorch_output = _run_once(pytorch_model, pixel_values) - pytorch_total_ms += (time.perf_counter() - start) * 1000.0 - - start = time.perf_counter() - optimized_output = _run_once(optimized_model, pixel_values) - optimized_total_ms += (time.perf_counter() - start) * 1000.0 - - return ( - pytorch_output, - pytorch_total_ms / timed_runs, - optimized_output, - optimized_total_ms / timed_runs, - ) - - -def test_load_cast_resnet50_benchmark() -> None: - cast_path = _find_resnet50_cast() - device = "cuda" if torch.cuda.is_available() else "cpu" - pixel_values = torch.randn(1, 3, 224, 224, device=device) - - optimized_model = kernelforge.load(str(cast_path), device=device, opt_level="-O3") - pytorch_model = kernelforge.load(str(cast_path), device=device, no_kernels=True) - - assert isinstance(optimized_model, torch.nn.Module) - assert isinstance(pytorch_model, torch.nn.Module) - assert optimized_model.training is False - assert pytorch_model.training is False - assert hasattr(optimized_model, "run") - assert hasattr(pytorch_model, "run") - - pytorch_output, pytorch_ms, optimized_output, optimized_ms = _benchmark_models_alternating( - pytorch_model, - optimized_model, - pixel_values, - warmup_runs=3, - timed_runs=5, - ) - - optimized_logits = optimized_output.logits if hasattr(optimized_output, "logits") else optimized_output - pytorch_logits = pytorch_output.logits if hasattr(pytorch_output, "logits") else pytorch_output - - assert isinstance(optimized_logits, torch.Tensor) - assert isinstance(pytorch_logits, torch.Tensor) - assert optimized_logits.device.type == device - assert pytorch_logits.device.type == device - assert optimized_logits.shape == pytorch_logits.shape - assert optimized_logits.shape[0] == 1 - assert optimized_logits.ndim == 2 - assert optimized_logits.shape[1] > 0 - - print(f"\n.cast runtime average latency : {optimized_ms:.2f} ms") - print(f"PyTorch average latency : {pytorch_ms:.2f} ms")