diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 0e772f91..12a6799d 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -219,13 +219,13 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, fused_kv_cache.data_ptr(), {num_kv_blocks, block_kv, head_dim}, {kv_cache_stride_bytes, head_dim, 1}, - torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(fused_kv_cache.device()) + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn) ); const auto& kv_cache_scales = torch::from_blob( fused_kv_cache.data_ptr() + block_kv * head_dim, {num_kv_blocks, block_kv}, {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, - torch::TensorOptions().dtype(torch::kFloat32).device(fused_kv_cache.device()) + torch::TensorOptions().dtype(torch::kFloat32) ); // Allocate output diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index 418b2c4d..dcc4def0 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -35,7 +35,7 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { DG_HOST_ASSERT(not disable_ue8m0_cast); - const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().dtype(torch::kInt).device(sf.device())).floor_divide_(128)); + const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); } diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index de4dc89b..ec81bec0 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -3,6 +3,7 @@ #define PADDLE_WITH_CUDA // make sure gpuStream_t declaration #include +#include #include #include @@ -19,7 +20,7 @@ class DeviceRuntime { static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; public: -#if false +#if TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3) // For PyTorch 2.3+, share the PyTorch cuBLASLt handle DeviceRuntime() = default; diff --git a/csrc/utils/compatibility.hpp b/csrc/utils/compatibility.hpp index 45b557c0..fb45e3d8 100644 --- a/csrc/utils/compatibility.hpp +++ b/csrc/utils/compatibility.hpp @@ -1,9 +1,10 @@ #pragma once +#include #include // `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1 -#define DG_FP8_COMPATIBLE true +#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1)) // `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1 -#define DG_TENSORMAP_COMPATIBLE true \ No newline at end of file +#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010) \ No newline at end of file