diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 7d20281225..bd288babfe 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -80,7 +80,9 @@ def download_file( # Handle local file copy if os.path.exists(source): try: - shutil.copy(source, local_path) + temp_path = f"{local_path}.tmp" + shutil.copy(source, temp_path) + os.replace(temp_path, local_path) # Atomic rename logger.info(f"File copied successfully: {local_path}") return True except Exception as e: @@ -93,9 +95,13 @@ def download_file( response = session.get(source, timeout=timeout) response.raise_for_status() - with open(local_path, "wb") as file: + temp_path = f"{local_path}.tmp" + with open(temp_path, "wb") as file: file.write(response.content) + # Atomic rename to prevent readers from seeing partial writes + os.replace(temp_path, local_path) + logger.info( f"File downloaded successfully: {source} -> {local_path}" )