diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 4ecb803ebee1..b48bf7d4d55a 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -67,11 +67,14 @@ parser.add_argument( "--nvidia_wheel_versions_data", default=None, - required=True, + required=False, help="NVIDIA wheel versions data", ) args = parser.parse_args() +if args.enable_cuda and not args.nvidia_wheel_versions_data: + parser.error('argument --nvidia_wheel_versions_data is required for CUDA builds') + r = runfiles.Create() pyext = "pyd" if build_utils.is_windows() else "so" diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 7ef6fdaaeea0..26887df22696 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -73,11 +73,12 @@ parser.add_argument( "--nvidia_wheel_versions_data", default=None, - required=True, + required=False, help="NVIDIA wheel versions data", ) args = parser.parse_args() - +if args.enable_cuda and not args.nvidia_wheel_versions_data: + parser.error('argument --nvidia_wheel_versions_data is required for CUDA builds') r = runfiles.Create()