Skip to content

Commit b8bc254

Browse files
committed
Maintain the requirement for CUDA builds
1 parent 33e668f commit b8bc254

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

jaxlib/tools/build_gpu_kernels_wheel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@
7272
)
7373
args = parser.parse_args()
7474

75+
if args.enable_cuda and not args.nvidia_wheel_versions_data:
76+
parser.error('argument --nvidia_wheel_versions_data is required for CUDA builds')
77+
7578
r = runfiles.Create()
7679
pyext = "pyd" if build_utils.is_windows() else "so"
7780

jaxlib/tools/build_gpu_plugin_wheel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@
7777
help="NVIDIA wheel versions data",
7878
)
7979
args = parser.parse_args()
80-
80+
if args.enable_cuda and not args.nvidia_wheel_versions_data:
81+
parser.error('argument --nvidia_wheel_versions_data is required for CUDA builds')
8182
r = runfiles.Create()
8283

8384

0 commit comments

Comments
 (0)