Skip to content

[Cpp API Compatibility] Align some compat APIs#78484

Closed
youge325 wants to merge 22 commits intoPaddlePaddle:developfrom
youge325:cAlign
Closed

[Cpp API Compatibility] Align some compat APIs#78484
youge325 wants to merge 22 commits intoPaddlePaddle:developfrom
youge325:cAlign

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Mar 25, 2026

PR Category

Execute Infrastructure

PR Types

Bug fixes

Description

移除 DeepEP 已经迁移完成的旧接口

c10::Stream 暴露到 at::Streamnamespace at 下的接口均会暴露到 namespace torch,DeepEP有一个接口暂时用的是 at::cuda::CUDAStream,要改回 torch::Stream

修复编译 FlashMLA 遇到的 redefinition 错误,如果要用 at::cuda::is_available() ,需要 #include <ATen/cuda/CUDAContext.h> ,而不是 #include <torch/csrc/api/include/torch/cuda.h>

修复 at::arange 在不指定 dtype 时创建的 tensor 数据类型不是 kLong 的错误,用于 DeepGEMM 的对齐

已对齐的接口可以在 PFCCLab/PaddleCppAPITest#58 中查看

是否引起精度变化

Copilot AI review requested due to automatic review settings March 25, 2026 13:33
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 25, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Mar 25, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates Paddle’s C++ ATen/c10 compatibility layer to better align API shapes with PyTorch and removes legacy DeepEP-transition interfaces.

Changes:

  • Tighten CUDA stream-pool APIs by adding a getStreamFromPool(int priority, ...) overload and mapping the bool form onto it; also add missing CUDA error checking.
  • Expose c10::Stream as at::Stream.
  • Remove legacy raw_stream() / cudaStream_t compatibility overloads for stream/event/record_stream.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
paddle/phi/api/include/compat/c10/cuda/CUDAStream.h Adds priority-based pool API + adjusts high/low selection semantics; removes legacy raw_stream(); adds error-checking for priority range query.
paddle/phi/api/include/compat/c10/core/Stream.h Re-exports c10::Stream into namespace at.
paddle/phi/api/include/compat/c10/core/Event.h Removes legacy record(cudaStream_t) compatibility overload.
paddle/phi/api/include/compat/ATen/ops/record_stream.h Removes record_stream(at::cuda::CUDAStream) and record_stream(cudaStream_t) overloads, leaving record_stream(at::Stream).
paddle/phi/api/include/compat/ATen/core/TensorBody.h Removes corresponding Tensor::record_stream overload declarations for CUDAStream/cudaStream_t.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


