From dea6fb0e1691789de9e21749128a8ce5c2daaa36 Mon Sep 17 00:00:00 2001 From: zhaoguochun1995 Date: Mon, 4 Mar 2024 16:28:53 +0800 Subject: [PATCH 1/2] add diopi generator --- impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp b/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp index 3f960eb03..9005b3ef5 100644 --- a/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp +++ b/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp @@ -2865,9 +2865,11 @@ std::pair NPUGeneratorImpl::philox_engine_inputs(uint64_t in return ret; } +thread_local static at::Generator gDiopiGenerator[16]; + namespace detail { -const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { INTERFACE_NOT_IMPL; } +const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { return gDiopiGenerator[device_index]; } } // namespace detail @@ -3046,9 +3048,11 @@ inline const at::Tensor buildATen(diopiConstTensorHandle_t tensor) { #endif at::Generator buildATen(diopiGeneratorHandle_t generator) { - auto gen = at::make_generator(current_device()); + int64_t currentDeviceIndex = current_device(); + auto gen = at::make_generator(currentDeviceIndex); auto impl = static_cast(gen.unsafeGetGeneratorImpl()); impl->generator_ = generator; + at_npu::gDiopiGenerator[currentDeviceIndex] = gen; return gen; } From 6d552ba061e0d2a567a460e8e5dc785e299173c0 Mon Sep 17 00:00:00 2001 From: jingguo-st Date: Wed, 10 Apr 2024 10:22:25 +0800 Subject: [PATCH 2/2] fix rotary_embedding, getDefaultNPUGenerator --- .../diopi_impl/functions_ext/rotary_embedding.cpp | 2 +- impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp index 1c7e6f4ea..a8ee8fac3 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp @@ -24,7 +24,7 @@ at::Tensor viewAs4D(const at::Tensor& input) { for (int i = 0; i < dim; ++i) { viewShape[i + n - dim] = inputShape[i]; } - return input.view(viewShape); + return impl::aten::viewStorage(input, viewShape); } DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos, diff --git a/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp b/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp index 0747ef9c3..39c0d50eb 100755 --- a/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp +++ b/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp @@ -2932,7 +2932,12 @@ thread_local static at::Generator gDiopiGenerator[16]; namespace detail { -const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { return gDiopiGenerator[device_index]; } +const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { + if (device_index == -1) { + device_index = current_device(); + } + return gDiopiGenerator[device_index]; +} } // namespace detail