Skip to content

[Cpp API Compatibility] Align misc apis#78555

Merged
SigureMo merged 9 commits intoPaddlePaddle:developfrom
youge325:align-misc-apis
Apr 3, 2026
Merged

[Cpp API Compatibility] Align misc apis#78555
SigureMo merged 9 commits intoPaddlePaddle:developfrom
youge325:align-misc-apis

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Apr 1, 2026

PR Category

Execute Infrastructure

PR Types

Improvements

Description

拆分自 #78484

杂项 API 对齐:Allocator、CUDAFunctions、CUDAGuard、CUDAStream、version.h 等。

变更详情

1. c10/core/Allocator.h 接口补齐

新增内容:

// 类型定义
typedef uint64_t CaptureId_t;
typedef int64_t MempoolId_t;
struct MempoolIdHash { ... };

// DataPtr 方法
void* mutable_get() const;

// Allocator 注册
void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority);
Allocator* GetAllocator(const DeviceType& t);

// 辅助类
struct InefficientStdFunctionContext {
  static DataPtr makeDataPtr(void* ptr, std::function<void(void*)> deleter, Device device);
};

行为修正:

  • is_simple_data_ptr() 语义修正为 get() == get_context()(与 PyTorch 一致)

2. c10/cuda/CUDAFunctions.h/cpp 重构

  • CUDAFunctions.h 中的部分实现移到新文件 CUDAFunctions.cpp
  • 修复 Windows 构建问题

3. c10/cuda/CUDAGuard.h 接口补齐

新增方法:

// CUDAGuard
DeviceIndex original_device() const;

// OptionalCUDAGuard
DeviceIndex original_device() const;
void reset();

4. c10/cuda/CUDAStream.h 扩展

新增功能:

// 优先级支持
static std::tuple<int, int> priority_range();
CUDAStream(Unchecked /*unused*/, Stream stream);
int priority() const;

// 序列化
StreamData3 pack3() const;
static CUDAStream unpack3(StreamId stream_id, DeviceIndex device_index, DeviceType device_type);

// 外部流包装
CUDAStream getStreamFromExternal(cudaStream_t ext_stream, c10::DeviceIndex device_index);

// 查询与同步
bool query() const;
void synchronize() const;

// 哈希支持
struct hash<CUDAStream>;

5. 新增 torch/version.h

提供与 PyTorch 版本信息相关的宏,用于第三方库(如 DeepGEMM、FlashMLA)检测 Paddle 兼容层的版本。

6. 测试修复

  • test/cpp/compat/ATen_TensorAccessor_test.cc: 修复 CUDA 测试条件
  • test/cpp/compat/c10_storage_test.cc: 修复 is_simple_data_ptr 断言
  • test/cpp/compat/c10_layout_test.cc: 补充 SparseCooTensorInferSize 断言
  • test/cpp/compat/compat_basic_test.cc: 修复 CUDA 测试条件

相关文档

是否引起精度变化

Copilot AI review requested due to automatic review settings April 1, 2026 11:19
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 1, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 1, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR continues the C++ API compatibility alignment work by bringing several misc compat-layer APIs and behaviors closer to LibTorch/PyTorch, and updating the corresponding C++ compat tests.

Changes:

  • Switch CUDA availability checks in compat tests to torch::cuda::is_available().
  • Expand c10::cuda::CUDAStream / c10::cuda::CUDAGuard compat APIs and adjust stream-pool and allocator behaviors to better match PyTorch.
  • Add missing compat headers/sources (e.g., torch/version.h, c10/cuda/CUDAFunctions.cpp) and strengthen sparse tensor test assertions.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test/cpp/compat/compat_basic_test.cc Uses torch::cuda::is_available() for CUDA-gated test sections.
