Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions tests/model_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(input_dtype)


@magi_compile(dynamic_arg_dims={"x": 0})
class MLP(torch.nn.Module):
"""MLP module with traditional architecture (up-projection, activation, and down-projection)"""
class RawMLP(torch.nn.Module):
"""MLP module with traditional architecture (up-projection, activation, and down-projection).

This is the uncompiled base class. Use ``MLP`` for the magi_compile-wrapped variant.
"""

config: MLPConfig

Expand All @@ -81,20 +83,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
- x: (num_tokens, hidden_size)
- output: (num_tokens, hidden_size)
"""
# Pre-normalization
x = self.pre_norm(x).to(torch.bfloat16)
# Up-projection
x = self.up_proj(x).to(torch.float32)
# Activation (SiLU)
x = F.silu(x).to(torch.bfloat16)
# Down-projection
x = self.down_proj(x).to(torch.float32)
return x


@magi_compile(dynamic_arg_dims={"x": 0})
class RMSNormModule(torch.nn.Module):
"""Compiled RMSNorm module for testing"""
class MLP(RawMLP):
"""Compiled MLP module (magi_compile-wrapped ``RawMLP``)."""

pass


class RawRMSNormModule(torch.nn.Module):
"""RMSNorm module for testing.

This is the uncompiled base class. Use ``RMSNormModule`` for the magi_compile-wrapped variant.
"""

config: RMSNormConfig

Expand All @@ -119,6 +126,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.norm(x)


@magi_compile(dynamic_arg_dims={"x": 0})
class RMSNormModule(RawRMSNormModule):
"""Compiled RMSNorm module (magi_compile-wrapped ``RawRMSNormModule``)."""

pass


def create_rms_norm_model(config: RMSNormConfig, device: torch.device) -> RMSNormModule:
"""Create RMSNorm model

Expand Down
95 changes: 95 additions & 0 deletions tests/perf_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import statistics
from collections.abc import Callable
from dataclasses import dataclass

import torch
from triton.testing import do_bench


@dataclass
class BenchmarkResult:
times_ms: list[float]

@property
def median(self) -> float:
return statistics.median(self.times_ms)

@property
def mean(self) -> float:
return statistics.mean(self.times_ms)

@property
def min(self) -> float:
return min(self.times_ms)

@property
def stdev(self) -> float:
return statistics.stdev(self.times_ms) if len(self.times_ms) > 1 else 0.0

def summary(self, label: str = "") -> str:
prefix = f"[{label}] " if label else ""
return (
f"{prefix}median={self.median:.3f}ms mean={self.mean:.3f}ms "
f"min={self.min:.3f}ms stdev={self.stdev:.3f}ms (n={len(self.times_ms)})"
)


def cuda_benchmark(
fn: Callable[[], object],
*,
warmup: int = 25,
rep: int = 100,
grad_to_none: list[torch.Tensor] | None = None,
compilation_warmup: int = 0,
) -> BenchmarkResult:
if compilation_warmup > 0:
for _ in range(compilation_warmup):
fn()
torch.cuda.synchronize()

times = do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, return_mode="all")
return BenchmarkResult(times_ms=times)


def print_perf_comparison(
title: str,
eager: BenchmarkResult,
magi: BenchmarkResult,
torch_compile: BenchmarkResult | None = None,
extra_info: str = "",
) -> tuple[float, float]:
magi_vs_eager = eager.median / magi.median
torch_vs_eager = eager.median / torch_compile.median if torch_compile else 0.0
magi_vs_torch = torch_compile.median / magi.median if torch_compile else 0.0

print(f"\n{'=' * 78}")
print(title)
if extra_info:
print(f" {extra_info}")
print(f"{'=' * 78}")
print(f" {eager.summary('eager ')}")
if torch_compile is not None:
print(f" {torch_compile.summary('torch.compile ')}")
print(f" {magi.summary('magi_compile ')}")
print(" ---")
if torch_compile is not None:
print(f" torch.compile vs eager: {torch_vs_eager:.2f}x")
print(f" magi_compile vs eager: {magi_vs_eager:.2f}x")
if torch_compile is not None:
print(f" magi_compile vs torch.compile: {magi_vs_torch:.2f}x")
print(f"{'=' * 78}")
return magi_vs_eager, magi_vs_torch
13 changes: 13 additions & 0 deletions tests/perf_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
188 changes: 188 additions & 0 deletions tests/perf_tests/test_mlp_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Performance test: end-to-end MLP block.

Covers all supported compilation paths (class, instance, instance+TC, method).