namespace at {
using c10::Stream;
}
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c10::Stream is now exposed as at::Stream, but (unlike most other c10/core/* compat headers) it is not re-exported into namespace torch. If the goal is that namespace at APIs are also visible under torch (as mentioned in the PR description and as done e.g. in c10/core/Device.h), consider adding a namespace torch { using c10::Stream; } re-export here as well to keep the surface consistent for downstream users.

Suggested change
}
}
namespace torch {
using c10::Stream;
}

Copilot uses AI. Check for mistakes.
Comment on lines 680 to 682
}

void record_stream(at::Stream s) const;
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the record_stream(at::cuda::CUDAStream) / record_stream(cudaStream_t) overloads, the earlier forward-declaration block for c10::cuda::CUDAStream in this header (and its accompanying comment about the overload) is now unused/misleading. Please remove or update that forward-declaration/comment block to reflect the new API surface.

Copilot uses AI. Check for mistakes.
Comment on lines 49 to 52
#endif
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
inline void Tensor::record_stream(at::cuda::CUDAStream s) const {
record_stream(static_cast<at::Stream>(s));
}

// TODO(youge325): Remove after DeepEP paddle branch is updated to use
// at::Stream
inline void Tensor::record_stream(cudaStream_t s) const {
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor_.impl());
PD_CHECK(dense_tensor != nullptr,
"record_stream only supports DenseTensor, but got a non-dense "
"tensor implementation.");
PD_CHECK(dense_tensor->place().GetType() != phi::AllocationType::CPU,
"record_stream is not supported for CPU tensors.");
paddle::memory::RecordStream(dense_tensor->Holder(), s);
}
#endif
} // namespace at
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the CUDAStream/cudaStream_t record_stream overloads removed from this header, the conditional #include <c10/cuda/CUDAStream.h> at the top of this file appears to be unused now. Consider dropping that include to reduce unnecessary dependencies/compile time.

Copilot uses AI. Check for mistakes.
Comment on lines 190 to +192
inline CUDAStream getStreamFromPool(const bool isHighPriority = false,
c10::DeviceIndex device_index = -1);

Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getStreamFromPool(bool, device_index) forward declaration is only needed to carry default arguments, but it makes the header a bit harder to follow since the actual implementation is below. Consider moving the defaults onto the inline definition (and removing the separate forward declaration), or defining the bool overload before the int priority overload, to avoid a declaration-only prototype.

Copilot uses AI. Check for mistakes.
@youge325
Copy link
Copy Markdown
Contributor Author

/re-run all-failed

@SigureMo
Copy link
Copy Markdown
Member

移除 DeepEP 已经迁移完成的旧接口

PaddleFleet 有更新 DeepEP commit 的 PR 么?这个 PR 直接删不行的呀,没有形成「Paddle 更新(保留兼容接口)」->「DeepEP 更新」->「PaddleFleet 更新」->「Paddle 更新(清理兼容接口)」的闭环

另外这个 PR 建议等等 #78389,会解决 CI 里测不到编译问题的问题

@youge325
Copy link
Copy Markdown
Contributor Author

另外这个 PR 建议等等 #78389,会解决 CI 里测不到编译问题的问题

好的,我等这个CI过了再 rebase 一遍,后面还有一些用来对齐 Device 相关接口的 commits

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 28, 2026

Codecov Report

❌ Patch coverage is 97.20280% with 4 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@571341b). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddle/phi/api/include/compat/c10/core/Device.cpp 91.89% 3 Missing ⚠️
paddle/phi/api/include/compat/ATen/ops/resize.h 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #78484   +/-   ##
==========================================
  Coverage           ?   97.20%           
==========================================
  Files              ?        8           
  Lines              ?      143           
  Branches           ?        0           
==========================================
  Hits               ?      139           
  Misses             ?        4           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@SigureMo
Copy link
Copy Markdown
Member

SigureMo commented Apr 1, 2026

@ShigureNyako 来 review 下这个 PR

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这次 PR 的主线我理解为:继续把 Paddle 的 C++ compat 层往 PyTorch 语义上对齐,重点包括 stream / guard / device 相关接口、at::arange 的默认 dtype,以及一批配套测试。整体方向基本对,arange / resize_ / Device 这些改动和新增测试我这边看下来也比较合理。

当前我先卡一个必须补齐的兼容性问题:c10::Stream 现在只 re-export 到了 namespace at,但没有同步暴露到 namespace torch。PR 描述里明确提到下游要从临时的 at::cuda::CUDAStream 改回 torch::Stream;如果这里不补 torch::Stream,这条迁移链路还是不完整,会直接留下编译回归。这个点修好后,我这边暂时没有看到别的明确 blocker。

} // namespace std

namespace at {
using c10::Stream;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里目前只把 c10::Stream re-export 到了 namespace at,但没有同步暴露到 namespace torch。从 PR 描述看,这一轮兼容层调整的目标之一就是让下游能把临时的 at::cuda::CUDAStream 改回 torch::Stream;如果这里不补 namespace torch { using c10::Stream; },下游切回 torch::Stream 还是会直接编译失败。建议和 c10/core/Device.h 等 compat 头保持一致,把 torch 侧别名一起补上。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看看这个文件 paddle/phi/api/include/compat/torch/csrc/api/include/torch/types.h ,和torch对齐,所有 namespace at 的接口都会暴露到 namespace torch

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“看看这个文件 paddle/phi/api/include/compat/torch/csrc/api/include/torch/types.h ,和torch对齐,所有 namespace at 的接口都会暴露到 namespace torch

已确认,这个解释成立。torch/types.h 里确实通过 using namespace at;at 的声明暴露到 torch,而 upstream PyTorch 的 torch/types.h 也是同样设计;对应 upstream 的 c10/core/Stream.h 本身也没有单独再写 namespace torch { using c10::Stream; }。基于这个上下文,torch::Stream 不需要在 c10/core/Stream.h 里额外 re-export,我这里撤销这条 blocker。

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重新核对了你关于 torch/types.h 的说明,这个解释是成立的:compat 层的 torch/types.h 确实通过 using namespace at;at 的声明暴露到 torch,而 upstream PyTorch 的 torch/types.h 也是同样设计;对应 upstream 的 c10/core/Stream.h 本身也没有单独再写 namespace torch { using c10::Stream; }

基于这点,我上一轮把 torch::Stream 需要在 c10/core/Stream.h 里单独 re-export 作为 blocker 的判断不成立,这个阻塞点可以解除。除此之外,这个 PR 在 stream / device / arange 等 compat 对齐上的主线和测试覆盖我这边没有看到新的明确问题,因此这轮给出 approve。

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 1, 2026

/re-run all-failed

@SigureMo
Copy link
Copy Markdown
Member

SigureMo commented Apr 3, 2026

这个 PR 是不是可以 close 了?应该只剩 #78549

@ShigureNyako
Copy link
Copy Markdown
Contributor

这个 PR 是不是可以 close 了?应该只剩 #78549

我刚核对了当前状态:

所以从现在的状态看,这个大 PR 确实已经没有继续保留的必要了,close 掉是合理的,后续只需要继续跟 #78549 即可。

@youge325 youge325 closed this Apr 3, 2026
@youge325 youge325 deleted the cAlign branch April 3, 2026 03:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants