Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 109 additions & 106 deletions magi_compiler/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,19 @@ def _isolated_dynamo_config():
yield


def _run_orchestration(state: MagiCompileState, original_invoker, args, kwargs):
def get_attr_name_for_wrapper_installed_flag() -> str:
return "_magi_wrapper_installed"


def get_attr_name_for_default_state() -> str:
return "_magi"


def get_attr_name_for_method_state(method_name: str) -> str:
return f"_magi_state_{method_name}"


def _run_orchestration(state: MagiCompileState, args, kwargs):
"""
Central orchestration logic for magi_compile.

Expand All @@ -86,176 +98,163 @@ def _run_orchestration(state: MagiCompileState, original_invoker, args, kwargs):
- (Optional) Perform AOT compilation and save artifacts.
"""
# JIT Fast Path
if state.compiled_code is not None:
with state.dispatch_to_compiled_fwd(mode="jit"):
return original_invoker()
if state.jit_compiled_code is not None:
with state.dispatch_to_compiled_fwd(mode="jit") as compiled_runtime_invoker:
return compiled_runtime_invoker(*args, **kwargs)

# AOT Fast Path
if state.compile_config.aot:
if state._aot_compiled_fn or state.load_aot_compile_artifacts():
res = state.dispatch_to_compiled_fwd(mode="aot")
if isinstance(state.obj, nn.Module):
with res:
return original_invoker()
with res as compiled_fn:
return compiled_fn(*args, **kwargs)
if state.aot_compiled_fn or state.load_aot_compile_artifacts():
with state.dispatch_to_compiled_fwd(mode="aot") as compiled_runtime_invoker:
return compiled_runtime_invoker(*args, **kwargs)

# First compilation
state._ensure_compiled()

# Mark dynamic and static shapes
_apply_shape_marks(state, args, kwargs)

magi_logger.info(f"Start compiling function {state.original_code_object}")
torch._dynamo.eval_frame.remove_from_cache(state.original_code_object)
magi_logger.info(f"Start compiling function {state.original_code_for_hook}")
torch._dynamo.eval_frame.remove_from_cache(state.original_code_for_hook)
CompileMonitor().start()

try:
if state.compile_config.aot:
with _compilation_context(state):
state.aot_compile(*args, **kwargs)
state.save_aot_compile_artifacts()
res = state.dispatch_to_compiled_fwd(mode="aot")
if isinstance(state.obj, nn.Module):
with res:
return original_invoker()
with res as compiled_fn:
return compiled_fn(*args, **kwargs)
with state.dispatch_to_compiled_fwd(mode="aot") as compiled_runtime_invoker:
return compiled_runtime_invoker(*args, **kwargs)
else:
with _compilation_context(state):
# For JIT, we need to capture bytecode.
with state._capture_compiled_bytecode():
if isinstance(state.obj, nn.Module):
with patch.object(state.obj, "forward", state._compiled_callable):
return original_invoker()
else:
return state._compiled_callable(*args, **kwargs)
with state._jit_capture_compiled_bytecode():
return state.compiled_entry(*args, **kwargs)
finally:
CompileMonitor().end()
state.traced_files.clear()


def _lazy_init_magi_state(
target_obj: object,
base_obj: nn.Module | Callable,
state_holder: object,
compile_obj: object,
dynamic_arg_dims: dict[str, int | list[int]] | None,
enable_if: Callable[[], bool] | None,
config_patch: Callable[[CompileConfig], CompileConfig],
model_tag: str | None,
*,
target_method_name: str | None = None,
state_attr: str | None = None,
):
"""
Lazily initializes the MagiCompileState and attaches it to `target_obj._magi`.
"""
if hasattr(target_obj, "_magi"):
"""Lazily initialize MagiCompileState and attach it on ``state_attr``."""
state_attr = state_attr or get_attr_name_for_default_state()
if getattr(state_holder, state_attr, None) is not None:
return

conf = config_patch(copy.deepcopy(get_compile_config()))
enable = enable_if is None or enable_if()
if conf.compile_mode == CompileMode.NONE or not enable:
target_obj._magi = None
return

compilation_counter.num_models_seen += 1

