Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions paddle/phi/api/include/compat/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ class IValue {
bool is_custom_class() const { return tag_ == TypeTag::CustomClass; }
bool is_tuple() const { return tag_ == TypeTag::Tuple; }

bool isNone() const { return is_none(); }
bool isBool() const { return is_bool(); }
bool isInt() const { return is_int(); }
bool isDouble() const { return is_double(); }
bool isString() const { return is_string(); }
bool isList() const { return is_list(); }
bool isTensor() const { return is_tensor(); }
bool isCustomClass() const { return is_custom_class(); }
bool isTuple() const { return is_tuple(); }

bool to_bool() const {
if (!is_bool()) throw std::runtime_error("Not a bool");
return std::get<bool>(value_);
Expand Down Expand Up @@ -280,6 +290,39 @@ class IValue {
return static_cast<at::ScalarType>(std::get<int64_t>(value_));
}

bool toBool() const { return to_bool(); }
int64_t toInt() const { return to_int(); }
double toDouble() const { return to_double(); }
const std::string& toStringRef() const { return to_string(); }
std::string_view toStringView() const { return to_string_view(); }
at::Tensor toTensor() const { return to_tensor(); }
at::ScalarType toScalarType() const { return to_scalar_type(); }

std::string tagKind() const {
switch (tag_) {
case TypeTag::None:
return "None";
case TypeTag::Bool:
return "Bool";
case TypeTag::Int:
return "Int";
case TypeTag::Double:
return "Double";
case TypeTag::String:
return "String";
case TypeTag::Tensor:
return "Tensor";
case TypeTag::GenericList:
return "GenericList";
case TypeTag::CustomClass:
return "CustomClass";
case TypeTag::Tuple:
return "Tuple";
default:
return "InvalidTag";
}
}

template <typename T>
intrusive_ptr<T> to_custom_class() const {
if (!is_custom_class()) throw std::runtime_error("Not a custom class");
Expand Down Expand Up @@ -637,3 +680,7 @@ intrusive_ptr<T> generic_to(const IValue& ivalue,
}

} // namespace torch

namespace c10 {
using IValue = ::torch::IValue;
}
1 change: 1 addition & 0 deletions paddle/phi/api/include/compat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
collect_srcs(api_srcs SRCS c10/core/Device.cpp)
collect_srcs(api_srcs SRCS c10/core/Stream.cpp)
collect_srcs(api_srcs SRCS c10/cuda/CUDAFunctions.cpp)
collect_srcs(api_srcs SRCS c10/util/typeid.cpp)
collect_srcs(api_srcs SRCS ATen/cuda/EmptyTensor.cpp)
collect_srcs(api_srcs SRCS ATen/cuda/CUDAContextLight.cpp)
Expand Down
101 changes: 98 additions & 3 deletions paddle/phi/api/include/compat/c10/core/Allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
#include <c10/util/Exception.h>
#include <c10/util/UniqueVoidPtr.h>

#include <array>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <memory>
Expand All @@ -37,6 +39,15 @@ namespace c10 {
// Deleter function pointer type (compatible with LibTorch)
using DeleterFnPtr = void (*)(void*);

using CaptureId_t = uint64_t;
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;

struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};

// DataPtr class compatible with LibTorch's c10::DataPtr
// Wraps a pointer with associated device and deleter
class DataPtr {
Expand All @@ -63,6 +74,7 @@ class DataPtr {

void clear() { ptr_.clear(); }
void* get() const { return ptr_.get(); }
void* mutable_get() { return ptr_.get(); }
void* get_context() const { return ptr_.get_context(); }
void* release_context() { return ptr_.release_context(); }

Expand Down Expand Up @@ -128,15 +140,14 @@ struct Allocator {
// Requires: input data was allocated by the same allocator.
DataPtr clone(const void* data, std::size_t n) {
auto new_data = allocate(n);
copy_data(new_data.get(), data, n);
copy_data(new_data.mutable_get(), data, n);
return new_data;
}

// Checks if DataPtr has a simple context, not wrapped with any out of the
// ordinary contexts.
virtual bool is_simple_data_ptr(const DataPtr& data_ptr) const {
return data_ptr.get_context() == nullptr ||
data_ptr.get_context() == data_ptr.get();
return data_ptr.get() == data_ptr.get_context();
}

// If this returns a non nullptr, it means that allocate()
Expand Down Expand Up @@ -176,6 +187,90 @@ struct Allocator {
}
};

struct InefficientStdFunctionContext {
void* ptr_{nullptr};
std::function<void(void*)> deleter_;

InefficientStdFunctionContext(void* ptr, std::function<void(void*)> deleter)
: ptr_(ptr), deleter_(std::move(deleter)) {}

InefficientStdFunctionContext(const InefficientStdFunctionContext&) = delete;

InefficientStdFunctionContext(InefficientStdFunctionContext&& rhs) noexcept
: ptr_(std::exchange(rhs.ptr_, nullptr)),
deleter_(std::move(rhs.deleter_)) {}

InefficientStdFunctionContext& operator=(
const InefficientStdFunctionContext&) = delete;

InefficientStdFunctionContext& operator=(
InefficientStdFunctionContext&& rhs) {
this->~InefficientStdFunctionContext();
ptr_ = std::exchange(rhs.ptr_, nullptr);
deleter_ = std::move(rhs.deleter_);
return *this;
}

~InefficientStdFunctionContext() {
if (deleter_) {
deleter_(ptr_);
}
}

static DataPtr makeDataPtr(void* ptr,
std::function<void(void*)> deleter,
Device device) {
return DataPtr(ptr,
new InefficientStdFunctionContext(ptr, std::move(deleter)),
&deleteContext,
device);
}

private:
static void deleteContext(void* ptr) {
delete static_cast<InefficientStdFunctionContext*>(ptr);
}
};

inline constexpr size_t kAllocatorRegistrySize =
static_cast<size_t>(DeviceType::CUSTOM) + 1;

inline std::array<Allocator*, kAllocatorRegistrySize> g_allocator_array{};
inline std::array<uint8_t, kAllocatorRegistrySize> g_allocator_priority{};

inline size_t allocator_device_index(DeviceType t) {
const size_t index = static_cast<size_t>(t);
TORCH_CHECK(index < kAllocatorRegistrySize,
"Allocator device type index out of range: ",
index);
return index;
}

inline void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0) {
const size_t index = allocator_device_index(t);
if (priority >= g_allocator_priority[index]) {
g_allocator_array[index] = alloc;
g_allocator_priority[index] = priority;
}
}

inline Allocator* GetAllocator(const DeviceType& t) {
const size_t index = allocator_device_index(t);
auto* alloc = g_allocator_array[index];
TORCH_CHECK(alloc != nullptr, "Allocator for ", t, " is not set.");
return alloc;
}

template <DeviceType t>
struct AllocatorRegisterer {
explicit AllocatorRegisterer(Allocator* alloc) { SetAllocator(t, alloc); }
};

#define REGISTER_ALLOCATOR(t, f) \
namespace { \
static c10::AllocatorRegisterer<t> g_allocator_d(f); \
}

} // namespace c10

namespace at {
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <c10/cuda/CUDAFunctions.h>

namespace c10::cuda {

c10::DeviceIndex device_count() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return phi::backends::gpu::GetGPUDeviceCount();
#else
// Return 0 instead of throwing to match PyTorch API semantics
// at::cuda::is_available() relies on this returning 0/false
return 0;
#endif
}

void device_synchronize() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int curr_device_id = paddle::platform::GetCurrentDeviceId();
paddle::platform::SetDeviceId(curr_device_id);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#endif
#else
PADDLE_THROW(common::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit device synchronize."));
#endif
}

} // namespace c10::cuda
30 changes: 7 additions & 23 deletions paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,21 @@

namespace c10::cuda {

inline c10::DeviceIndex device_count() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return phi::backends::gpu::GetGPUDeviceCount();
#else
// Return 0 instead of throwing to match PyTorch API semantics
// at::cuda::is_available() relies on this returning 0/false
return 0;
#endif
}
c10::DeviceIndex device_count();

inline void device_synchronize() {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int curr_device_id = paddle::platform::GetCurrentDeviceId();
paddle::platform::SetDeviceId(curr_device_id);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#endif
#else
PADDLE_THROW(common::errors::Unavailable(
"Paddle is not compiled with CUDA. Cannot visit device synchronize."));
#endif
}
void device_synchronize();

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void __inline__ stream_synchronize(gpuStream_t stream) {
phi::backends::gpu::GpuStreamSync(stream);
}
#endif

} // namespace c10::cuda

namespace at::cuda {
using c10::cuda::device_synchronize;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
using c10::cuda::stream_synchronize;
#endif
} // namespace at::cuda
Loading
Loading