Skip to content

[Cpp API Compatibility] Fix flashmla compile#78550

Merged
SigureMo merged 5 commits intoPaddlePaddle:developfrom
youge325:fix-flashmla-compile
Apr 3, 2026
Merged

[Cpp API Compatibility] Fix flashmla compile#78550
SigureMo merged 5 commits intoPaddlePaddle:developfrom
youge325:fix-flashmla-compile

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Apr 1, 2026

PR Category

Execute Infrastructure

PR Types

Bug fixes

Description

拆分自 #78484

修复编译 FlashMLA 遇到的 redefinition 错误。如果要用 at::cuda::is_available(),需要统一使用 torch::cuda::is_available(),移除 torch/cuda.h 中重复导出到 at::cuda 命名空间的定义。

变更详情

1. 修复 torch/cuda.h 头文件问题

namespace at::cuda {
- using torch::cuda::device_count;
- using torch::cuda::is_available;
  using torch::cuda::synchronize;
}

原代码将 torch::cuda::is_available() 重复导出到 at::cuda 命名空间,导致与 ATen/cuda/CUDAContext.h 中的定义冲突。FlashMLA 使用 at::cuda::is_available() 时会触发编译错误。

修复策略是统一入口:

  • 不再在 torch/cuda.h 中导出 is_availableat::cuda
  • 推荐使用 torch::cuda::is_available() 作为跨库兼容入口

2. 更新测试代码使用 torch::cuda::is_available()

修改以下测试文件:

  • test/cpp/compat/ATen_TensorAccessor_test.cc
  • test/cpp/compat/compat_basic_test.cc
  • test/cpp/compat/c10_storage_test.cc

3. 修复 Stream 相关接口 (c10/core/Stream.h, c10/core/Stream.cpp)

  • 对齐 Stream 的比较运算符行为
  • 修复 native_handle() 在 CPU 设备上的异常抛出语义

4. 精简 ScalarType 实现

暂时移除部分尚未完全对齐的 ScalarType 相关代码(将在后续 PR 中重新引入)。

相关文档

是否引起精度变化

Copilot AI review requested due to automatic review settings April 1, 2026 11:12
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 1, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 1, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates Paddle’s C++ API compatibility layer to improve PyTorch-compat behavior and fix compilation issues encountered by FlashMLA (split out from #78484).

Changes:

  • Removes torch/cuda.h re-exports of device_count / is_available into at::cuda to avoid symbol/redefinition conflicts.
  • Refactors c10::ScalarType enum/utilities and trims related unit tests.
  • Aligns some compat behaviors/messages (e.g., c10::Stream printing, added bounds/defined checks in a few ATen ops).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
test/cpp/compat/c10_ScalarType_test.cc Removes ScalarType utility branch tests (and an unused include).
paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h Stops re-exporting device_count / is_available into at::cuda.
paddle/phi/api/include/compat/c10/core/Stream.h Changes c10::Stream operator<< formatting to a PyTorch-like form.
paddle/phi/api/include/compat/c10/core/Stream.cpp Changes unsupported native_handle() failure path/message construction.
paddle/phi/api/include/compat/c10/core/ScalarType.h Refactors ScalarType enum generation and removes several explicit utility branches.
paddle/phi/api/include/compat/ATen/ops/std.h Adds dimension-range validation for std implementation.
paddle/phi/api/include/compat/ATen/ops/select.h Adds explicit dim/index range checks and negative index normalization.
paddle/phi/api/include/compat/ATen/ops/equal.h Adds defined-tensor checks before comparing tensors.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 234 to 256
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_1, _2, name) \
case ScalarType::name: \
return #name;

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
case ScalarType::QInt8:
return "QInt8";
case ScalarType::QUInt8:
return "QUInt8";
case ScalarType::QInt32:
return "QInt32";
case ScalarType::QUInt4x2:
return "QUInt4x2";
case ScalarType::QUInt2x4:
return "QUInt2x4";
case ScalarType::ComplexHalf:
return "ComplexHalf";
case ScalarType::Bits1x8:
return "Bits1x8";
case ScalarType::Bits2x4:
return "Bits2x4";
case ScalarType::Bits4x2:
return "Bits4x2";
case ScalarType::Bits8:
return "Bits8";
case ScalarType::Bits16:
return "Bits16";
case ScalarType::Float8_e5m2fnuz:
return "Float8_e5m2fnuz";
case ScalarType::Float8_e4m3fnuz:
return "Float8_e4m3fnuz";
case ScalarType::Float8_e8m0fnu:
return "Float8_e8m0fnu";
case ScalarType::Float4_e2m1fn_x2:
return "Float4_e2m1fn_x2";
case ScalarType::Undefined:
return "Undefined";
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}

inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, _2, name) \
case ScalarType::name: \
return sizeof(ctype);

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
return 1;
case ScalarType::QInt32:
case ScalarType::Bits16:
return 4;
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarType still declares the quantized enums via AT_FORALL_QINT_TYPES(...), but toString() and elementSize() no longer handle QInt8/QUInt8/QInt32/QUInt4x2/QUInt2x4 (the switches only expand AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS). This is a behavioral regression: printing these scalar types will return "UNKNOWN_SCALAR", and elementSize() will throw. Add QInt cases back (e.g., by also expanding AT_FORALL_QINT_TYPES in these switches) and ensure their element sizes match the intended semantics.

