diff --git a/kernelforge/__init__.py b/kernelforge/__init__.py new file mode 100644 index 0000000..ce8c7fb --- /dev/null +++ b/kernelforge/__init__.py @@ -0,0 +1,25 @@ +"""KernelForge runtime loading helpers.""" + +from __future__ import annotations + +from run_cast import CastModelRuntime, load_cast + + +def load( + cast_path: str, + *, + device: str | None = None, + model_args: dict | None = None, + no_kernels: bool = False, + opt_level: str = "-O3", +) -> CastModelRuntime: + return load_cast( + cast_path, + model_args=model_args, + no_kernels=no_kernels, + opt_level=opt_level, + device=device, + ) + + +__all__ = ["CastModelRuntime", "load", "load_cast"] diff --git a/run_cast.py b/run_cast.py index 1d3db53..ab07e0b 100755 --- a/run_cast.py +++ b/run_cast.py @@ -1,14 +1,21 @@ #!/usr/bin/env python3 -"""Run a .cast inference package produced by KernelForge.""" +"""Load and run a .cast inference package produced by KernelForge.""" import argparse +import contextlib +import glob import hashlib import inspect import json import os import re +import shutil +import threading import time import zipfile +from typing import Any, Callable + +import torch.nn as nn def verify_checksums(zf: zipfile.ZipFile) -> None: @@ -34,9 +41,68 @@ def verify_checksums(zf: zipfile.ZipFile) -> None: raise RuntimeError(f"Checksum mismatch for {rel_path}") -def compile_kernel(kernel_cu_path: str, op_name: str, build_dir: str, opt_level: str = "-O0"): +def ensure_cuda_toolkit_env() -> str | None: + """Point CUDA_HOME/CUDACXX/PATH and torch cpp_extension at a working nvcc.""" + + candidates: list[str] = [] + + cuda_home = os.getenv("CUDA_HOME") + if cuda_home: + candidates.append(os.path.abspath(os.path.expanduser(cuda_home))) + + cudacxx = os.getenv("CUDACXX") + if cudacxx: + candidates.append(os.path.dirname(os.path.dirname(os.path.abspath(os.path.expanduser(cudacxx))))) + + nvcc_on_path = shutil.which("nvcc") + if nvcc_on_path: + candidates.append(os.path.dirname(os.path.dirname(os.path.abspath(nvcc_on_path)))) + + candidates.append("/usr/local/cuda") + candidates.extend(sorted(glob.glob("/usr/local/cuda-*"), reverse=True)) + + seen: set[str] = set() + for candidate in candidates: + resolved = os.path.abspath(candidate) + if resolved in seen: + continue + seen.add(resolved) + + nvcc = os.path.join(resolved, "bin", "nvcc") + if not os.path.exists(nvcc): + continue + + os.environ["CUDA_HOME"] = resolved + os.environ["CUDACXX"] = nvcc + + path_entries = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else [] + nvcc_dir = os.path.dirname(nvcc) + if nvcc_dir not in path_entries: + os.environ["PATH"] = nvcc_dir if not path_entries else nvcc_dir + os.pathsep + os.environ["PATH"] + + try: + import torch.utils.cpp_extension as cpp_extension + + cpp_extension.CUDA_HOME = resolved + except Exception: + pass + + return resolved + + return None + + +def compile_kernel(kernel_cu_path: str, op_name: str, build_dir: str, opt_level: str = "-O3"): import re - from torch.utils.cpp_extension import load_inline + + cuda_home = ensure_cuda_toolkit_env() + if cuda_home is None: + raise RuntimeError("No CUDA toolkit with nvcc was found for JIT compilation.") + + import torch.utils.cpp_extension as cpp_extension + + cpp_extension.CUDA_HOME = cuda_home + load_inline = cpp_extension.load_inline with open(kernel_cu_path) as f: cuda_src = f.read() @@ -60,10 +126,218 @@ def compile_kernel(kernel_cu_path: str, op_name: str, build_dir: str, opt_level: ) -def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = False, opt_level: str = "-O0"): - import torch +_PATCH_STATE = threading.local() +_ORIGINAL_FUNCTIONALS: dict[str, Callable[..., Any]] = {} +_ATEN_FALLBACK_CALL_RE = re.compile( + r"\bat::(?:adaptive_avg_pool\d*d|batch_norm|conv\d*d|linear|max_pool\d*d|relu)\s*\(" +) +_KNOWN_FAST_OPS = { + "torch_nn_functional_linear", + "torch_nn_functional_max_pool2d", + "torch_nn_functional_relu", + "torch_nn_functional_batch_norm", +} +_FOCUS_OPS = { + "torch_nn_functional_linear", + "torch_nn_functional_max_pool2d", + "torch_nn_functional_batch_norm", +} +_KERNEL_POLICIES = ("all", "skip_aten", "known_fast", "focus_ops") + + +def _patch_stack() -> list[dict[str, Callable[..., Any]]]: + stack = getattr(_PATCH_STATE, "stack", None) + if stack is None: + stack = [] + _PATCH_STATE.stack = stack + return stack + + +@contextlib.contextmanager +def _activate_functional_patches( + patch_map: dict[str, Callable[..., Any]], +): + stack = _patch_stack() + stack.append(patch_map) + try: + yield + finally: + stack.pop() + + +def _ensure_functional_dispatch(fn_attr: str) -> Callable[..., Any] | None: import torch.nn.functional as F + current = getattr(F, fn_attr, None) + if current is None: + return None + + original = _ORIGINAL_FUNCTIONALS.get(fn_attr) + if original is not None: + return original + + original = current + _ORIGINAL_FUNCTIONALS[fn_attr] = original + + def dispatched(*args, **kwargs): + stack = getattr(_PATCH_STATE, "stack", None) + if stack: + patch = stack[-1].get(fn_attr) + if patch is not None: + return patch(*args, **kwargs) + return original(*args, **kwargs) + + setattr(F, fn_attr, dispatched) + return original + + +def _launch_arity(kernel_cu: str, ext: Any) -> int | None: + n_launch = None + if os.path.exists(kernel_cu): + try: + with open(kernel_cu) as handle: + cu_src = handle.read() + match = re.search(r"torch::Tensor\s+launch\s*\(([^)]*)\)", cu_src) + if match: + params = [part.strip() for part in match.group(1).split(",") if part.strip()] + n_launch = len(params) + except Exception: + pass + if n_launch is None: + try: + n_launch = len(inspect.signature(ext.launch).parameters) + except Exception: + pass + return n_launch + + +def _kernel_calls_aten_fallback(kernel_cu: str) -> bool: + if not os.path.exists(kernel_cu): + return False + try: + with open(kernel_cu) as handle: + cuda_src = handle.read() + except Exception: + return False + return bool(_ATEN_FALLBACK_CALL_RE.search(cuda_src)) + + +def _kernel_policy_skip_reason(op_name: str, kernel_cu: str, kernel_policy: str) -> str | None: + if kernel_policy == "all": + return None + + uses_aten_fallback = _kernel_calls_aten_fallback(kernel_cu) + + if kernel_policy == "skip_aten": + if uses_aten_fallback: + return "kernel source falls back to ATen" + return None + + if kernel_policy == "known_fast": + if uses_aten_fallback: + return "kernel source falls back to ATen" + if op_name not in _KNOWN_FAST_OPS: + return "op is not in the known-fast allowlist" + return None + + if kernel_policy == "focus_ops": + if uses_aten_fallback: + return "kernel source falls back to ATen" + if op_name not in _FOCUS_OPS: + return "op is not in the focus-ops allowlist" + return None + + raise ValueError(f"Unknown kernel policy: {kernel_policy}") + + +def _build_functional_patch( + *, + op_name: str, + ext: Any, + orig_fn: Callable[..., Any], + n_launch: int | None, + orig_params: list[str] | None, +) -> Callable[..., Any]: + import torch + + def patched(*args, **kwargs): + try: + if orig_params is not None: + resolved = {orig_params[i]: value for i, value in enumerate(args) if i < len(orig_params)} + resolved.update(kwargs) + ordered = [resolved.get(name) for name in orig_params] + else: + ordered = list(args) + + limit = n_launch if n_launch is not None else len(ordered) + call_args: list[Any] = [] + tensor_args: list[torch.Tensor] = [] + for value in ordered[:limit]: + if isinstance(value, torch.Tensor): + tensor_args.append(value) + call_args.append(value.contiguous()) + else: + call_args.append(value) + + if not tensor_args: + return orig_fn(*args, **kwargs) + + if any(not tensor.is_cuda for tensor in tensor_args): + return orig_fn(*args, **kwargs) + + first_device = tensor_args[0].device + if any(tensor.device != first_device for tensor in tensor_args): + return orig_fn(*args, **kwargs) + + return ext.launch(*call_args) + except Exception: + return orig_fn(*args, **kwargs) + + patched.__name__ = f"cast_patch_{op_name}" + return patched + + +class CastModelRuntime(nn.Module): + """A thin nn.Module wrapper that activates cast kernel patches per forward.""" + + def __init__(self, model, functional_patches: dict[str, Callable[..., Any]]): + import torch.nn as nn + + if not isinstance(model, nn.Module): + raise TypeError("model must be an nn.Module") + super().__init__() + self.model = model + self._cast_functional_patches = functional_patches + + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError: + model = self._modules.get("model") + if model is None: + raise + return getattr(model, name) + + def forward(self, *args, **kwargs): + if not self._cast_functional_patches: + return self.model(*args, **kwargs) + with _activate_functional_patches(self._cast_functional_patches): + return self.model(*args, **kwargs) + + def run(self, *args, **kwargs): + return self(*args, **kwargs) + + +def load_cast( + cast_path: str, + model_args: dict | None = None, + no_kernels: bool = False, + opt_level: str = "-O3", + device: str | None = None, + kernel_policy: str = "all", +): + import torch + cast_path = os.path.abspath(cast_path) cache_key = hashlib.sha256(open(cast_path, "rb").read()).hexdigest() cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "cast", cache_key) @@ -97,6 +371,12 @@ def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = os.makedirs(build_dir, exist_ok=True) _F_PREFIX = "torch_nn_functional_" + functional_patches: dict[str, Callable[..., Any]] = {} + + if kernel_policy not in _KERNEL_POLICIES: + raise ValueError( + f"Unknown kernel policy '{kernel_policy}'. Expected one of: {', '.join(_KERNEL_POLICIES)}" + ) for op in manifest["ops"]: op_name = op["name"] @@ -110,27 +390,40 @@ def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = print(f" [WARN] CUDA not available — skipping kernel for {op_name}") continue + skip_reason = _kernel_policy_skip_reason(op_name, kernel_cu, kernel_policy) + if skip_reason: + print(f" [kernel-policy={kernel_policy}] Skipping kernel for {op_name}: {skip_reason}") + continue + # Try precompiled .so for the current GPU first gpu_sm = "sm_{0}{1}".format(*torch.cuda.get_device_capability()) precompiled = op.get("precompiled", {}) so_rel = precompiled.get(gpu_sm) so_path = os.path.join(cache_dir, so_rel) if so_rel else None + ext = None if so_path and os.path.exists(so_path): - import importlib.util as _ilu - _spec = _ilu.spec_from_file_location(op_name, so_path) - ext = _ilu.module_from_spec(_spec) - _spec.loader.exec_module(ext) # type: ignore[union-attr] - print(f" Loaded precompiled {op_name} ({gpu_sm})") - else: + try: + import importlib.util as _ilu + + _spec = _ilu.spec_from_file_location(op_name, so_path) + ext = _ilu.module_from_spec(_spec) + _spec.loader.exec_module(ext) # type: ignore[union-attr] + print(f" Loaded precompiled {op_name} ({gpu_sm})") + except Exception as exc: + print(f" [WARN] Failed to load precompiled {op_name} ({gpu_sm}): {exc}") + + if ext is None: if so_rel: print(f" [WARN] Precompiled .so not found for {gpu_sm}, falling back to JIT") if not os.path.exists(kernel_cu): print(f" [WARN] No kernel.cu for {op_name}, skipping") continue - if "CUDA_HOME" not in os.environ: - os.environ["CUDA_HOME"] = "/usr/local/cuda-12.1" - ext = compile_kernel(kernel_cu, op_name, build_dir, opt_level=opt_level) + try: + ext = compile_kernel(kernel_cu, op_name, build_dir, opt_level=opt_level) + except Exception as exc: + print(f" [WARN] Failed to prepare kernel for {op_name}, using native PyTorch: {exc}") + continue # Generic patch: decode torch.nn.functional. from the op_name convention. if not op_name.startswith(_F_PREFIX): @@ -138,67 +431,32 @@ def load_cast(cast_path: str, model_args: dict | None = None, no_kernels: bool = continue fn_attr = op_name[len(_F_PREFIX):] - orig_fn = getattr(F, fn_attr, None) - if orig_fn is None: + original = _ensure_functional_dispatch(fn_attr) + if original is None: print(f" [WARN] torch.nn.functional.{fn_attr} not found, skipping patch") continue - # Determine how many args launch() expects by parsing the kernel source. - # This is more reliable than inspect.signature on C extensions. - n_launch = None - if os.path.exists(kernel_cu): - try: - with open(kernel_cu) as _f: - _cu_src = _f.read() - _m = re.search(r"torch::Tensor\s+launch\s*\(([^)]*)\)", _cu_src) - if _m: - _params = [p.strip() for p in _m.group(1).split(",") if p.strip()] - n_launch = len(_params) - except Exception: - pass - if n_launch is None: - try: - n_launch = len(inspect.signature(ext.launch).parameters) - except Exception: - pass + n_launch = _launch_arity(kernel_cu, ext) # Resolve the original function's parameter names so we can handle # both positional and keyword call sites correctly. try: - orig_params = list(inspect.signature(orig_fn).parameters.keys()) + orig_params = list(inspect.signature(original).parameters.keys()) except Exception: orig_params = None - def _make_patch(ext=ext, orig=orig_fn, n=n_launch, params=orig_params): - def patched(*args, **kwargs): - try: - # Resolve all argument values by name, handling mixed positional/keyword callers. - if params is not None: - resolved = {params[i]: v for i, v in enumerate(args) if i < len(params)} - resolved.update(kwargs) - ordered = [resolved.get(p) for p in params] - else: - ordered = list(args) - # Truncate to what launch() expects, move tensors to CUDA. - limit = n if n is not None else len(ordered) - call_args = [] - for val in ordered[:limit]: - if isinstance(val, torch.Tensor): - call_args.append(val.cuda().contiguous()) - else: - call_args.append(val) - return ext.launch(*call_args) - except Exception: - return orig(*args, **kwargs) - return patched - - setattr(F, fn_attr, _make_patch()) - print(f" Patched torch.nn.functional.{fn_attr} → {op_name}") + functional_patches[fn_attr] = _build_functional_patch( + op_name=op_name, + ext=ext, + orig_fn=original, + n_launch=n_launch, + orig_params=orig_params, + ) + print(f" Registered runtime patch torch.nn.functional.{fn_attr} → {op_name}") # 5. Load model class from model.py import importlib.util import sys - import torch.nn as nn model_py = os.path.join(cache_dir, "model.py") spec = importlib.util.spec_from_file_location("cast_model", model_py) @@ -273,10 +531,53 @@ def patched(*args, **kwargs): model.load_state_dict(state_dict) model.eval() - return model + runtime_model = CastModelRuntime(model, functional_patches) + runtime_model.eval() + if device: + runtime_model = runtime_model.to(device) + + return runtime_model + + +def _parse_shape(raw: str) -> tuple[int, ...]: + parts = [part.strip() for part in raw.split(",") if part.strip()] + if not parts: + raise ValueError("input shape must contain at least one dimension") + dims = tuple(int(part) for part in parts) + if any(dim <= 0 for dim in dims): + raise ValueError("input shape dimensions must be positive integers") + return dims + + +def _benchmark_loaded_model(model, device: str, runs: int, input_shape: tuple[int, ...]) -> None: + import torch + + dummy = torch.randn(*input_shape, device=device) + print(f"Running {runs} inference pass(es) with input shape {list(dummy.shape)} ...") + + with torch.inference_mode(): + _ = model(dummy) + if device == "cuda": + torch.cuda.synchronize() + + t0 = time.perf_counter() + for _ in range(runs): + out = model(dummy) + if device == "cuda": + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) / runs * 1000 + + logits = out.logits if hasattr(out, "logits") else out + if hasattr(logits, "shape"): + print(f"Output shape : {list(logits.shape)}") + print(f"Average latency : {elapsed_ms:.2f} ms") + if getattr(logits, "ndim", 0) >= 2 and logits.shape[0] >= 1: + top5 = logits[0].topk(min(5, logits.shape[-1])) + print(f"Top-5 indices : {top5.indices.tolist()}") + print(f"Top-5 scores : {[f'{v:.4f}' for v in top5.values.tolist()]}") -def main() -> None: +def main(*, default_kernel_policy: str = "all") -> None: import torch parser = argparse.ArgumentParser(description="Run a KernelForge .cast inference package") @@ -286,7 +587,6 @@ def main() -> None: default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on (default: cuda if available)", ) - parser.add_argument("--runs", type=int, default=5, help="Inference passes for timing") parser.add_argument( "--model-args", metavar="JSON", @@ -301,9 +601,32 @@ def main() -> None: ) parser.add_argument( "--opt-level", - default="-O0", + default="-O3", choices=["-O0", "-O1", "-O2", "-O3"], - help="NVCC optimisation level for JIT compilation (default: -O0).", + help="NVCC optimisation level for JIT compilation (default: -O3).", + ) + parser.add_argument( + "--kernel-policy", + default=default_kernel_policy, + choices=list(_KERNEL_POLICIES), + help=( + "Kernel selection policy: " + "'all' loads every exported kernel, " + "'skip_aten' skips kernels that just call back into ATen, " + "'known_fast' keeps only the current known-fast allowlist, " + "'focus_ops' keeps only the smallest fast-op subset." + ), + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Run a dummy-input benchmark after loading the model.", + ) + parser.add_argument("--runs", type=int, default=5, help="Inference passes for timing with --benchmark") + parser.add_argument( + "--input-shape", + default="1,3,224,224", + help="Comma-separated dummy input shape for --benchmark (default: 1,3,224,224).", ) args = parser.parse_args() @@ -313,33 +636,16 @@ def main() -> None: model_args=extra_model_args, no_kernels=args.no_kernels, opt_level=args.opt_level, + device=args.device, + kernel_policy=args.kernel_policy, ) - device = args.device - model = model.to(device) - print(f"\nModel ready on {device}") - - dummy = torch.randn(1, 3, 224, 224, device=device) - print(f"Running {args.runs} inference pass(es) with input shape {list(dummy.shape)} ...") + print(f"\nModel ready on {args.device}") + print(f"Kernel policy : {args.kernel_policy}") + print("Import load_cast(...) in production code to load a .cast as a normal model object.") - with torch.no_grad(): - # warmup - _ = model(dummy) - if device == "cuda": - torch.cuda.synchronize() - - t0 = time.perf_counter() - for _ in range(args.runs): - out = model(dummy) - if device == "cuda": - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) / args.runs * 1000 - - logits = out.logits if hasattr(out, "logits") else out - print(f"Output shape : {list(logits.shape)}") - print(f"Average latency : {elapsed_ms:.2f} ms") - top5 = logits[0].topk(5) - print(f"Top-5 indices : {top5.indices.tolist()}") - print(f"Top-5 scores : {[f'{v:.4f}' for v in top5.values.tolist()]}") + if args.benchmark: + input_shape = _parse_shape(args.input_shape) + _benchmark_loaded_model(model, args.device, args.runs, input_shape) if __name__ == "__main__":