Skip to content

Commit af048f8

Browse files
author
Protonu Basu
committed
set device for cuda-codegen if new device is not prior device
1 parent 7c13a07 commit af048f8

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

torch/csrc/jit/tensorexpr/cuda_codegen.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,11 +947,18 @@ void CudaCodeGen::CompileToNVRTC(
947947
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
948948
// properly in some scenarios
949949
const auto prior_device = at::cuda::current_device();
950-
at::cuda::set_device(this->device().index());
950+
if (prior_device != this->device().index()) {
951+
at::cuda::set_device(this->device().index());
952+
}
951953
// cudaSetDevice does not have to really change the underlying device if it
952954
// doesn't have to, so calling cudaFree to force that change
953955
CudaSetContext(pctx);
954-
956+
if (!pctx) {
957+
std::unique_lock<std::mutex> cudaFreeMutexLock(
958+
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
959+
cudaFree(0);
960+
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
961+
}
955962
// Acquires device and NVRTC properties (for compile arch and occupancy
956963
// calculations)
957964
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
@@ -1003,7 +1010,10 @@ void CudaCodeGen::CompileToNVRTC(
10031010
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
10041011
AT_CUDA_DRIVER_CHECK(
10051012
nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
1006-
at::cuda::set_device(prior_device);
1013+
1014+
if (prior_device != this->device().index()) {
1015+
at::cuda::set_device(prior_device);
1016+
}
10071017
}
10081018

10091019
CudaCodeGen::~CudaCodeGen() = default;

0 commit comments

Comments
 (0)