diff --git a/magi_compiler/_api.py b/magi_compiler/_api.py index 2f51eec..6304de8 100644 --- a/magi_compiler/_api.py +++ b/magi_compiler/_api.py @@ -73,7 +73,19 @@ def _isolated_dynamo_config(): yield -def _run_orchestration(state: MagiCompileState, original_invoker, args, kwargs): +def get_attr_name_for_wrapper_installed_flag() -> str: + return "_magi_wrapper_installed" + + +def get_attr_name_for_default_state() -> str: + return "_magi" + + +def get_attr_name_for_method_state(method_name: str) -> str: + return f"_magi_state_{method_name}" + + +def _run_orchestration(state: MagiCompileState, args, kwargs): """ Central orchestration logic for magi_compile. @@ -86,19 +98,15 @@ def _run_orchestration(state: MagiCompileState, original_invoker, args, kwargs): - (Optional) Perform AOT compilation and save artifacts. """ # JIT Fast Path - if state.compiled_code is not None: - with state.dispatch_to_compiled_fwd(mode="jit"): - return original_invoker() + if state.jit_compiled_code is not None: + with state.dispatch_to_compiled_fwd(mode="jit") as compiled_runtime_invoker: + return compiled_runtime_invoker(*args, **kwargs) # AOT Fast Path if state.compile_config.aot: - if state._aot_compiled_fn or state.load_aot_compile_artifacts(): - res = state.dispatch_to_compiled_fwd(mode="aot") - if isinstance(state.obj, nn.Module): - with res: - return original_invoker() - with res as compiled_fn: - return compiled_fn(*args, **kwargs) + if state.aot_compiled_fn or state.load_aot_compile_artifacts(): + with state.dispatch_to_compiled_fwd(mode="aot") as compiled_runtime_invoker: + return compiled_runtime_invoker(*args, **kwargs) # First compilation state._ensure_compiled() @@ -106,8 +114,8 @@ def _run_orchestration(state: MagiCompileState, original_invoker, args, kwargs): # Mark dynamic and static shapes _apply_shape_marks(state, args, kwargs) - magi_logger.info(f"Start compiling function {state.original_code_object}") - torch._dynamo.eval_frame.remove_from_cache(state.original_code_object) + magi_logger.info(f"Start compiling function {state.original_code_for_hook}") + torch._dynamo.eval_frame.remove_from_cache(state.original_code_for_hook) CompileMonitor().start() try: @@ -115,147 +123,138 @@ def _run_orchestration(state: MagiCompileState, original_invoker, args, kwargs): with _compilation_context(state): state.aot_compile(*args, **kwargs) state.save_aot_compile_artifacts() - res = state.dispatch_to_compiled_fwd(mode="aot") - if isinstance(state.obj, nn.Module): - with res: - return original_invoker() - with res as compiled_fn: - return compiled_fn(*args, **kwargs) + with state.dispatch_to_compiled_fwd(mode="aot") as compiled_runtime_invoker: + return compiled_runtime_invoker(*args, **kwargs) else: with _compilation_context(state): # For JIT, we need to capture bytecode. - with state._capture_compiled_bytecode(): - if isinstance(state.obj, nn.Module): - with patch.object(state.obj, "forward", state._compiled_callable): - return original_invoker() - else: - return state._compiled_callable(*args, **kwargs) + with state._jit_capture_compiled_bytecode(): + return state.compiled_entry(*args, **kwargs) finally: CompileMonitor().end() state.traced_files.clear() def _lazy_init_magi_state( - target_obj: object, - base_obj: nn.Module | Callable, + state_holder: object, + compile_obj: object, dynamic_arg_dims: dict[str, int | list[int]] | None, enable_if: Callable[[], bool] | None, config_patch: Callable[[CompileConfig], CompileConfig], model_tag: str | None, + *, + target_method_name: str | None = None, + state_attr: str | None = None, ): - """ - Lazily initializes the MagiCompileState and attaches it to `target_obj._magi`. - """ - if hasattr(target_obj, "_magi"): + """Lazily initialize MagiCompileState and attach it on ``state_attr``.""" + state_attr = state_attr or get_attr_name_for_default_state() + if getattr(state_holder, state_attr, None) is not None: return conf = config_patch(copy.deepcopy(get_compile_config())) enable = enable_if is None or enable_if() if conf.compile_mode == CompileMode.NONE or not enable: - target_obj._magi = None return compilation_counter.num_models_seen += 1 # Infer default model tag if not provided if model_tag is None: - if hasattr(base_obj, "__class__") and isinstance(base_obj, nn.Module): - model_tag = base_obj.__class__.__name__ - else: - model_tag = getattr(base_obj, "__name__", "unknown_func") - - target_obj._magi = MagiCompileState( - base_obj, conf, model_idx=compilation_counter.num_models_seen, model_tag=model_tag, dynamic_arg_dims=dynamic_arg_dims + model_tag = getattr(compile_obj, "__name__", compile_obj.__class__.__name__) + + setattr( + state_holder, + state_attr, + MagiCompileState( + compile_obj, + conf, + model_idx=compilation_counter.num_models_seen, + model_tag=model_tag, + dynamic_arg_dims=dynamic_arg_dims, + target_method_name=target_method_name, + ), ) def _magi_compile_class( - cls, + cls: type, dynamic_arg_dims: dict[str, int | list[int]], enable_if: Callable[[], bool] | None, - config_patch: Callable[[CompileConfig], CompileConfig] | None, + config_patch: Callable[[CompileConfig], CompileConfig], model_tag: str | None, + method_name: str, ): - """Class-level decoration: mutates ``cls.__call__`` so every instance is compiled. + """Install class-level lazy compilation for ``method_name``. - MagiCompileState is created **lazily** on first ``__call__`` via ``_lazy_init_magi_state``, - because at decoration time no instance exists yet, and the global CompileConfig may - not be finalized (e.g. env-vars set after import but before first forward). + This wraps ``cls.__init__`` so every new instance is patched by + ``_magi_compile_bound_method`` after initialization. """ - if getattr(cls, "_magi_compiled", False): + compile_flag_attr = get_attr_name_for_wrapper_installed_flag() + if getattr(cls, compile_flag_attr, False): return cls - config_patch = config_patch or (lambda x: x) - if config_patch(copy.deepcopy(get_compile_config())).offload_config.model_cpu_offload: - _patch_cpu_offload_apply(cls) - - old_call = cls.__call__ - - @torch.compiler.disable() - def wrapper(self, *args, **kwargs): - _lazy_init_magi_state(self, self, dynamic_arg_dims, enable_if, config_patch, model_tag) - state = self._magi + if not callable(getattr(cls, method_name, None)): + raise AttributeError(f"{cls.__name__} has no callable method '{method_name}'") - # Offload arguments if offload is enabled and not yet compiled - if state is not None and state.compile_config.offload_config.model_cpu_offload and state.compiled_code is None: - args = offload(args) - kwargs = offload(kwargs) + if issubclass(cls, nn.Module) and config_patch(copy.deepcopy(get_compile_config())).offload_config.model_cpu_offload: + _patch_cpu_offload_apply(cls) - if state is None or torch.compiler.is_compiling(): - return old_call(self, *args, **kwargs) + old_init = cls.__init__ - with _isolated_dynamo_config(): - return _run_orchestration(state, lambda: old_call(self, *args, **kwargs), args, kwargs) + @functools.wraps(old_init) + def wrapped_init(self, *args, **kwargs): + old_init(self, *args, **kwargs) + _magi_compile_bound_method(self, dynamic_arg_dims, enable_if, config_patch, model_tag, method_name=method_name) - cls.__call__ = wrapper - cls._magi_compiled = True + cls.__init__ = wrapped_init + setattr(cls, compile_flag_attr, True) return cls -def _magi_compile_instance( - module: nn.Module, +def _magi_compile_bound_method( + instance: object, dynamic_arg_dims: dict[str, int | list[int]], enable_if: Callable[[], bool] | None, - config_patch: Callable[[CompileConfig], CompileConfig] | None, + config_patch: Callable[[CompileConfig], CompileConfig], model_tag: str | None, + method_name: str, ): - """Instance-level decoration: only this instance is compiled, class is untouched. + """Patch one instance method with lazy state initialization and compiled routing.""" + if not callable(getattr(instance, method_name, None)): + raise AttributeError(f"{instance.__class__.__name__} instance has no callable method '{method_name}'") - MagiCompileState is created **lazily** on first ``forward`` call via ``_lazy_init_magi_state``. - A compiled ``forward`` is installed as an instance attribute so ``Module.__call__`` → - ``self.forward()`` resolves to it, while ``module.__class__.forward`` remains original. + state_attr = get_attr_name_for_method_state(method_name) + if getattr(instance, state_attr, None) is not None: + return instance - Call flow:: - - module(x) - → Module.__call__ (hooks, FSDP, etc.) - → self.forward(x) (finds instance attr → _compiled_forward) - → _run_orchestration - → module.__class__.forward(module, x) # original or bytecode-swapped - """ - if getattr(module, "_magi", None) is not None: - return module - - config_patch = config_patch or (lambda x: x) - - # module.__class__.forward is the unbound class method — never affected by our - # instance-level override, so calling it goes straight to original forward logic. - old_call = module.__class__.forward - module._magi_original_forward = module.forward + old_method = getattr(instance, method_name) @torch.compiler.disable() def new_call(*args, **kwargs): - _lazy_init_magi_state(module, module, dynamic_arg_dims, enable_if, config_patch, model_tag) - state = module._magi + state = getattr(instance, state_attr, None) + + if state is None: + _lazy_init_magi_state( + instance, + instance, + dynamic_arg_dims, + enable_if, + config_patch, + model_tag, + target_method_name=method_name, + state_attr=state_attr, + ) + state = getattr(instance, state_attr, None) + if state is None or torch.compiler.is_compiling(): - return old_call(module, *args, **kwargs) + return old_method(*args, **kwargs) with _isolated_dynamo_config(): - return _run_orchestration(state, lambda: module.__class__.__call__(module, *args, **kwargs), args, kwargs) + return _run_orchestration(state, args, kwargs) - module.forward = new_call - module._magi_compiled = True - return module + setattr(instance, method_name, new_call) + setattr(instance, get_attr_name_for_wrapper_installed_flag(), True) + return instance def _magi_compile_function( @@ -265,12 +264,13 @@ def _magi_compile_function( config_patch: Callable[[CompileConfig], CompileConfig] | None, model_tag: str | None, ): - """Function / bound-method level decoration. + """Wrap a function entry with lazy ``MagiCompileState`` and compiled routing. - MagiCompileState is created **lazily** on first call via ``_lazy_init_magi_state``. - The wrapper replaces the original callable. + The returned wrapper initializes state on first invocation and then dispatches + through ``_run_orchestration``. """ - if getattr(func, "_magi", None) is not None: + state_attr = get_attr_name_for_default_state() + if getattr(func, state_attr, None) is not None: return func config_patch = config_patch or (lambda x: x) @@ -278,13 +278,16 @@ def _magi_compile_function( @torch.compiler.disable() @functools.wraps(func) # for the original function name and docstring def wrapper(*args, **kwargs): - _lazy_init_magi_state(wrapper, func, dynamic_arg_dims, enable_if, config_patch, model_tag) - state = wrapper._magi + state = getattr(wrapper, state_attr, None) + if state is None: + _lazy_init_magi_state(wrapper, func, dynamic_arg_dims, enable_if, config_patch, model_tag, state_attr=state_attr) + state = getattr(wrapper, state_attr, None) + if state is None or torch.compiler.is_compiling(): return func(*args, **kwargs) with _isolated_dynamo_config(): - return _run_orchestration(state, lambda: func(*args, **kwargs), args, kwargs) + return _run_orchestration(state, args, kwargs) return wrapper @@ -320,7 +323,7 @@ def _apply_shape_marks(state: MagiCompileState, args, kwargs): This is called just before Dynamo tracing to ensure dimensions are correctly generalized in the captured graph. """ - sig = inspect.signature(state._target_callable) + sig = inspect.signature(state.original_entry) bound = sig.bind(*args, **kwargs) bound.apply_defaults() @@ -435,7 +438,7 @@ def _compilation_context(state: MagiCompileState): # 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline # the function by calling InliningInstructionTranslator.inline_call_ def _hijack_inline_call_to_collect_traced_files(state: MagiCompileState): - state.traced_files.add(state.original_code_object.co_filename) + state.traced_files.add(state.original_code_for_hook.co_filename) inline_call = InliningInstructionTranslator.inline_call_ def patched(self_): diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 77fb828..8701f52 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -16,39 +16,43 @@ import inspect from typing import Callable, TypeVar -from torch import nn - from ._api import ( _check_dynamic_arg_dims, _infer_dynamic_arg_dims, + _magi_compile_bound_method, _magi_compile_class, _magi_compile_function, - _magi_compile_instance, ) from ._magi_register_custom_op import _magi_register_custom_op_impl from .config import CompileConfig -_T = TypeVar("_T", bound=type[nn.Module]) +_T = TypeVar("_T", bound=type) _F = TypeVar("_F", bound=Callable) -_M = TypeVar("_M", bound=nn.Module) +_O = TypeVar("_O", bound=object) def magi_compile( - obj: _T | _M | _F | None = None, + obj: _T | _O | _F | None = None, *, model_tag: str | None = None, dynamic_arg_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[], bool] | None = None, config_patch: Callable[[CompileConfig], CompileConfig] | None = None, -) -> _T | _M | _F | Callable[[_T | _M | _F], _T | _M | _F]: + method_name: str | None = None, +) -> _T | _O | _F | Callable[[_T | _O | _F], _T | _O | _F]: """ - Compile target objects (nn.Module classes, modules, functions, or methods). + Compile classes, instances, standalone functions, or bound methods. + + Default compile target when no explicit method is passed: + - ``nn.Module``: compile ``forward``. + - Non-module callable class/instance: compile ``forward`` by default; + if missing, users must pass ``method_name`` explicitly. Supported target types ---------------------- - 1) Class (must be an `nn.Module` subclass): - - Affects all instances of the annotated class. - - Compilation dispatch enters via `__call__`, while compiled execution replaces `forward` logic. + 1) Class: + - Hooks ``__init__`` so every new instance gets the default method compiled (same mechanism for + ``nn.Module`` and non-module callable classes). - Example: @magi_compile class MyModel(nn.Module): @@ -61,16 +65,15 @@ def forward(self, x): return x @magi_compile def my_func(x): return x - 3) Instance (nn.Module): - - Compiles a single instance specifically. - - Avoids affecting other instances by creating an instance-specific subclass. + 3) Instance: + - Compiles only that object’s default method (``forward`` by default, or + explicit ``method_name`` for non-module targets). - Example: model = MyModel() model = magi_compile(model) - 4) Method (Bound/Unbound): - - Wraps a specific function attribute (e.g., `model.forward`). - - Enables focused compilation of specific object behaviors. + 4) Bound method: + - Compiles that method on its ``__self__`` (works for ``nn.Module`` and plain objects). - Example: model = MyModel() model.forward = magi_compile(model.forward) @@ -103,6 +106,9 @@ def forward(self, x): ... - dynamic_arg_dims: Dictionary mapping argument names to dynamic dimensions (int or list[int]). - model_tag: Optional tag for caching path (defaults to class/function name). - enable_if: Callable returning bool; compilation happens only if this returns True. + - method_name: Optional explicit method for class/instance targets. If omitted, + ``forward`` is used by default; for non-module targets without ``forward``, + this argument is required. Notes ----- @@ -117,21 +123,45 @@ def forward(self, x): ... dynamic_arg_dims=dynamic_arg_dims, enable_if=enable_if, config_patch=config_patch, + method_name=method_name, ) + config_patch = config_patch or (lambda x: x) + + is_bound_method = inspect.ismethod(obj) + is_function = inspect.isfunction(obj) + is_class = inspect.isclass(obj) + is_instance = callable(obj) and not any((is_class, is_function, is_bound_method)) + if not any((is_class, is_instance, is_bound_method, is_function)): + raise TypeError(f"Unsupported type for magi_compile: {type(obj)}") + + if method_name is not None and (is_bound_method or is_function): + entry_name = "bound method" if is_bound_method else "function" + raise ValueError(f"method_name cannot be used when compiling a {entry_name} directly") + # 1. Determine target function for dynamic dim inference - if inspect.isclass(obj): - assert issubclass(obj, nn.Module), f"Expected nn.Module subclass, got {obj}" - target_func = obj.forward - context_name = f"forward method of {obj.__name__}" - elif isinstance(obj, nn.Module): - target_func = obj.forward - context_name = f"forward method of instance {obj.__class__.__name__}" - elif callable(obj): + owner_instance = obj.__self__ if is_bound_method else obj if is_instance else None + owner_class = obj if is_class else owner_instance.__class__ if is_bound_method else obj.__class__ if is_instance else None + + if is_class or is_instance: + method_name = method_name or "forward" + target_func = getattr(owner_class, method_name, None) + context_name = f"{'class' if is_class else 'instance'} {owner_class.__name__}.{method_name}" + elif is_bound_method: + method_name = method_name or obj.__name__ target_func = obj - context_name = f"function/method {obj.__name__}" + context_name = f"bound method {method_name}" else: - raise TypeError(f"Unsupported type for magi_compile: {type(obj)}") + method_name = None + target_func = obj + context_name = f"function {obj.__name__}" + + if not callable(target_func): + if is_class and not method_name: + raise AssertionError(f"Class '{owner_class.__name__}' must have forward method or pass method_name explicitly.") + if is_instance and not method_name: + raise AssertionError(f"Instance '{owner_class.__name__}' must have forward method or pass method_name explicitly.") + raise TypeError(f"Target '{target_func.__name__}' is not callable for {type(obj)}") # 2. Infer dynamic dims inferred_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(target_func, context_name) @@ -141,14 +171,20 @@ def forward(self, x): ... _check_dynamic_arg_dims(inferred_dims, target_func) - # 3. Logic based on type - if inspect.isclass(obj): - return _magi_compile_class(obj, inferred_dims, enable_if, config_patch, model_tag) - elif isinstance(obj, nn.Module): - return _magi_compile_instance(obj, inferred_dims, enable_if, config_patch, model_tag) - else: + # 3. Dispatch by entry kind (class / instance / bound method / bare function) + + if is_class: + return _magi_compile_class(obj, inferred_dims, enable_if, config_patch, model_tag, method_name) + elif is_instance: + return _magi_compile_bound_method(obj, inferred_dims, enable_if, config_patch, model_tag, method_name) + elif is_bound_method: + _magi_compile_bound_method(owner_instance, inferred_dims, enable_if, config_patch, model_tag, method_name) + return getattr(owner_instance, method_name) + elif is_function: return _magi_compile_function(obj, inferred_dims, enable_if, config_patch, model_tag) + raise TypeError(f"Unsupported type for magi_compile: {type(obj)}") + def magi_register_custom_op( name: str | None = None, diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index c3b5b5e..631f46d 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -126,6 +126,7 @@ def initialize_cache(self, cache_dir: Path, prefix: str = ""): magi_logger.info("MagiCompiler's cache is disabled.") return + self.cache_dir.mkdir(parents=True, exist_ok=True) magi_logger.info("Using cache directory: %s for MagiCompiler", cache_dir) if self.cache_file_path.exists(): # load the cache from the file diff --git a/magi_compiler/magi_backend/magi_compiler_base.py b/magi_compiler/magi_backend/magi_compiler_base.py index 3f6230e..dac4a48 100644 --- a/magi_compiler/magi_backend/magi_compiler_base.py +++ b/magi_compiler/magi_backend/magi_compiler_base.py @@ -20,6 +20,7 @@ import os import sys from contextlib import contextmanager +from pyclbr import Class from types import CodeType from typing import Callable, Literal @@ -81,11 +82,12 @@ class MagiCompileState: def __init__( self, - obj: torch.nn.Module | Callable, + obj: Callable | Class, compile_config: CompileConfig, model_idx: int, model_tag: str, dynamic_arg_dims: dict[str, int | list[int]], + target_method_name: str | None = None, ): self.obj = obj self.compile_config = compile_config @@ -94,20 +96,24 @@ def __init__( self.dynamic_arg_dims = dynamic_arg_dims self.traced_files: OrderedSet = OrderedSet() self.inductor_compile_config: dict = {} - - if isinstance(obj, torch.nn.Module): - self.original_code_object: CodeType = obj.__class__.forward.__code__ - self._target_callable = getattr(obj, "_magi_original_forward", obj.forward) + self.target_method_name: str | None = None + self.target_function: Callable | None = None + + if target_method_name: + self.target_method_name = target_method_name + self.target_function = getattr(obj.__class__, self.target_method_name) + self.original_code_for_hook: CodeType = self.target_function.__code__ + self.original_entry = self.target_function.__get__(obj, obj.__class__) elif callable(obj): - self.original_code_object: CodeType = inspect.unwrap(obj).__code__ - self._target_callable = obj + self.original_code_for_hook: CodeType = inspect.unwrap(obj).__code__ + self.original_entry = obj else: raise TypeError(f"Unsupported object type for MagiCompileState: {type(obj)}") - self.compiled_code: CodeType | None = None - self._aot_compiled_fn: Callable | None = None - self._compile_artifacts: object | None = None - self._compiled_callable: Callable | None = None + self.compiled_entry: Callable | None = None + self.jit_compiled_code: CodeType | None = None + self.aot_compiled_fn: Callable | None = None + self.aot_compile_artifacts: object | None = None def _ensure_compiled(self): """Lazy initialization of the ``torch.compile`` wrapper. @@ -115,7 +121,7 @@ def _ensure_compiled(self): Called on first actual compilation (JIT or AOT cache miss). On AOT cache hits, this is never called — avoiding ``torch.compile`` overhead entirely. """ - if self._compiled_callable is not None: + if self.compiled_entry is not None: return backend = init_backend( self.compile_config, self.model_idx, self.model_tag, self.traced_files, self.inductor_compile_config @@ -142,8 +148,8 @@ def _ensure_compiled(self): # gets an extra ``.aot_compile`` attribute (eval_frame.py:880) that # delegates to ``torch._dynamo.aot_compile.aot_compile_fullgraph`` # (aot_compile.py:108). - self._compiled_callable = torch.compile( - self._target_callable, fullgraph=True, dynamic=True, backend=backend, options=options + self.compiled_entry = torch.compile( + self.original_entry, fullgraph=True, dynamic=True, backend=backend, options=options ) @property @@ -158,7 +164,7 @@ def aot_compilation_path(self) -> str: traced through (unknown before Dynamo runs). On loading we verify traced-file checksums separately via ``_verify_source_unchanged``. """ - hash_key = compute_hash([self._target_callable, self.model_idx, self.compile_config.hash, self.dynamic_arg_dims]) + hash_key = compute_hash([self.original_entry, self.model_idx, self.compile_config.hash, self.dynamic_arg_dims]) cache_dir = os.path.join( self.compile_config.cache_root_dir, "torch_aot_compile", @@ -184,8 +190,8 @@ def load_aot_compile_artifacts(self): from torch._dynamo.aot_compile import CompileArtifacts with open(aot_path, "rb") as f: - self._compile_artifacts = CompileArtifacts.deserialize(f.read()) - self._aot_compiled_fn = self._compile_artifacts.compiled_function() + self.aot_compile_artifacts = CompileArtifacts.deserialize(f.read()) + self.aot_compiled_fn = self.aot_compile_artifacts.compiled_function() magi_logger.info("AOT cache loaded successfully from %s", aot_path) return True @@ -197,7 +203,7 @@ def save_aot_compile_artifacts(self) -> None: aot_path = self.aot_compilation_path with open(aot_path, "wb") as f: - f.write(CompileArtifacts.serialize(self._compile_artifacts)) + f.write(CompileArtifacts.serialize(self.aot_compile_artifacts)) _save_source_checksum(self.aot_compilation_path, self.traced_files) magi_logger.info("AOT path: artifacts saved to %s", aot_path) @@ -228,10 +234,10 @@ def aot_compile(self, *args, **kwargs): self._aot_retry_count = 0 for attempt in range(self._AOT_MAX_RETRIES): try: - self._aot_compiled_fn = self._compiled_callable.aot_compile((args, kwargs)) - save_fn = self._aot_compiled_fn.save_compiled_function + self.aot_compiled_fn = self.compiled_entry.aot_compile((args, kwargs)) + save_fn = self.aot_compiled_fn.save_compiled_function idx = save_fn.__code__.co_freevars.index("self") - self._compile_artifacts = save_fn.__closure__[idx].cell_contents + self.aot_compile_artifacts = save_fn.__closure__[idx].cell_contents return except TensorifyScalarRestartAnalysis: if attempt >= self._AOT_MAX_RETRIES - 1: @@ -243,11 +249,11 @@ def aot_compile(self, *args, **kwargs): attempt + 1, self._AOT_MAX_RETRIES, ) - self._compiled_callable = None + self.compiled_entry = None self._ensure_compiled() @contextmanager - def _capture_compiled_bytecode(self): + def _jit_capture_compiled_bytecode(self): """Register a Dynamo bytecode hook to capture compiled bytecode. Each time Dynamo completes compilation of a frame (e.g., the ``forward`` @@ -265,7 +271,7 @@ def _capture_compiled_bytecode(self): def _bytecode_hook(old_code: CodeType, new_code: CodeType): """Hook to save the compiled bytecode for direct execution.""" - if old_code is not self.original_code_object: + if old_code is not self.original_code_for_hook: return frame = sys._getframe() while frame and frame.f_back: @@ -277,13 +283,13 @@ def _bytecode_hook(old_code: CodeType, new_code: CodeType): frame = frame.f_locals["frame"] assert frame.f_code == old_code - if isinstance(self.obj, torch.nn.Module): - if hasattr(frame.f_locals, "self") and frame.f_locals["self"] is not self.obj: + if self.target_method_name is not None: + if "self" in frame.f_locals and frame.f_locals["self"] is not self.obj: return emit_after_dynamo_bytecode_transform() # Save the compiled bytecode - self.compiled_code = new_code + self.jit_compiled_code = new_code handle = torch._dynamo.convert_frame.register_bytecode_hook(_bytecode_hook) try: @@ -293,60 +299,39 @@ def _bytecode_hook(old_code: CodeType, new_code: CodeType): @contextmanager def dispatch_to_compiled_fwd(self, mode: Literal["jit", "aot"] = "jit"): + """Temporarily swap in compiled code and yield a callable invoker. + + For JIT mode the original ``__code__`` is swapped with the compiled + bytecode and restored in ``finally``. For AOT mode the pre-compiled + function is used directly with no cleanup needed. """ - Context manager to dispatch to the compiled code. - - For class-level decoration (obj is nn.Module): - Temporarily swaps the class's forward bytecode, yields None. - The caller invokes old_call(self, ...) which picks up the swapped code. - - For function-level decoration (obj is Callable): - Temporarily swaps the target function's bytecode, yields the modified function. - The caller invokes the yielded function directly. - - This way: - 1. Dynamo guarantees that the compiled bytecode has exactly the same arguments, - cell variables, and free variables as the original code. Therefore we can - directly switch the code object in the function and call it. - 2. In torch.nn.Module, `__call__` wraps `forward` with critical runtime logic - (hooks, FSDP mechanics, etc.). Switching bytecode ensures these are preserved. - """ + dispatch_via_method = self.target_method_name is not None + if mode == "jit": - assert self.compiled_code is not None - if isinstance(self.obj, torch.nn.Module): - self.obj.__class__.forward.__code__ = self.compiled_code - if hasattr(self.obj, "_magi_original_forward"): - original_forward = self.obj.forward - self.obj.forward = self.obj.__class__.forward.__get__(self.obj, self.obj.__class__) - yield - self.obj.forward = original_forward - else: - yield - self.obj.__class__.forward.__code__ = self.original_code_object + assert self.jit_compiled_code is not None + if dispatch_via_method: + assert self.target_function is not None + original_code = self.target_function.__code__ + self.target_function.__code__ = self.jit_compiled_code + try: + yield self.target_function.__get__(self.obj, self.obj.__class__) + finally: + self.target_function.__code__ = original_code else: - # Function/Method level target = inspect.unwrap(self.obj) - if inspect.ismethod(target): - # Bound method (e.g. model.forward) - target.__func__.__code__ = self.compiled_code - yield - target.__func__.__code__ = self.original_code_object - elif inspect.isfunction(target) or hasattr(target, "__code__"): - # Normal function or anything with __code__ - target.__code__ = self.compiled_code - yield - target.__code__ = self.original_code_object - else: - raise AttributeError(f"Target {target} is neither a method nor a function with __code__") + original_code = target.__code__ + target.__code__ = self.jit_compiled_code + try: + yield target + finally: + target.__code__ = original_code + elif mode == "aot": - assert self._aot_compiled_fn is not None - if isinstance(self.obj, torch.nn.Module): - original_forward = self.obj.forward - self.obj.forward = lambda *args, **kwargs: self._aot_compiled_fn(self.obj, *args, **kwargs) - yield - self.obj.forward = original_forward + assert self.aot_compiled_fn is not None + if dispatch_via_method: + yield lambda *args, **kwargs: self.aot_compiled_fn(self.obj, *args, **kwargs) else: - # For functions, AOT returns the compiled function directly - yield self._aot_compiled_fn + yield self.aot_compiled_fn + else: raise ValueError(f"Invalid mode: {mode}") diff --git a/tests/api_tests/test_magi_compile.py b/tests/api_tests/test_magi_compile.py index 63ddd78..5342406 100644 --- a/tests/api_tests/test_magi_compile.py +++ b/tests/api_tests/test_magi_compile.py @@ -18,7 +18,6 @@ import shutil import tempfile -import time from typing import Tuple from unittest.mock import MagicMock, patch @@ -29,6 +28,7 @@ from magi_compiler.api import magi_compile from magi_compiler.config import CompileConfig, CompileMode +from tests.perf_tests import cuda_benchmark @pytest.fixture(autouse=True) @@ -310,6 +310,68 @@ def _sync(model): _check(_sync(mtd_imp), "mtd_imp") _check(_sync(mtd_fact), "mtd_fact") + # 5. Non-nn.Module callable class / instance / method + class CallableModel: + def __init__(self, dim: int): + self.dim = dim + self.ln_weight = [torch.randn(dim) for _ in range(3)] + self.ln_bias = [torch.randn(dim) for _ in range(3)] + self.linear_weight = [torch.randn(dim, dim) for _ in range(3)] + self.linear_bias = [torch.randn(dim) for _ in range(3)] + + def copy_from(self, other: "CallableModel"): + self.ln_weight = [w.clone() for w in other.ln_weight] + self.ln_bias = [b.clone() for b in other.ln_bias] + self.linear_weight = [w.clone() for w in other.linear_weight] + self.linear_bias = [b.clone() for b in other.linear_bias] + + def _run_blocks(self, x: torch.Tensor) -> torch.Tensor: + for i in range(3): + res = x + x = torch.nn.functional.layer_norm(x, (self.dim,), self.ln_weight[i], self.ln_bias[i]) + x = torch.nn.functional.linear(x, self.linear_weight[i], self.linear_bias[i]) + x = torch.nn.functional.gelu(x) + x = x + res + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._run_blocks(x) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x) + + def step(self, x: torch.Tensor) -> torch.Tensor: + return self._run_blocks(x) + + x_non_module = torch.randn(2, size) + native_callable = CallableModel(size) + native_callable_out = native_callable(x_non_module) + + # class path (factory style) + CallableClsFact = compiler(CallableModel) + callable_cls_inst = CallableClsFact(size) + callable_cls_inst.copy_from(native_callable) + + # instance path (factory style, returns compiled callable entry) + callable_inst_obj = CallableModel(size) + callable_inst_obj.copy_from(native_callable) + callable_inst_entry = compiler(callable_inst_obj) + + # method path (bound method; dispatch decided by self type) + callable_mtd_inst = CallableModel(size) + callable_mtd_inst.copy_from(native_callable) + callable_mtd_inst.step = compiler(callable_mtd_inst.step) + + assert_close(callable_cls_inst(x_non_module), native_callable_out, rtol=1e-3, atol=1e-3) + assert_close(callable_inst_entry(x_non_module), native_callable_out, rtol=1e-3, atol=1e-3) + assert_close(callable_mtd_inst.step(x_non_module), native_callable_out, rtol=1e-3, atol=1e-3) + + x_non_module_2 = torch.randn(5, size) + native_callable_out_2 = native_callable(x_non_module_2) + assert_close(callable_cls_inst(x_non_module_2), native_callable_out_2, rtol=1e-3, atol=1e-3) + assert_close(callable_inst_entry(x_non_module_2), native_callable_out_2, rtol=1e-3, atol=1e-3) + assert_close(callable_mtd_inst.step(x_non_module_2), native_callable_out_2, rtol=1e-3, atol=1e-3) + def test_nested_function_calls(self): """Test compilation of model with nested function calls.""" @@ -485,25 +547,28 @@ class ClsSimpleModel(SimpleModel): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support for stable timing") def test_simple_model_timing_class_function_instance_method(self): - """Lightweight timing sanity: Class / Function / Instance / Method entrypoints.""" + """Timing sanity on a heavier workload: Class / Function / Instance / Method entrypoints.""" class SimpleModel(nn.Module): def __init__(self): super().__init__() - self.dim = 32 + self.dim = 128 self.layers = nn.ModuleList([nn.Linear(self.dim, self.dim) for _ in range(4)]) + self.norms = nn.ModuleList([nn.LayerNorm(self.dim) for _ in range(4)]) def forward(self, x: torch.Tensor) -> torch.Tensor: - for layer in self.layers: + for i, layer in enumerate(self.layers): res = x + x = self.norms[i](x) x = layer(x) x = torch.nn.functional.gelu(x) + x = layer(x) x = x + res return x device = torch.device("cuda:0") - seq_len = 16 - test_input = torch.randn(seq_len, 32, device=device) + seq_len = 256 + test_input = torch.randn(seq_len, 128, device=device) native = SimpleModel().to(device).eval() @@ -529,24 +594,72 @@ def func_entry(x: torch.Tensor) -> torch.Tensor: mtd_model.load_state_dict(native.state_dict()) mtd_model.forward = magi_compile(mtd_model.forward, dynamic_arg_dims={"x": 0}) - def _bench(callable_obj, label: str, warmup: int = 5, iters: int = 200) -> float: - with torch.no_grad(): - for _ in range(warmup): - callable_obj(test_input) - torch.cuda.synchronize() - start = time.perf_counter() - with torch.no_grad(): - for _ in range(iters): - callable_obj(test_input) - torch.cuda.synchronize() - elapsed = time.perf_counter() - start - print(f"{label}: {elapsed:.4f}s") - return elapsed - - t_class = _bench(cls_model, "class") - t_func = _bench(func_entry, "function") - t_inst = _bench(inst_model, "instance") - t_mtd = _bench(mtd_model, "method") + class NonModulePerf: + def __init__(self, dim: int): + self.dim = dim + self.ln_weight = [torch.randn(dim, device=device) for _ in range(4)] + self.ln_bias = [torch.randn(dim, device=device) for _ in range(4)] + self.linear_weight = [torch.randn(dim, dim, device=device) for _ in range(4)] + self.linear_bias = [torch.randn(dim, device=device) for _ in range(4)] + + def copy_from(self, other: "NonModulePerf"): + self.ln_weight = [w.clone() for w in other.ln_weight] + self.ln_bias = [b.clone() for b in other.ln_bias] + self.linear_weight = [w.clone() for w in other.linear_weight] + self.linear_bias = [b.clone() for b in other.linear_bias] + + def _run(self, x: torch.Tensor) -> torch.Tensor: + for i in range(4): + res = x + x = torch.nn.functional.layer_norm(x, (self.dim,), self.ln_weight[i], self.ln_bias[i]) + x = torch.nn.functional.linear(x, self.linear_weight[i], self.linear_bias[i]) + x = torch.nn.functional.gelu(x) + x = torch.nn.functional.linear(x, self.linear_weight[i], self.linear_bias[i]) + x = x + res + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._run(x) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x) + + def step(self, x: torch.Tensor) -> torch.Tensor: + return self._run(x) + + @magi_compile(dynamic_arg_dims={"x": 0}) + class CompiledNonModulePerf(NonModulePerf): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x) + + non_module_native = NonModulePerf(128) + + non_module_cls = CompiledNonModulePerf(128) + non_module_cls.copy_from(non_module_native) + + non_module_inst_obj = NonModulePerf(128) + non_module_inst_obj.copy_from(non_module_native) + non_module_inst = magi_compile(non_module_inst_obj, dynamic_arg_dims={"x": 0}) + + non_module_mtd_obj = NonModulePerf(128) + non_module_mtd_obj.copy_from(non_module_native) + non_module_mtd_obj.step = magi_compile(non_module_mtd_obj.step, dynamic_arg_dims={"x": 0}) + + with torch.no_grad(): + class_result = cuda_benchmark(lambda: cls_model(test_input), compilation_warmup=3) + func_result = cuda_benchmark(lambda: func_entry(test_input), compilation_warmup=3) + inst_result = cuda_benchmark(lambda: inst_model(test_input), compilation_warmup=3) + method_result = cuda_benchmark(lambda: mtd_model(test_input), compilation_warmup=3) + + print(class_result.summary("class")) + print(func_result.summary("function")) + print(inst_result.summary("instance")) + print(method_result.summary("method")) + + t_class = class_result.median / 1000.0 + t_func = func_result.median / 1000.0 + t_inst = inst_result.median / 1000.0 + t_mtd = method_result.median / 1000.0 compiled_times = [t_class, t_func, t_inst, t_mtd] max_compiled = max(compiled_times) @@ -555,3 +668,25 @@ def _bench(callable_obj, label: str, warmup: int = 5, iters: int = 200) -> float "Magi entry timings diverged too much: " f"class={t_class:.4f}s, function={t_func:.4f}s, instance={t_inst:.4f}s, method={t_mtd:.4f}s" ) + + # non-nn.Module callable class / instance / method timing sanity + with torch.no_grad(): + non_module_class_result = cuda_benchmark(lambda: non_module_cls(test_input), compilation_warmup=3) + non_module_instance_result = cuda_benchmark(lambda: non_module_inst(test_input), compilation_warmup=3) + non_module_method_result = cuda_benchmark(lambda: non_module_mtd_obj.step(test_input), compilation_warmup=3) + + print(non_module_class_result.summary("non_module_class")) + print(non_module_instance_result.summary("non_module_instance")) + print(non_module_method_result.summary("non_module_method")) + + t_nm_class = non_module_class_result.median / 1000.0 + t_nm_inst = non_module_instance_result.median / 1000.0 + t_nm_mtd = non_module_method_result.median / 1000.0 + + nm_times = [t_nm_class, t_nm_inst, t_nm_mtd] + max_nm = max(nm_times) + min_nm = min(nm_times) + assert max_nm / min_nm < 1.2, ( + "Non-module entry timings diverged too much: " + f"class={t_nm_class:.4f}s, instance={t_nm_inst:.4f}s, method={t_nm_mtd:.4f}s" + ) diff --git a/tests/model_definition.py b/tests/model_definition.py index 13c75bd..d7edf35 100644 --- a/tests/model_definition.py +++ b/tests/model_definition.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass +from typing import Self import torch import torch.nn as nn @@ -177,6 +178,97 @@ def create_mlp_model_with_initial_params(config: MLPConfig, device: torch.device return model, initial_params +class RawNonModuleMLP: + """Non-module MLP workload aligned with ``RawMLP`` math.""" + + def __init__(self, hidden_size: int, intermediate_size: int, device: torch.device, dtype: torch.dtype = torch.bfloat16): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.device = device + self.dtype = dtype + self.eps = 1e-6 + + self.pre_norm_weight = torch.ones(hidden_size, device=device, dtype=torch.float32) + self.up_proj_weight = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) + self.down_proj_weight = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + + def copy_from(self, other: Self) -> None: + self.pre_norm_weight = other.pre_norm_weight.clone() + self.up_proj_weight = other.up_proj_weight.clone() + self.down_proj_weight = other.down_proj_weight.clone() + + def _rms_norm(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + variance = x.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(self.pre_norm_weight.dtype) * self.pre_norm_weight + return x.to(input_dtype) + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + x = self._rms_norm(inp).to(torch.bfloat16) + x = torch.nn.functional.linear(x, self.up_proj_weight).to(torch.float32) + x = torch.nn.functional.silu(x).to(torch.bfloat16) + x = torch.nn.functional.linear(x, self.down_proj_weight).to(torch.float32) + return x + + def __call__(self, inp: torch.Tensor) -> torch.Tensor: + return self.forward(inp) + + def step(self, inp: torch.Tensor) -> torch.Tensor: + return self.forward(inp) + + +class RawNonModulePointwiseFusionChain: + """Non-module pointwise chain aligned with ``PointwiseFusionChain`` math.""" + + def copy_from(self, other: Self) -> None: + _ = other + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + x = inp + x = x * 0.5 + x = x + 1.0 + x = torch.relu(x) + x = x * x + x = x - 0.5 + x = torch.sigmoid(x) + return x + + def __call__(self, inp: torch.Tensor) -> torch.Tensor: + return self.forward(inp) + + def step(self, inp: torch.Tensor) -> torch.Tensor: + return self.forward(inp) + + +class RawNonModuleNormResidualActivation: + """Non-module norm+residual+activation workload aligned with module math.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + self.hidden_size = hidden_size + self.eps = eps + self.weight = torch.ones(hidden_size, dtype=torch.float32) + + def copy_from(self, other: Self) -> None: + self.weight = other.weight.clone() + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + variance = x.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(self.weight.dtype) * self.weight.to(x.device) + return x.to(input_dtype) + + def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(self._norm(x) + residual) + + def __call__(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return self.forward(x, residual) + + def step(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return self.forward(x, residual) + + @dataclass class TransformerConfig: """Configuration for the Transformer model""" diff --git a/tests/perf_tests/test_mlp_perf.py b/tests/perf_tests/test_mlp_perf.py index 9563e2d..b5edbcc 100644 --- a/tests/perf_tests/test_mlp_perf.py +++ b/tests/perf_tests/test_mlp_perf.py @@ -26,7 +26,7 @@ from magi_compiler import magi_compile from magi_compiler.config import CompileMode -from tests.model_definition import MLPConfig, RawMLP +from tests.model_definition import MLPConfig, RawMLP, RawNonModuleMLP from tests.perf_tests import cuda_benchmark, print_perf_comparison from tests.perf_tests.utils import assert_speedup @@ -186,3 +186,97 @@ def test_mlp_method_decoration(mlp_device, mlp_input, mlp_baselines): extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16", ) assert_speedup(magi_vs_eager, eager_result, magi_result, "method", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_non_module_mlp_class_decoration_speedup(mlp_device, mlp_input): + base = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + + @magi_compile(dynamic_arg_dims={"inp": 0}) + class CompiledNonModuleMLP(RawNonModuleMLP): + def __call__(self, inp: torch.Tensor) -> torch.Tensor: + return super().__call__(inp) + + eager_obj = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + eager_obj.copy_from(base) + + compiled_obj = CompiledNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + compiled_obj.copy_from(base) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj(mlp_input)) + compiled_result = cuda_benchmark(lambda: compiled_obj(mlp_input), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Non-module MLP - class decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16", + ) + assert_speedup(speedup, eager_result, compiled_result, "non_module_class", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_non_module_mlp_instance_decoration_speedup(mlp_device, mlp_input): + base = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + + eager_obj = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + eager_obj.copy_from(base) + + inst_obj = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + inst_obj.copy_from(base) + compiled_obj = magi_compile(inst_obj, dynamic_arg_dims={"inp": 0}) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj(mlp_input)) + compiled_result = cuda_benchmark(lambda: compiled_obj(mlp_input), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Non-module MLP - instance decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16", + ) + assert_speedup(speedup, eager_result, compiled_result, "non_module_instance", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_non_module_mlp_method_decoration_speedup(mlp_device, mlp_input): + base = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + + eager_obj = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + eager_obj.copy_from(base) + + mtd_obj = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + mtd_obj.copy_from(base) + mtd_obj.step = magi_compile(mtd_obj.step, dynamic_arg_dims={"inp": 0}) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj.step(mlp_input)) + compiled_result = cuda_benchmark(lambda: mtd_obj.step(mlp_input), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Non-module MLP - method decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16", + ) + assert_speedup(speedup, eager_result, compiled_result, "non_module_method", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_non_module_mlp_eager_vs_module_eager_speed(mlp_device, mlp_input): + config = _build_config() + module_eager = RawMLP(config).to(mlp_device).eval() + + non_module = RawNonModuleMLP(HIDDEN_SIZE, INTERMEDIATE_SIZE, mlp_device) + with torch.no_grad(): + module_eager_result = cuda_benchmark(lambda: module_eager(mlp_input)) + non_module_eager_result = cuda_benchmark(lambda: non_module(mlp_input)) + + print_perf_comparison( + "MLP eager baseline check: module vs non-module", + module_eager_result, + non_module_eager_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16", + ) diff --git a/tests/perf_tests/test_norm_residual_fusion_perf.py b/tests/perf_tests/test_norm_residual_fusion_perf.py index 4e20a68..9b7b51b 100644 --- a/tests/perf_tests/test_norm_residual_fusion_perf.py +++ b/tests/perf_tests/test_norm_residual_fusion_perf.py @@ -31,7 +31,7 @@ from magi_compiler import magi_compile from magi_compiler.config import CompileMode -from tests.model_definition import RMSNorm +from tests.model_definition import RawNonModuleNormResidualActivation, RMSNorm from tests.perf_tests import cuda_benchmark, print_perf_comparison from tests.perf_tests.utils import assert_speedup @@ -199,3 +199,104 @@ def test_norm_residual_method_decoration(nra_device, nra_inputs, nra_baselines): extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) dtype=bf16", ) assert_speedup(magi_vs_eager, eager_result, magi_result, "method", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_norm_residual_non_module_class_decoration_speedup(nra_device, nra_inputs): + x, residual = nra_inputs + + @magi_compile(dynamic_arg_dims={"x": 0, "residual": 0}) + class CompiledNonModuleNRA(RawNonModuleNormResidualActivation): + def __call__(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return super().__call__(x, residual) + + base = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + + eager_obj = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + eager_obj.copy_from(base) + + compiled_obj = CompiledNonModuleNRA(HIDDEN_SIZE) + compiled_obj.copy_from(base) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj(x, residual)) + compiled_result = cuda_benchmark(lambda: compiled_obj(x, residual), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Norm+Residual non-module - class decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) dtype=bf16", + ) + assert_speedup(speedup, eager_result, compiled_result, "norm_residual_non_module_class", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_norm_residual_non_module_instance_decoration_speedup(nra_device, nra_inputs): + x, residual = nra_inputs + + base = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + + eager_obj = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + eager_obj.copy_from(base) + + inst_obj = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + inst_obj.copy_from(base) + compiled_obj = magi_compile(inst_obj, dynamic_arg_dims={"x": 0, "residual": 0}) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj(x, residual)) + compiled_result = cuda_benchmark(lambda: compiled_obj(x, residual), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Norm+Residual non-module - instance decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) dtype=bf16", + ) + assert_speedup(speedup, eager_result, compiled_result, "norm_residual_non_module_instance", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_norm_residual_non_module_method_decoration_speedup(nra_device, nra_inputs): + x, residual = nra_inputs + + base = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + + eager_obj = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + eager_obj.copy_from(base) + + mtd_obj = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + mtd_obj.copy_from(base) + mtd_obj.step = magi_compile(mtd_obj.step, dynamic_arg_dims={"x": 0, "residual": 0}) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj.step(x, residual)) + compiled_result = cuda_benchmark(lambda: mtd_obj.step(x, residual), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Norm+Residual non-module - method decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) dtype=bf16", + ) + assert_speedup(speedup, eager_result, compiled_result, "norm_residual_non_module_method", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_norm_residual_non_module_eager_vs_module_eager_speed(nra_device, nra_inputs): + x, residual = nra_inputs + + module_eager = NormResidualActivation(HIDDEN_SIZE).to(nra_device).eval() + non_module_eager = RawNonModuleNormResidualActivation(HIDDEN_SIZE) + + with torch.no_grad(): + module_eager_result = cuda_benchmark(lambda: module_eager(x, residual)) + non_module_eager_result = cuda_benchmark(lambda: non_module_eager(x, residual)) + + print_perf_comparison( + "Norm+Residual eager baseline check: module vs non-module", + module_eager_result, + non_module_eager_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) dtype=bf16", + ) diff --git a/tests/perf_tests/test_pointwise_fusion_perf.py b/tests/perf_tests/test_pointwise_fusion_perf.py index 526f3cd..209bd71 100644 --- a/tests/perf_tests/test_pointwise_fusion_perf.py +++ b/tests/perf_tests/test_pointwise_fusion_perf.py @@ -30,6 +30,7 @@ from magi_compiler import magi_compile from magi_compiler.config import CompileMode +from tests.model_definition import RawNonModulePointwiseFusionChain from tests.perf_tests import cuda_benchmark, print_perf_comparison from tests.perf_tests.utils import assert_speedup @@ -190,3 +191,81 @@ def test_pointwise_method_decoration(pointwise_device, pointwise_input, pointwis extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE})", ) assert_speedup(magi_vs_eager, eager_result, magi_result, "method", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_pointwise_non_module_class_decoration_speedup(pointwise_device, pointwise_input): + @magi_compile(dynamic_arg_dims={"inp": 0}) + class CompiledNonModulePointwise(RawNonModulePointwiseFusionChain): + def __call__(self, inp: torch.Tensor) -> torch.Tensor: + return super().__call__(inp) + + eager_obj = RawNonModulePointwiseFusionChain() + compiled_obj = CompiledNonModulePointwise() + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj(pointwise_input)) + compiled_result = cuda_benchmark(lambda: compiled_obj(pointwise_input), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Pointwise non-module - class decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE})", + ) + assert_speedup(speedup, eager_result, compiled_result, "pointwise_non_module_class", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_pointwise_non_module_instance_decoration_speedup(pointwise_device, pointwise_input): + eager_obj = RawNonModulePointwiseFusionChain() + inst_obj = RawNonModulePointwiseFusionChain() + compiled_obj = magi_compile(inst_obj, dynamic_arg_dims={"inp": 0}) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj(pointwise_input)) + compiled_result = cuda_benchmark(lambda: compiled_obj(pointwise_input), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Pointwise non-module - instance decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE})", + ) + assert_speedup(speedup, eager_result, compiled_result, "pointwise_non_module_instance", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_pointwise_non_module_method_decoration_speedup(pointwise_device, pointwise_input): + eager_obj = RawNonModulePointwiseFusionChain() + mtd_obj = RawNonModulePointwiseFusionChain() + mtd_obj.step = magi_compile(mtd_obj.step, dynamic_arg_dims={"inp": 0}) + + with torch.no_grad(): + eager_result = cuda_benchmark(lambda: eager_obj.step(pointwise_input)) + compiled_result = cuda_benchmark(lambda: mtd_obj.step(pointwise_input), compilation_warmup=3) + + speedup, _ = print_perf_comparison( + "Pointwise non-module - method decoration", + eager_result, + compiled_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE})", + ) + assert_speedup(speedup, eager_result, compiled_result, "pointwise_non_module_method", SPEEDUP_VS_EAGER_THRESHOLD) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_pointwise_non_module_eager_vs_module_eager_speed(pointwise_device, pointwise_input): + module_eager = PointwiseFusionChain().to(pointwise_device).eval() + non_module_eager = RawNonModulePointwiseFusionChain() + + with torch.no_grad(): + module_eager_result = cuda_benchmark(lambda: module_eager(pointwise_input)) + non_module_eager_result = cuda_benchmark(lambda: non_module_eager(pointwise_input)) + + print_perf_comparison( + "Pointwise eager baseline check: module vs non-module", + module_eager_result, + non_module_eager_result, + extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE})", + )