From 3ccf22f9cff48888ad6f445225bb3d74fb3d9596 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Sun, 19 Oct 2025 18:34:17 -0700 Subject: [PATCH] Add static_shapes=True to layernorm and rmsnorm Summary: We add static_shapes=True to improve perfs since we have multiple different shapes Differential Revision: D84663972 --- examples/layer_norm.py | 2 +- examples/rms_norm.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 809df10f6..14c32ad32 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -21,7 +21,7 @@ # %% -@helion.kernel +@helion.kernel(static_shapes=True) def layer_norm_fwd( x: torch.Tensor, normalized_shape: list[int], diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 0d1342ae5..f732a85a2 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -29,7 +29,7 @@ # %% -@helion.kernel +@helion.kernel(static_shapes=True) def rms_norm_fwd( x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -70,7 +70,7 @@ def rms_norm_fwd( return out, inv_rms.reshape(-1, 1) -@helion.kernel +@helion.kernel(static_shapes=True) def rms_norm_bwd( grad_out: torch.Tensor, x: torch.Tensor,