-
Notifications
You must be signed in to change notification settings - Fork 58
[Fusion] use paddle.nn.functional.swiglu to replace manual chunk+silu #707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,6 @@ | |
| import logging | ||
|
|
||
| import paddle | ||
| import paddle.nn.functional as F | ||
|
|
||
| from paddlefleet.jit import jit_fuser | ||
| from paddlefleet.utils import nvtx_decorator | ||
|
|
@@ -38,8 +37,7 @@ def swiglu(y): | |
| Returns: | ||
| paddle.Tensor: Result of SwiGLU activation: SiLU(y1) * y2, where y1, y2 are the split halves. | ||
| """ | ||
| y_1, y_2 = paddle.chunk(y, 2, -1) | ||
| return F.silu(y_1) * y_2 | ||
| return paddle.nn.functional.swiglu(y) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里我暂时不能把它当成“等价重构”处理。 从当前 PR 的集成日志看,这个替换已经在多组 GLM4.5 场景里引入了可观测的精度漂移,并且直接触发了 建议先做两件事再合:
如果这些精度变化是预期的,也请同步更新基线并走对应的 precision approval 流程。 |
||
|
|
||
|
|
||
| @jit_fuser | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里把
swiglu()的实现切换为paddle.nn.functional.swiglu后,仓库内现有的单卡单测(例如tests/single_card_tests/transformer/test_mlp.py)仍默认使用hidden_act=F.gelu,因此不会覆盖到bias_activation_fusion=True + gated_linear_unit=True + hidden_act=F.silu这条会走到本文件swiglu()的路径。建议补充一个单测场景显式设置hidden_act=F.silu,并对 forward 输出及 backward 梯度做回归对齐(参考之前paddle.chunk + F.silu的结果)。