diff --git a/src/paddlefleet/fusions/fused_bias_swiglu.py b/src/paddlefleet/fusions/fused_bias_swiglu.py index 09932e267..be3cf99ab 100644 --- a/src/paddlefleet/fusions/fused_bias_swiglu.py +++ b/src/paddlefleet/fusions/fused_bias_swiglu.py @@ -18,7 +18,6 @@ import logging import paddle -import paddle.nn.functional as F from paddlefleet.jit import jit_fuser from paddlefleet.utils import nvtx_decorator @@ -38,8 +37,7 @@ def swiglu(y): Returns: paddle.Tensor: Result of SwiGLU activation: SiLU(y1) * y2, where y1, y2 are the split halves. """ - y_1, y_2 = paddle.chunk(y, 2, -1) - return F.silu(y_1) * y_2 + return paddle.nn.functional.swiglu(y) @jit_fuser