Copilot uses AI. Check for mistakes.
Comment on lines +48 to +49
PD_CHECK(false,
"native_handle() is not supported for this device type (",
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD_CHECK(false, ...) will throw a PD_Exception that appends an additional "Expected false, but it's not satisfied." context. If the goal is to match PyTorch's error message format exactly (as the comment states), this will not produce an exact match. Prefer throwing directly (e.g., TORCH_CHECK(false, ...) / PD_THROW(...)) or pass the real condition into PD_CHECK(...) instead of false so the extra context remains meaningful.

Suggested change
PD_CHECK(false,
"native_handle() is not supported for this device type (",
PD_THROW("native_handle() is not supported for this device type (",

Copilot uses AI. Check for mistakes.
Comment on lines +28 to 51
if (dim < 0 || dim >= self.dim()) {
PD_CHECK(false,
"select(): index ",
orig_dim,
" out of range for tensor of size ",
self.sizes(),
" at dimension ",
orig_dim);
}
// Handle negative index
int64_t orig_index = index;
if (index < 0) {
int64_t dim_size = self.size(dim);
index = dim_size + index;
index = self.size(dim) + index;
}
// Check index is valid
if (index < 0 || index >= self.size(dim)) {
PD_CHECK(false,
"select(): index ",
orig_index,
" out of range for tensor of size ",
self.sizes(),
" at dimension ",
orig_dim < 0 ? orig_dim + self.dim() : orig_dim);
}
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New error handling was added for invalid dim / index values, but the existing select tests (e.g., test/cpp/compat/ATen_select_test.cc) don't cover these out-of-range branches. Add test cases that assert an exception is thrown for (1) dim out of range (including negative beyond -self.dim()), and (2) index out of range (including negative beyond -size(dim)).

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 1, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@53326c8). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #78550   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         2           
  Lines              ?         6           
  Branches           ?         0           
===========================================
  Hits               ?         6           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边先请求修改,主要有 2 个阻塞问题:

  1. paddle/phi/api/include/compat/c10/core/ScalarType.h

    • 这次把 ScalarType 改成基于 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS 生成,同时删除了 toString() / elementSize() 中原来补齐的 case。
    • 结果会丢掉当前已有的 ComplexHalfBits*Float8_e5m2fnuzFloat8_e4m3fnuzFloat8_e8m0fnuFloat4_e2m1fn_x2 等类型;仓内 paddle/phi/api/include/compat/ATen/cuda/CUDADataType.h 仍然引用 Float8_e4m3fnuz / Float8_e5m2fnuz,ROCm 路径会直接编译失败。
    • 同时 QInt8/QUInt8/QInt32/QUInt4x2/QUInt2x4toString() / elementSize() 上会退化成 UNKNOWN_SCALAR / Unknown ScalarType
    • 这属于 source-level BREAKING CHANGE,也和 upstream PyTorch 当前 ScalarType 集合/辅助函数覆盖范围不一致。
  2. 同文件里 #define DEFINE_ST_ENUM_VAL_ 后写成了 #undef DEFINE_ENUM_ST_ENUM_VAL_,宏名不匹配,DEFINE_ST_ENUM_VAL_ 会泄漏到后续展开。

补充说明:

  • torch/csrc/api/include/torch/cuda.h 去掉 at::cuda::* re-export 这件事本身更接近 upstream PyTorch;Streamoperator<< 也基本是向 upstream 对齐。
  • 不过,torch/cuda.h 不再提供 at::cuda::is_available()/device_count() 的 include 侧别名,对依赖这两个别名且只包含 torch/cuda.h 的下游来说仍然是 source-level 变更;如果最终保留,建议在 PR 描述里明确。
  • 我看了失败 CI:PR-CI-SOT 失败点是 Python 3.14 环境导入 libpaddle.so_Py_TriggerGC 未定义,Linux-IXUCA 失败点是 Paddle-iluvatar patch apply 失败,这两项都不像是这个 PR diff 直接引入的问题。

建议先把 ScalarType 这一组回归修掉,再继续看 FlashMLA 这部分修复。

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我重新看了一遍当前版本,这次我上一轮指出的两个 blocker 已经解除,可以先给通过。

  1. ScalarType 兼容回归已恢复

    • ComplexHalf / QInt* / Bits* / Float8_*fnuz / Float4_e2m1fn_x2 的 source-level 枚举值、常量别名、toString()elementSize()isFloat8Type()isReducedFloatingType()isComplexType()isSignedType() 等兼容路径都补回来了。
    • ATen/cuda/CUDADataType.h 里依赖的 ROCm 路径类型也重新可达,不再有我上一轮指出的那类编译风险。
    • 新增的 c10_ScalarType_test.cc 用例也覆盖了这些恢复的分支;其中 Bits16 现在按 16-bit 语义返回 2,这一点也更接近 upstream PyTorch 的 bits16 定义。
  2. 宏名问题已消失

    • 之前那个 #undef DEFINE_ENUM_ST_ENUM_VAL_ / DEFINE_ST_ENUM_VAL_ 不匹配的问题,随着这次 ScalarType 回退恢复,也不再存在。

补充说明:

  • torch/cuda.h 中移除 at::cuda::{device_count,is_available} 的 re-export 这个 source-level 变化仍然存在,但这是本 PR 修复 FlashMLA 重定义冲突的核心改动,而且方向上与 upstream PyTorch 是一致的。
  • CI 里我重点区分了一下:PR-CI-SOT 现在已经通过;当前仍失败的 Linux-IXUCA 还是 Paddle-iluvatar patch apply 失败,看起来与这个 PR 的 diff 没有直接关系。

基于当前 diff,我这边没有新的阻塞问题,先给 approve。

Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit cebfafd into PaddlePaddle:develop Apr 3, 2026
82 of 83 checks passed
@youge325 youge325 deleted the fix-flashmla-compile branch April 3, 2026 01:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants