File tree Expand file tree Collapse file tree 1 file changed +13
-3
lines changed
torch/csrc/jit/tensorexpr Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -947,11 +947,18 @@ void CudaCodeGen::CompileToNVRTC(
947947 // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
948948 // properly in some scenarios
949949 const auto prior_device = at::cuda::current_device ();
950- at::cuda::set_device (this ->device ().index ());
950+ if (prior_device != this ->device ().index ()) {
951+ at::cuda::set_device (this ->device ().index ());
952+ }
951953 // cudaSetDevice does not have to really change the underlying device if it
952954 // doesn't have to, so calling cudaFree to force that change
953955 CudaSetContext (pctx);
954-
956+ if (!pctx) {
957+ std::unique_lock<std::mutex> cudaFreeMutexLock (
958+ *(c10::cuda::CUDACachingAllocator::getFreeMutex ()));
959+ cudaFree (0 );
960+ AT_CUDA_DRIVER_CHECK (nvrtc ().cuCtxGetCurrent (&pctx));
961+ }
955962 // Acquires device and NVRTC properties (for compile arch and occupancy
956963 // calculations)
957964 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties ();
@@ -1003,7 +1010,10 @@ void CudaCodeGen::CompileToNVRTC(
10031010 AT_CUDA_DRIVER_CHECK (nvrtc ().cuModuleLoadData (&module , ptx.data ()));
10041011 AT_CUDA_DRIVER_CHECK (
10051012 nvrtc ().cuModuleGetFunction (&function_, module , func_name.c_str ()));
1006- at::cuda::set_device (prior_device);
1013+
1014+ if (prior_device != this ->device ().index ()) {
1015+ at::cuda::set_device (prior_device);
1016+ }
10071017}
10081018
10091019CudaCodeGen::~CudaCodeGen () = default ;
You can’t perform that action at this time.
0 commit comments