# Infer default model tag if not provided
if model_tag is None:
if hasattr(base_obj, "__class__") and isinstance(base_obj, nn.Module):
model_tag = base_obj.__class__.__name__
else:
model_tag = getattr(base_obj, "__name__", "unknown_func")

target_obj._magi = MagiCompileState(
base_obj, conf, model_idx=compilation_counter.num_models_seen, model_tag=model_tag, dynamic_arg_dims=dynamic_arg_dims
model_tag = getattr(compile_obj, "__name__", compile_obj.__class__.__name__)

setattr(
state_holder,
state_attr,
MagiCompileState(
compile_obj,
conf,
model_idx=compilation_counter.num_models_seen,
model_tag=model_tag,
dynamic_arg_dims=dynamic_arg_dims,
target_method_name=target_method_name,
),
)


def _magi_compile_class(
cls,
cls: type,
dynamic_arg_dims: dict[str, int | list[int]],
enable_if: Callable[[], bool] | None,
config_patch: Callable[[CompileConfig], CompileConfig] | None,
config_patch: Callable[[CompileConfig], CompileConfig],
model_tag: str | None,
method_name: str,
):
"""Class-level decoration: mutates ``cls.__call__`` so every instance is compiled.
"""Install class-level lazy compilation for ``method_name``.

MagiCompileState is created **lazily** on first ``__call__`` via ``_lazy_init_magi_state``,
because at decoration time no instance exists yet, and the global CompileConfig may
not be finalized (e.g. env-vars set after import but before first forward).
This wraps ``cls.__init__`` so every new instance is patched by
``_magi_compile_bound_method`` after initialization.
"""
if getattr(cls, "_magi_compiled", False):
compile_flag_attr = get_attr_name_for_wrapper_installed_flag()
if getattr(cls, compile_flag_attr, False):
return cls

config_patch = config_patch or (lambda x: x)
if config_patch(copy.deepcopy(get_compile_config())).offload_config.model_cpu_offload:
_patch_cpu_offload_apply(cls)

old_call = cls.__call__

@torch.compiler.disable()
def wrapper(self, *args, **kwargs):
_lazy_init_magi_state(self, self, dynamic_arg_dims, enable_if, config_patch, model_tag)
state = self._magi
if not callable(getattr(cls, method_name, None)):
raise AttributeError(f"{cls.__name__} has no callable method '{method_name}'")

# Offload arguments if offload is enabled and not yet compiled
if state is not None and state.compile_config.offload_config.model_cpu_offload and state.compiled_code is None:
args = offload(args)
kwargs = offload(kwargs)
if issubclass(cls, nn.Module) and config_patch(copy.deepcopy(get_compile_config())).offload_config.model_cpu_offload:
_patch_cpu_offload_apply(cls)

if state is None or torch.compiler.is_compiling():
return old_call(self, *args, **kwargs)
old_init = cls.__init__

with _isolated_dynamo_config():
return _run_orchestration(state, lambda: old_call(self, *args, **kwargs), args, kwargs)
@functools.wraps(old_init)
def wrapped_init(self, *args, **kwargs):
old_init(self, *args, **kwargs)
_magi_compile_bound_method(self, dynamic_arg_dims, enable_if, config_patch, model_tag, method_name=method_name)

cls.__call__ = wrapper
cls._magi_compiled = True
cls.__init__ = wrapped_init
setattr(cls, compile_flag_attr, True)
return cls


def _magi_compile_instance(
module: nn.Module,
def _magi_compile_bound_method(
instance: object,
dynamic_arg_dims: dict[str, int | list[int]],
enable_if: Callable[[], bool] | None,
config_patch: Callable[[CompileConfig], CompileConfig] | None,
config_patch: Callable[[CompileConfig], CompileConfig],
model_tag: str | None,
method_name: str,
):
"""Instance-level decoration: only this instance is compiled, class is untouched.
"""Patch one instance method with lazy state initialization and compiled routing."""
if not callable(getattr(instance, method_name, None)):
raise AttributeError(f"{instance.__class__.__name__} instance has no callable method '{method_name}'")

MagiCompileState is created **lazily** on first ``forward`` call via ``_lazy_init_magi_state``.
A compiled ``forward`` is installed as an instance attribute so ``Module.__call__`` β†’
``self.forward()`` resolves to it, while ``module.__class__.forward`` remains original.
state_attr = get_attr_name_for_method_state(method_name)
if getattr(instance, state_attr, None) is not None:
return instance

Call flow::

module(x)
β†’ Module.__call__ (hooks, FSDP, etc.)
β†’ self.forward(x) (finds instance attr β†’ _compiled_forward)
β†’ _run_orchestration
β†’ module.__class__.forward(module, x) # original or bytecode-swapped
"""
if getattr(module, "_magi", None) is not None:
return module

config_patch = config_patch or (lambda x: x)

# module.__class__.forward is the unbound class method β€” never affected by our
# instance-level override, so calling it goes straight to original forward logic.
old_call = module.__class__.forward
module._magi_original_forward = module.forward
old_method = getattr(instance, method_name)

@torch.compiler.disable()
def new_call(*args, **kwargs):
_lazy_init_magi_state(module, module, dynamic_arg_dims, enable_if, config_patch, model_tag)
state = module._magi
state = getattr(instance, state_attr, None)

if state is None:
_lazy_init_magi_state(
instance,
instance,
dynamic_arg_dims,
enable_if,
config_patch,
model_tag,
target_method_name=method_name,
state_attr=state_attr,
)
state = getattr(instance, state_attr, None)

if state is None or torch.compiler.is_compiling():
return old_call(module, *args, **kwargs)
return old_method(*args, **kwargs)

with _isolated_dynamo_config():
return _run_orchestration(state, lambda: module.__class__.__call__(module, *args, **kwargs), args, kwargs)
return _run_orchestration(state, args, kwargs)

module.forward = new_call
module._magi_compiled = True
return module
setattr(instance, method_name, new_call)
setattr(instance, get_attr_name_for_wrapper_installed_flag(), True)
return instance


def _magi_compile_function(
Expand All @@ -265,26 +264,30 @@ def _magi_compile_function(
config_patch: Callable[[CompileConfig], CompileConfig] | None,
model_tag: str | None,
):
"""Function / bound-method level decoration.
"""Wrap a function entry with lazy ``MagiCompileState`` and compiled routing.

MagiCompileState is created **lazily** on first call via ``_lazy_init_magi_state``.
The wrapper replaces the original callable.
The returned wrapper initializes state on first invocation and then dispatches
through ``_run_orchestration``.
"""
if getattr(func, "_magi", None) is not None:
state_attr = get_attr_name_for_default_state()
if getattr(func, state_attr, None) is not None:
return func

config_patch = config_patch or (lambda x: x)

@torch.compiler.disable()
@functools.wraps(func) # for the original function name and docstring
def wrapper(*args, **kwargs):
_lazy_init_magi_state(wrapper, func, dynamic_arg_dims, enable_if, config_patch, model_tag)
state = wrapper._magi
state = getattr(wrapper, state_attr, None)
if state is None:
_lazy_init_magi_state(wrapper, func, dynamic_arg_dims, enable_if, config_patch, model_tag, state_attr=state_attr)
state = getattr(wrapper, state_attr, None)

if state is None or torch.compiler.is_compiling():
return func(*args, **kwargs)

with _isolated_dynamo_config():
return _run_orchestration(state, lambda: func(*args, **kwargs), args, kwargs)
return _run_orchestration(state, args, kwargs)

return wrapper

Expand Down Expand Up @@ -320,7 +323,7 @@ def _apply_shape_marks(state: MagiCompileState, args, kwargs):
This is called just before Dynamo tracing to ensure dimensions are
correctly generalized in the captured graph.
"""
sig = inspect.signature(state._target_callable)
sig = inspect.signature(state.original_entry)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

Expand Down Expand Up @@ -435,7 +438,7 @@ def _compilation_context(state: MagiCompileState):
# 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
def _hijack_inline_call_to_collect_traced_files(state: MagiCompileState):
state.traced_files.add(state.original_code_object.co_filename)
state.traced_files.add(state.original_code_for_hook.co_filename)
inline_call = InliningInstructionTranslator.inline_call_

def patched(self_):
Expand Down
Loading
Loading