diff --git a/build/build.py b/build/build.py index 70e9bc225eb5..e47895c5f280 100755 --- a/build/build.py +++ b/build/build.py @@ -722,6 +722,13 @@ async def main(): wheel_dir = wheel.replace("cuda", f"cuda{cuda_major_version}").replace( "-", "_" ) + elif "rocm" in wheel: + if args.editable: + # For editable builds, use the actual ROCm version since directory paths cannot contain wildcards + wheel_dir = wheel.replace("rocm", f"rocm{args.rocm_version}").replace("-", "_") + else: + # For non-editable builds, use wildcard pattern to match any ROCm version in glob patterns + wheel_dir = wheel.replace("rocm", "rocm*").replace("-", "_") else: wheel_dir = wheel @@ -739,7 +746,7 @@ async def main(): wheel_version_suffix += ( f"+{wheel_git_hash}{custom_wheel_version_suffix}" ) - if wheel in ["jax", "jax-cuda-pjrt"]: + if wheel in ["jax", "jax-cuda-pjrt", "jax-rocm-pjrt"]: python_tag = "py" else: python_tag = "cp"