diff --git a/patches/xla.patch b/patches/xla.patch deleted file mode 100644 index abd039dea6..0000000000 --- a/patches/xla.patch +++ /dev/null @@ -1,62 +0,0 @@ -diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl -index d6022c80d8..42741e0e6f 100644 ---- a/third_party/gpus/rocm/BUILD.tpl -+++ b/third_party/gpus/rocm/BUILD.tpl -@@ -563,6 +563,24 @@ filegroup( - visibility = ["//visibility:public"], - ) - -+sh_binary( -+ name = "hipcc", -+ srcs = ["%{rocm_root}/bin/hipcc"], -+ visibility = ["//visibility:public"], -+) -+ -+sh_binary( -+ name = "llvm-link", -+ srcs = ["%{rocm_root}/lib/llvm/bin/llvm-link"], -+ visibility = ["//visibility:public"], -+) -+ -+sh_binary( -+ name = "opt", -+ srcs = ["%{rocm_root}/lib/llvm/bin/opt"], -+ visibility = ["//visibility:public"], -+) -+ - filegroup( - name = "toolchain_data", - srcs = glob([ -diff --git a/third_party/rocm_device_libs/build_defs.bzl b/third_party/rocm_device_libs/build_defs.bzl -index 207102de8c..6aa7e535fe 100644 ---- a/third_party/rocm_device_libs/build_defs.bzl -+++ b/third_party/rocm_device_libs/build_defs.bzl -@@ -103,6 +103,7 @@ def _bitcode_library_impl(ctx): - inputs = [strip_out], - outputs = [final_bc], - arguments = [strip_out.path, "-o", final_bc.path], -+ use_default_shell_env = True, - progress_message = "Preparing final bitcode for {}".format(ctx.label.name), - mnemonic = "RocmBitCodeFinalize", - ) -@@ -119,17 +120,17 @@ bitcode_library = rule( - "hdrs": attr.label_list(allow_files = [".h"]), - "file_specific_flags": attr.string_dict(), - "_clang": attr.label( -- default = Label("@llvm-project//clang:clang"), -+ default = Label("@local_config_rocm//rocm:hipcc"), - executable = True, - cfg = "exec", - ), - "_llvm_link": attr.label( -- default = Label("@llvm-project//llvm:llvm-link"), -+ default = Label("@local_config_rocm//rocm:llvm-link"), - executable = True, - cfg = "exec", - ), - "_opt": attr.label( -- default = Label("@llvm-project//llvm:opt"), -+ default = Label("@local_config_rocm//rocm:opt"), - executable = True, - cfg = "exec", - ), diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 1f81fcc95d..50ba6cddc1 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -17,6 +17,4 @@ def repo(extra_patches = [], override_commit = ""): strip_prefix = "openxla-xla-{commit}".format(commit = commit[:7]), urls = ["https://api.github.com/repos/openxla/xla/tarball/{commit}".format(commit = commit)], patch_cmds = XLA_PATCHES + extra_patches, - patches = ["//:patches/xla.patch"], - patch_args = ["-p1"], ) diff --git a/workspace.bzl b/workspace.bzl index 65c3debe6e..7067ea9f20 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -1,4 +1,4 @@ -JAX_COMMIT = "071a6a5a9a70b2447f0164e61e3e0333318422a8" +JAX_COMMIT = "e45b7b654c5734ba0ebf22a771a947f9a5507859" JAX_SHA256 = "" ENZYME_COMMIT = "6c20ffc94c5abff04831f22caf46fe1b25c069be"