From 33e668f91f9de6136eac69b11a6a7dfc7f89faa4 Mon Sep 17 00:00:00 2001 From: Thanh Binh Date: Fri, 31 Oct 2025 18:12:40 -0500 Subject: [PATCH 1/2] Make nvidia version data optional for ROCm builds (cherry picked from commit 730790a538196299d317a429107b7f4771319077) --- jaxlib/tools/build_gpu_kernels_wheel.py | 2 +- jaxlib/tools/build_gpu_plugin_wheel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 4ecb803ebee1..20babb2f38fd 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -67,7 +67,7 @@ parser.add_argument( "--nvidia_wheel_versions_data", default=None, - required=True, + required=False, help="NVIDIA wheel versions data", ) args = parser.parse_args() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 7ef6fdaaeea0..8256a89a91b3 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -73,7 +73,7 @@ parser.add_argument( "--nvidia_wheel_versions_data", default=None, - required=True, + required=False, help="NVIDIA wheel versions data", ) args = parser.parse_args() From b8bc25465b55578cf5b43eeaccba33c61a5e854c Mon Sep 17 00:00:00 2001 From: Thanh Binh Date: Thu, 13 Nov 2025 16:28:24 -0600 Subject: [PATCH 2/2] Maintain the requirement for CUDA builds --- jaxlib/tools/build_gpu_kernels_wheel.py | 3 +++ jaxlib/tools/build_gpu_plugin_wheel.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 20babb2f38fd..b48bf7d4d55a 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -72,6 +72,9 @@ ) args = parser.parse_args() +if args.enable_cuda and not args.nvidia_wheel_versions_data: + parser.error('argument --nvidia_wheel_versions_data is required for CUDA builds') + r = runfiles.Create() pyext = "pyd" if build_utils.is_windows() else "so" diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 8256a89a91b3..26887df22696 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -77,7 +77,8 @@ help="NVIDIA wheel versions data", ) args = parser.parse_args() - +if args.enable_cuda and not args.nvidia_wheel_versions_data: + parser.error('argument --nvidia_wheel_versions_data is required for CUDA builds') r = runfiles.Create()