test/cpp/compat/c10_storage_test.cc Aligns CUDA availability checks and updates is_simple_data_ptr expectations to PyTorch semantics.
test/cpp/compat/c10_layout_test.cc Adds extra sparse COO shape/dim assertions.
test/cpp/compat/ATen_TensorAccessor_test.cc Uses torch::cuda::is_available() for CUDA-gated accessor test.
paddle/phi/api/include/compat/torch/csrc/api/include/torch/version.h Introduces LibTorch-style version macros for the compat layer.
paddle/phi/api/include/compat/CMakeLists.txt Adds c10/cuda/CUDAFunctions.cpp to the compat build sources.
paddle/phi/api/include/compat/c10/cuda/CUDAStream.h Adds stream API surface (priority/query/sync/pack/unpack/hash/etc.) and new pool/external stream helpers.
paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h Tracks original/current device to better emulate PyTorch guard semantics.
paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.h Moves device_count/device_synchronize to out-of-line definitions; keeps stream sync helper gated by CUDA/HIP.
paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.cpp Implements device_count() and device_synchronize() out-of-line.
paddle/phi/api/include/compat/c10/core/Allocator.h Aligns is_simple_data_ptr semantics and adds allocator registry/utilities.
paddle/phi/api/include/compat/ATen/core/ivalue.h Adds camelCase API aliases and exposes c10::IValue alias for compatibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 230 to +240
/**
* Set the current CUDA stream for the device of the given stream in the
* calling thread.
*
* Implements per-thread, per-device current stream semantics: the change is
* local to the calling OS thread and does not affect any shared state such as
* Paddle's GPUContext. Other threads continue to see their own current stream.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority,
c10::DeviceIndex device_index) {
return getStreamFromPool(isHighPriority ? -1 : 0, device_index);
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comment above getStreamFromPool(...) describes "Set the current CUDA stream..." but the function below is a stream-pool accessor. This mismatch is misleading for users and makes the header harder to maintain; update the comment to describe getStreamFromPool (and keep setCurrentCUDAStream documented at its own definition).

Copilot uses AI. Check for mistakes.
* Paddle's GPUContext. Other threads continue to see their own current stream.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority,
c10::DeviceIndex device_index) {
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getStreamFromPool(bool isHighPriority, DeviceIndex ...) no longer has a default device_index. This makes calls like getStreamFromPool(true) either fail to compile or (more dangerously) bind to the getStreamFromPool(int priority, DeviceIndex=-1) overload with priority=1, changing semantics (high-priority requested but low-priority returned). Add device_index = -1 to the bool overload (matching PyTorch) and ensure bool arguments cannot silently resolve to the int overload.

Suggested change
c10::DeviceIndex device_index) {
c10::DeviceIndex device_index = -1) {

Copilot uses AI. Check for mistakes.
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 1, 2026

/re-run all-failed

1 similar comment
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这次拆分的方向我认可,Allocator / CUDAGuard / CUDAStream / IValue 这些接口补齐也覆盖了不少缺口;但当前版本我先给 REQUEST_CHANGES,主要有两个阻塞点:

  1. paddle/phi/api/include/compat/c10/cuda/CUDAStream.hgetStreamFromPool(bool, ...) 少了 device_index = -1 默认参数。这个不仅和 PyTorch 的签名不一致,而且会让 getStreamFromPool(true) 静默绑定到 int priority 重载,最终返回低优先级 stream,属于实际行为错误,不只是声明形状差异。
  2. 这个“misc apis”拆分 PR 仍然混入了 raw_stream() 的删除。该变更对当前 compat surface 来说是 breaking change,并且已经由 #78549 专门承载(Remove deepep legacy apis)。继续放在 #78555 里,会让 split 后的范围重新耦合,也让这组拆分 PR 出现遗留半迁移状态。

建议先把这两个点理顺:补齐 bool 重载默认参数并补一个 getStreamFromPool(true) 的覆盖测试;把 legacy cleanup 从本 PR 剥离,只保留与“剩余接口对齐”直接相关的新增/修正。处理完后我再继续看下一轮。


// TODO(youge325): Remove after DeepEP paddle branch is updated to use
// stream()
cudaStream_t raw_stream() const { return stream(); }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把现有的 raw_stream() 一并删掉了,但这已经不是单纯的“补齐剩余接口”,而是一个对当前 compat surface 的 breaking change。更关键的是,raw_stream() 的清理已经被单独拆到 #78549(Remove deepep legacy apis)里;如果 #78555 继续带着这段删除,会把 split 后的范围重新混在一起,也让两个 PR 在同一处 API 变更上重复/冲突。建议把 legacy cleanup 留在 #78549,这个 PR 只保留新增/对齐的接口。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把 legacy cleanup 留在 #78549,这个 PR 只保留新增/对齐的接口。

已处理。最新提交 53cb19c6 已把 raw_stream() 恢复到 CUDAStream,把这部分 cleanup 留回独立的 #78549;当前 PR 保留的是 misc API 对齐相关改动。

* Paddle's GPUContext. Other threads continue to see their own current stream.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority,
c10::DeviceIndex device_index) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里少了 device_index = -1 默认参数,和 PyTorch 的 getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1) 不一致。现在调用 getStreamFromPool(true) 时,不会命中这个 bool 重载,而会静默落到上面的 getStreamFromPool(int priority, DeviceIndex device_index = -1);由于 true 会被当成 1,而当前实现用 priority < 0 才走高优先级分支,结果会返回低优先级 stream,语义直接反了。这个问题建议在这里补上默认参数,并补一个 getStreamFromPool(true) 的覆盖测试。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里少了 device_index = -1 默认参数,和 PyTorch 的 getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1) 不一致。

已处理。最新提交已补成 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1),并在 test/cpp/compat/c10_Stream_test.cc 新增 GetStreamFromPoolBoolOverloadPreservesHighPriority,覆盖单参数 getStreamFromPool(true) 的高优先级行为。

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这轮更新已经把我上一轮卡的两个阻塞点都处理掉了:

  1. getStreamFromPool(bool, ...) 已补齐默认参数,和 PyTorch 的签名保持一致;
  2. raw_stream() 已恢复,legacy cleanup 不再混入这个 misc API 对齐 PR。

另外还补了 test/cpp/compat/c10_Stream_test.cc 中对 getStreamFromPool(true) 高优先级语义的覆盖,能把之前的重载绑定问题直接卡住。代码层面我这边没有新的阻塞问题,先通过;当前 CI 还在跑,继续看结果即可。

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

1 similar comment
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 2, 2026

/re-run all-failed

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 95.55556% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@bb246f6). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...addle/phi/api/include/compat/c10/cuda/CUDAStream.h 80.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #78555   +/-   ##
==========================================
  Coverage           ?   95.55%           
==========================================
  Files              ?        4           
  Lines              ?       45           
  Branches           ?        0           
==========================================
  Hits               ?       43           
  Misses             ?        2           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这里需要重新卡一下,这个新 head 把 develop 合进来之后,在 paddle/phi/api/include/compat/c10/cuda/CUDAStream.h 里引入了一个新的阻塞问题:

  • 文件里现在同时存在两份 getStreamFromPool(const bool isHighPriority, c10::DeviceIndex device_index);下面新增的版本只是补了默认参数和注释,但默认参数不参与函数签名,所以这不是合法 overload,而是重复定义。
  • 也就是说,我上一轮 approve 的是 53cb19c6 那个版本;当前 5f387c44 这个 head 需要先把这个重复定义收敛掉,再继续看 CI。

建议直接保留一份 bool 版本,把默认参数/注释放回现有实现,不要再追加第二个同签名函数。

* This overload matches PyTorch's bool-based entry point and preserves the
* single-argument form `getStreamFromPool(true)` for high-priority requests.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority = false,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在保留上方已有的 getStreamFromPool(const bool isHighPriority, c10::DeviceIndex device_index) 实现的同时,又新增了一份只靠默认参数区分的同名函数。默认参数不参与函数签名,所以这不是 overload,而是同一签名的重复定义;合到当前头文件后会形成 redefinition。\n\n建议不要再追加第二个 bool 版本,而是直接把默认参数和注释放回上面那一个已有实现:保留一份 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1) 即可。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议不要再追加第二个 bool 版本,而是直接把默认参数和注释放回上面那一个已有实现:保留一份 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1) 即可。

已处理。当前 head 085d55a0 已删掉那份额外的 bool 前向/重复声明,只保留了一份带默认参数的 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1) 定义,raw_stream() 也仍保留在本 PR 之外的拆分边界要求内。

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这轮新 head 继续需要调整:上一版的“同签名重复定义”虽然收敛掉了,但现在 CUDAStream.h 里变成了“前向声明和定义都重复携带默认参数”。

具体来说:

  • 上面已经有 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1);
  • 下面定义时又再次写了 = false / = -1

这会形成 default argument redefinition,仍然不是一个可接受的最终状态。建议只保留一处默认参数声明:要么放在前向声明里、定义去掉默认值;要么删掉前向声明,只保留带默认参数的单一定义。

* This overload matches PyTorch's bool-based entry point and preserves the
* single-argument form `getStreamFromPool(true)` for high-priority requests.
*/
inline CUDAStream getStreamFromPool(const bool isHighPriority = false,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这轮合并把上面的前向声明改成了 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1);,但下面这里的定义又重复写了一次默认参数。这个版本虽然不再是“重复定义函数体”,但仍然会形成 default argument redefinition。\n\n建议二选一:\n1. 要么只在前向声明保留默认参数,下面定义改成无默认参数;\n2. 要么删掉前向声明,把默认参数直接放在这里这一个定义上。\n\n按目前状态,和 PyTorch 对齐的修复还没有真正收敛。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议二选一:要么只在前向声明保留默认参数,下面定义改成无默认参数;要么删掉前向声明,只保留带默认参数的单一定义。

已处理。最新 head 已采用第二种方式:删除前向声明,只保留一处带默认参数的 bool 重载定义,因此上一轮指出的 default argument redefinition 问题已经解除。

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我重新看了当前 head 085d55a0,上一轮卡住的 default argument redefinition 已经修掉了:

  1. CUDAStream.h 里已经删除了多余的 bool 前向声明,只保留一处 getStreamFromPool(const bool isHighPriority = false, c10::DeviceIndex device_index = -1) 定义;
  2. getStreamFromPool(true) 的 PyTorch 一致性语义仍由现有测试覆盖;
  3. raw_stream() 仍保留在本 PR 中,拆分边界没有再次被污染,也没有看到新的 BREAKING CHANGE。

当前代码层面我这边没有新的阻塞问题,先通过。补充说明:Linux-IXUCA 这次失败的日志表现为外部 rebase / patch apply 流程问题,未看到当前 compat 改动再次触发之前那类 CUDAStream.h 声明错误;CI 结果继续跟进即可。

@SigureMo SigureMo closed this Apr 3, 2026
@SigureMo SigureMo reopened this Apr 3, 2026
Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit 6d87891 into PaddlePaddle:develop Apr 3, 2026
91 of 92 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants