[Cpp API Compatibility] Fix flashmla compile#78550
[Cpp API Compatibility] Fix flashmla compile#78550SigureMo merged 5 commits intoPaddlePaddle:developfrom
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Pull request overview
This PR updates Paddle’s C++ API compatibility layer to improve PyTorch-compat behavior and fix compilation issues encountered by FlashMLA (split out from #78484).
Changes:
- Removes
torch/cuda.hre-exports ofdevice_count/is_availableintoat::cudato avoid symbol/redefinition conflicts. - Refactors
c10::ScalarTypeenum/utilities and trims related unit tests. - Aligns some compat behaviors/messages (e.g.,
c10::Streamprinting, added bounds/defined checks in a few ATen ops).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/cpp/compat/c10_ScalarType_test.cc | Removes ScalarType utility branch tests (and an unused include). |
| paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h | Stops re-exporting device_count / is_available into at::cuda. |
| paddle/phi/api/include/compat/c10/core/Stream.h | Changes c10::Stream operator<< formatting to a PyTorch-like form. |
| paddle/phi/api/include/compat/c10/core/Stream.cpp | Changes unsupported native_handle() failure path/message construction. |
| paddle/phi/api/include/compat/c10/core/ScalarType.h | Refactors ScalarType enum generation and removes several explicit utility branches. |
| paddle/phi/api/include/compat/ATen/ops/std.h | Adds dimension-range validation for std implementation. |
| paddle/phi/api/include/compat/ATen/ops/select.h | Adds explicit dim/index range checks and negative index normalization. |
| paddle/phi/api/include/compat/ATen/ops/equal.h | Adds defined-tensor checks before comparing tensors. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| inline const char* toString(ScalarType t) { | ||
| #define DEFINE_CASE(_1, _2, name) \ | ||
| case ScalarType::name: \ | ||
| return #name; | ||
|
|
||
| switch (t) { | ||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) | ||
| case ScalarType::QInt8: | ||
| return "QInt8"; | ||
| case ScalarType::QUInt8: | ||
| return "QUInt8"; | ||
| case ScalarType::QInt32: | ||
| return "QInt32"; | ||
| case ScalarType::QUInt4x2: | ||
| return "QUInt4x2"; | ||
| case ScalarType::QUInt2x4: | ||
| return "QUInt2x4"; | ||
| case ScalarType::ComplexHalf: | ||
| return "ComplexHalf"; | ||
| case ScalarType::Bits1x8: | ||
| return "Bits1x8"; | ||
| case ScalarType::Bits2x4: | ||
| return "Bits2x4"; | ||
| case ScalarType::Bits4x2: | ||
| return "Bits4x2"; | ||
| case ScalarType::Bits8: | ||
| return "Bits8"; | ||
| case ScalarType::Bits16: | ||
| return "Bits16"; | ||
| case ScalarType::Float8_e5m2fnuz: | ||
| return "Float8_e5m2fnuz"; | ||
| case ScalarType::Float8_e4m3fnuz: | ||
| return "Float8_e4m3fnuz"; | ||
| case ScalarType::Float8_e8m0fnu: | ||
| return "Float8_e8m0fnu"; | ||
| case ScalarType::Float4_e2m1fn_x2: | ||
| return "Float4_e2m1fn_x2"; | ||
| case ScalarType::Undefined: | ||
| return "Undefined"; | ||
| default: | ||
| return "UNKNOWN_SCALAR"; | ||
| } | ||
| #undef DEFINE_CASE | ||
| } | ||
|
|
||
| inline size_t elementSize(ScalarType t) { | ||
| #define CASE_ELEMENTSIZE_CASE(ctype, _2, name) \ | ||
| case ScalarType::name: \ | ||
| return sizeof(ctype); | ||
|
|
||
| switch (t) { | ||
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) | ||
| case ScalarType::QInt8: | ||
| case ScalarType::QUInt8: | ||
| case ScalarType::QUInt4x2: | ||
| case ScalarType::QUInt2x4: | ||
| case ScalarType::Bits1x8: | ||
| case ScalarType::Bits2x4: | ||
| case ScalarType::Bits4x2: | ||
| case ScalarType::Bits8: | ||
| return 1; | ||
| case ScalarType::QInt32: | ||
| case ScalarType::Bits16: | ||
| return 4; | ||
| default: | ||
| TORCH_CHECK(false, "Unknown ScalarType"); | ||
| } |
There was a problem hiding this comment.
ScalarType still declares the quantized enums via AT_FORALL_QINT_TYPES(...), but toString() and elementSize() no longer handle QInt8/QUInt8/QInt32/QUInt4x2/QUInt2x4 (the switches only expand AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS). This is a behavioral regression: printing these scalar types will return "UNKNOWN_SCALAR", and elementSize() will throw. Add QInt cases back (e.g., by also expanding AT_FORALL_QINT_TYPES in these switches) and ensure their element sizes match the intended semantics.
| PD_CHECK(false, | ||
| "native_handle() is not supported for this device type (", |
There was a problem hiding this comment.
PD_CHECK(false, ...) will throw a PD_Exception that appends an additional "Expected false, but it's not satisfied." context. If the goal is to match PyTorch's error message format exactly (as the comment states), this will not produce an exact match. Prefer throwing directly (e.g., TORCH_CHECK(false, ...) / PD_THROW(...)) or pass the real condition into PD_CHECK(...) instead of false so the extra context remains meaningful.
| PD_CHECK(false, | |
| "native_handle() is not supported for this device type (", | |
| PD_THROW("native_handle() is not supported for this device type (", |
| if (dim < 0 || dim >= self.dim()) { | ||
| PD_CHECK(false, | ||
| "select(): index ", | ||
| orig_dim, | ||
| " out of range for tensor of size ", | ||
| self.sizes(), | ||
| " at dimension ", | ||
| orig_dim); | ||
| } | ||
| // Handle negative index | ||
| int64_t orig_index = index; | ||
| if (index < 0) { | ||
| int64_t dim_size = self.size(dim); | ||
| index = dim_size + index; | ||
| index = self.size(dim) + index; | ||
| } | ||
| // Check index is valid | ||
| if (index < 0 || index >= self.size(dim)) { | ||
| PD_CHECK(false, | ||
| "select(): index ", | ||
| orig_index, | ||
| " out of range for tensor of size ", | ||
| self.sizes(), | ||
| " at dimension ", | ||
| orig_dim < 0 ? orig_dim + self.dim() : orig_dim); | ||
| } |
There was a problem hiding this comment.
New error handling was added for invalid dim / index values, but the existing select tests (e.g., test/cpp/compat/ATen_select_test.cc) don't cover these out-of-range branches. Add test cases that assert an exception is thrown for (1) dim out of range (including negative beyond -self.dim()), and (2) index out of range (including negative beyond -size(dim)).
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #78550 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 2
Lines ? 6
Branches ? 0
===========================================
Hits ? 6
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ShigureNyako
left a comment
There was a problem hiding this comment.
我这边先请求修改,主要有 2 个阻塞问题:
-
paddle/phi/api/include/compat/c10/core/ScalarType.h- 这次把
ScalarType改成基于AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS生成,同时删除了toString()/elementSize()中原来补齐的 case。 - 结果会丢掉当前已有的
ComplexHalf、Bits*、Float8_e5m2fnuz、Float8_e4m3fnuz、Float8_e8m0fnu、Float4_e2m1fn_x2等类型;仓内paddle/phi/api/include/compat/ATen/cuda/CUDADataType.h仍然引用Float8_e4m3fnuz/Float8_e5m2fnuz,ROCm 路径会直接编译失败。 - 同时
QInt8/QUInt8/QInt32/QUInt4x2/QUInt2x4在toString()/elementSize()上会退化成UNKNOWN_SCALAR/Unknown ScalarType。 - 这属于 source-level BREAKING CHANGE,也和 upstream PyTorch 当前
ScalarType集合/辅助函数覆盖范围不一致。
- 这次把
-
同文件里
#define DEFINE_ST_ENUM_VAL_后写成了#undef DEFINE_ENUM_ST_ENUM_VAL_,宏名不匹配,DEFINE_ST_ENUM_VAL_会泄漏到后续展开。
补充说明:
torch/csrc/api/include/torch/cuda.h去掉at::cuda::*re-export 这件事本身更接近 upstream PyTorch;Stream的operator<<也基本是向 upstream 对齐。- 不过,
torch/cuda.h不再提供at::cuda::is_available()/device_count()的 include 侧别名,对依赖这两个别名且只包含torch/cuda.h的下游来说仍然是 source-level 变更;如果最终保留,建议在 PR 描述里明确。 - 我看了失败 CI:
PR-CI-SOT失败点是 Python 3.14 环境导入libpaddle.so时_Py_TriggerGC未定义,Linux-IXUCA失败点是 Paddle-iluvatar patch apply 失败,这两项都不像是这个 PR diff 直接引入的问题。
建议先把 ScalarType 这一组回归修掉,再继续看 FlashMLA 这部分修复。
ShigureNyako
left a comment
There was a problem hiding this comment.
我重新看了一遍当前版本,这次我上一轮指出的两个 blocker 已经解除,可以先给通过。
-
ScalarType兼容回归已恢复ComplexHalf / QInt* / Bits* / Float8_*fnuz / Float4_e2m1fn_x2的 source-level 枚举值、常量别名、toString()、elementSize()、isFloat8Type()、isReducedFloatingType()、isComplexType()、isSignedType()等兼容路径都补回来了。ATen/cuda/CUDADataType.h里依赖的 ROCm 路径类型也重新可达,不再有我上一轮指出的那类编译风险。- 新增的
c10_ScalarType_test.cc用例也覆盖了这些恢复的分支;其中Bits16现在按 16-bit 语义返回2,这一点也更接近 upstream PyTorch 的bits16定义。
-
宏名问题已消失
- 之前那个
#undef DEFINE_ENUM_ST_ENUM_VAL_/DEFINE_ST_ENUM_VAL_不匹配的问题,随着这次ScalarType回退恢复,也不再存在。
- 之前那个
补充说明:
torch/cuda.h中移除at::cuda::{device_count,is_available}的 re-export 这个 source-level 变化仍然存在,但这是本 PR 修复 FlashMLA 重定义冲突的核心改动,而且方向上与 upstream PyTorch 是一致的。- CI 里我重点区分了一下:
PR-CI-SOT现在已经通过;当前仍失败的Linux-IXUCA还是 Paddle-iluvatar patch apply 失败,看起来与这个 PR 的 diff 没有直接关系。
基于当前 diff,我这边没有新的阻塞问题,先给 approve。
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
拆分自 #78484
修复编译 FlashMLA 遇到的 redefinition 错误。如果要用
at::cuda::is_available(),需要统一使用torch::cuda::is_available(),移除torch/cuda.h中重复导出到at::cuda命名空间的定义。变更详情
1. 修复
torch/cuda.h头文件问题原代码将
torch::cuda::is_available()重复导出到at::cuda命名空间,导致与ATen/cuda/CUDAContext.h中的定义冲突。FlashMLA 使用at::cuda::is_available()时会触发编译错误。修复策略是统一入口:
torch/cuda.h中导出is_available到at::cudatorch::cuda::is_available()作为跨库兼容入口2. 更新测试代码使用
torch::cuda::is_available()修改以下测试文件:
test/cpp/compat/ATen_TensorAccessor_test.cctest/cpp/compat/compat_basic_test.cctest/cpp/compat/c10_storage_test.cc3. 修复 Stream 相关接口 (
c10/core/Stream.h,c10/core/Stream.cpp)Stream的比较运算符行为native_handle()在 CPU 设备上的异常抛出语义4. 精简 ScalarType 实现
暂时移除部分尚未完全对齐的
ScalarType相关代码(将在后续 PR 中重新引入)。相关文档
是否引起精度变化
否