Measured baseline (H100):
torch.compile ~1.8x vs eager
magi_compile ~1.8x vs eager (all paths)
"""

import pytest
import torch

from magi_compiler import magi_compile
from magi_compiler.config import CompileMode
from tests.model_definition import MLPConfig, RawMLP
from tests.perf_tests import cuda_benchmark, print_perf_comparison
from tests.perf_tests.utils import assert_speedup

HIDDEN_SIZE = 2048
INTERMEDIATE_SIZE = 8192
NUM_TOKENS = 8192
SPEEDUP_VS_EAGER_THRESHOLD = 1.5


def _build_config():
return MLPConfig(hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE, params_dtype=torch.bfloat16)


# ── Shared baselines (computed once per module) ────────────────────────


@pytest.fixture(scope="module")
def mlp_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture(scope="module")
def mlp_input(mlp_device):
return torch.randn(NUM_TOKENS, HIDDEN_SIZE, device=mlp_device, dtype=torch.bfloat16)


@pytest.fixture(scope="module")
def mlp_baselines(mlp_device, mlp_input):
"""Eager and torch.compile baselines, benchmarked once for the whole module."""
config = _build_config()
eager_model = RawMLP(config).to(mlp_device).eval()
torch_compiled = torch.compile(RawMLP(config).to(mlp_device).eval(), backend="inductor")
with torch.no_grad():
eager_result = cuda_benchmark(lambda: eager_model(mlp_input))
torch_result = cuda_benchmark(lambda: torch_compiled(mlp_input), compilation_warmup=3)
return eager_result, torch_result


# ── Tests ──────────────────────────────────────────────────────────────


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_class_decoration(mlp_device, mlp_input, mlp_baselines):
"""MLP block: @magi_compile class decoration."""
eager_result, torch_result = mlp_baselines
config = _build_config()

@magi_compile(dynamic_arg_dims={"x": 0})
class CompiledMLP(RawMLP):
pass

magi_compiled = CompiledMLP(config).to(mlp_device).eval()

with torch.no_grad():
magi_result = cuda_benchmark(lambda: magi_compiled(mlp_input), compilation_warmup=3)

magi_vs_eager, _ = print_perf_comparison(
"MLP - class decoration",
eager_result,
magi_result,
torch_result,
extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16",
)
assert_speedup(magi_vs_eager, eager_result, magi_result, "class", SPEEDUP_VS_EAGER_THRESHOLD)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_instance_decoration(mlp_device, mlp_input, mlp_baselines):
"""MLP block: magi_compile(instance) decoration."""
eager_result, torch_result = mlp_baselines
config = _build_config()

magi_compiled = magi_compile(RawMLP(config).to(mlp_device), dynamic_arg_dims={"x": 0})
magi_compiled.eval()

with torch.no_grad():
magi_result = cuda_benchmark(lambda: magi_compiled(mlp_input), compilation_warmup=3)

magi_vs_eager, _ = print_perf_comparison(
"MLP - instance decoration",
eager_result,
magi_result,
torch_result,
extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16",
)
assert_speedup(magi_vs_eager, eager_result, magi_result, "instance", SPEEDUP_VS_EAGER_THRESHOLD)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_instance_torch_compile_mode(mlp_device, mlp_input, mlp_baselines):
"""MLP block: magi_compile(instance, mode=TORCH_COMPILE)."""
eager_result, torch_result = mlp_baselines
config = _build_config()

def _tc_mode(cfg):
cfg.compile_mode = CompileMode.TORCH_COMPILE
return cfg

magi_compiled = magi_compile(RawMLP(config).to(mlp_device), dynamic_arg_dims={"x": 0}, config_patch=_tc_mode)
magi_compiled.eval()

with torch.no_grad():
magi_result = cuda_benchmark(lambda: magi_compiled(mlp_input), compilation_warmup=3)

magi_vs_eager, _ = print_perf_comparison(
"MLP - instance (TORCH_COMPILE mode)",
eager_result,
magi_result,
torch_result,
extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16",
)
assert_speedup(magi_vs_eager, eager_result, magi_result, "instance_tc", SPEEDUP_VS_EAGER_THRESHOLD)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_function_decoration(mlp_device, mlp_input, mlp_baselines):
"""MLP block: @magi_compile function-level entry."""
eager_result, torch_result = mlp_baselines
config = _build_config()

model = RawMLP(config).to(mlp_device).eval()

@magi_compile(dynamic_arg_dims={"x": 0})
def compiled_entry(x: torch.Tensor) -> torch.Tensor:
return model(x)

with torch.no_grad():
magi_result = cuda_benchmark(lambda: compiled_entry(mlp_input), compilation_warmup=3)

magi_vs_eager, _ = print_perf_comparison(
"MLP - function decoration",
eager_result,
magi_result,
torch_result,
extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16",
)
assert_speedup(magi_vs_eager, eager_result, magi_result, "function", SPEEDUP_VS_EAGER_THRESHOLD)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_method_decoration(mlp_device, mlp_input, mlp_baselines):
"""MLP block: magi_compile(model.forward) method decoration."""
eager_result, torch_result = mlp_baselines
config = _build_config()

magi_compiled = RawMLP(config).to(mlp_device).eval()
magi_compiled.forward = magi_compile(magi_compiled.forward, dynamic_arg_dims={"x": 0})

with torch.no_grad():
magi_result = cuda_benchmark(lambda: magi_compiled(mlp_input), compilation_warmup=3)

magi_vs_eager, _ = print_perf_comparison(
"MLP - method decoration",
eager_result,
magi_result,
torch_result,
extra_info=f"shape=({NUM_TOKENS}, {HIDDEN_SIZE}) intermediate={INTERMEDIATE_SIZE} dtype=bf16",
)
assert_speedup(magi_vs_eager, eager_result, magi_result, "method", SPEEDUP_VS_EAGER_THRESHOLD)
Loading
Loading