diff --git a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp index 1c7e6f4ea..a8ee8fac3 100644 --- a/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp +++ b/impl/ascend_npu/diopi_impl/functions_ext/rotary_embedding.cpp @@ -24,7 +24,7 @@ at::Tensor viewAs4D(const at::Tensor& input) { for (int i = 0; i < dim; ++i) { viewShape[i + n - dim] = inputShape[i]; } - return input.view(viewShape); + return impl::aten::viewStorage(input, viewShape); } DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos, diff --git a/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp b/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp index eec9a0cc7..39c0d50eb 100755 --- a/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp +++ b/impl/ascend_npu/torch_npu/csrc/DIOPIAdapter.cpp @@ -2928,9 +2928,16 @@ std::pair NPUGeneratorImpl::philox_engine_inputs(uint64_t in return ret; } +thread_local static at::Generator gDiopiGenerator[16]; + namespace detail { -const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { INTERFACE_NOT_IMPL; } +const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { + if (device_index == -1) { + device_index = current_device(); + } + return gDiopiGenerator[device_index]; +} } // namespace detail @@ -3109,9 +3116,11 @@ inline const at::Tensor buildATen(diopiConstTensorHandle_t tensor) { #endif at::Generator buildATen(diopiGeneratorHandle_t generator) { - auto gen = at::make_generator(current_device()); + int64_t currentDeviceIndex = current_device(); + auto gen = at::make_generator(currentDeviceIndex); auto impl = static_cast(gen.unsafeGetGeneratorImpl()); impl->generator_ = generator; + at_npu::gDiopiGenerator[currentDeviceIndex] = gen; return gen; }