From a1449527565a0b791e6d6e748eee0206053860ae Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 10 Nov 2025 09:06:50 -0800 Subject: [PATCH] Add wheel import tests for JAX and jaxlib. This change introduces `wheel_imports_test.py`, a script that discovers and attempts to import all modules within a specified Python package. New Bazel `pytype_test` targets are added for `jax`, `jaxlib`, `jax_cuda_plugin`, and `jax_cuda_pjrt` wheels to run this test. These tests are integrated into the CPU and CUDA CI workflows to catch packaging issues. PiperOrigin-RevId: 830469562 --- BUILD.bazel | 18 +++ ci/run_bazel_test_cpu_rbe.sh | 12 +- ci/run_bazel_test_cuda_rbe.sh | 13 ++- jaxlib/tools/BUILD.bazel | 63 ++++++++++- jaxlib/tools/wheel_imports_test.py | 170 +++++++++++++++++++++++++++++ 5 files changed, 264 insertions(+), 12 deletions(-) create mode 100644 jaxlib/tools/wheel_imports_test.py diff --git a/BUILD.bazel b/BUILD.bazel index ce175425f435..e297377b02a6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -158,3 +158,21 @@ pytype_test( "notap", ], ) + +pytype_test( + name = "jax_wheel_imports_test", + srcs = ["//jaxlib/tools:wheel_imports_test.py"], + args = [ + "--package-name=jax", + ], + main = "wheel_imports_test.py", + tags = [ + "manual", + "notap", + ], + deps = [ + ":jax_py_import", + "//jaxlib/tools:jaxlib_py_import" + ], +) + diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index bca37b5c1831..dd21ef9b90bc 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -46,13 +46,13 @@ if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then fi if [[ "$JAXCI_BUILD_JAXLIB" == "false" ]]; then - WHEEL_SIZE_TESTS="" + WHEEL_TESTS="" else - WHEEL_SIZE_TESTS="//jaxlib/tools:jaxlib_wheel_size_test" + WHEEL_TESTS="//jaxlib/tools:jaxlib_wheel_size_test //jaxlib/tools:jaxlib_wheel_imports_test" fi if [[ "$JAXCI_BUILD_JAX" != "false" ]]; then - WHEEL_SIZE_TESTS="$WHEEL_SIZE_TESTS //:jax_wheel_size_test" + WHEEL_TESTS="$WHEEL_TESTS //:jax_wheel_size_test //:jax_wheel_imports_test" fi if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then @@ -82,7 +82,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ) --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - $WHEEL_SIZE_TESTS \ + $WHEEL_TESTS \ //tests:cpu_tests //tests:backend_independent_tests \ //jax/experimental/jax2tf/tests:jax2tf_test_cpu \ //tests/multiprocess:cpu_tests \ @@ -100,7 +100,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ) --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - $WHEEL_SIZE_TESTS \ + $WHEEL_TESTS \ //tests:cpu_tests //tests:backend_independent_tests \ //jax/experimental/jax2tf/tests:jax2tf_test_cpu \ //tests/multiprocess:cpu_tests \ @@ -119,7 +119,7 @@ else --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - $WHEEL_SIZE_TESTS \ + $WHEEL_TESTS \ //tests:cpu_tests //tests:backend_independent_tests \ //jax/experimental/jax2tf/tests:jax2tf_test_cpu \ //tests/multiprocess:cpu_tests \ diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 8aaee0505c0f..93ef3d40b32f 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -35,15 +35,18 @@ fi source "ci/utilities/setup_build_environment.sh" if [[ "$JAXCI_BUILD_JAXLIB" == "false" ]]; then - WHEEL_SIZE_TESTS="" + WHEEL_TESTS="" else - WHEEL_SIZE_TESTS="//jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + WHEEL_TESTS="//jaxlib/tools:jax_cuda_plugin_wheel_size_test \ //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ - //jaxlib/tools:jaxlib_wheel_size_test" + //jaxlib/tools:jaxlib_wheel_size_test \ + //jaxlib/tools:jax_cuda_plugin_wheel_imports_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_imports_test \ + //jaxlib/tools:jaxlib_wheel_imports_test" fi if [[ "$JAXCI_BUILD_JAX" != "false" ]]; then - WHEEL_SIZE_TESTS="$WHEEL_SIZE_TESTS //:jax_wheel_size_test" + WHEEL_TESTS="$WHEEL_TESTS //:jax_wheel_size_test //:jax_wheel_imports_test" fi if [[ "$JAXCI_BUILD_JAXLIB" != "true" ]]; then @@ -71,4 +74,4 @@ bazel test --config=rbe_linux_x86_64_cuda${JAXCI_CUDA_VERSION} \ --//jax:build_jax=$JAXCI_BUILD_JAX \ //tests:gpu_tests //tests:backend_independent_tests \ //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ - $WHEEL_SIZE_TESTS + $WHEEL_TESTS diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 3a1a48736e17..3533a880a317 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -43,7 +43,10 @@ licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) -exports_files(["wheel_size_test.py"]) +exports_files([ + "wheel_size_test.py", + "wheel_imports_test.py", +]) genrule( name = "platform_tags_py", @@ -476,6 +479,17 @@ py_import( wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) +# Py_import targets for tests. +py_import( + name = "jax_cuda_plugin_py_import_for_test", + wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version), +) + +py_import( + name = "jax_cuda_pjrt_py_import_for_test", + wheel = ":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version), +) + # Mosaic GPU py_binary( @@ -633,3 +647,50 @@ pytype_test( "notap", ], ) + +pytype_test( + name = "jaxlib_wheel_imports_test", + srcs = [":wheel_imports_test.py"], + args = [ + "--package-name=jaxlib", + ], + main = "wheel_imports_test.py", + tags = [ + "manual", + "notap", + ], + deps = [":jaxlib_py_import"], +) + +pytype_test( + name = "jax_cuda_plugin_wheel_imports_test", + srcs = [":wheel_imports_test.py"], + args = [ + "--package-name=jax_cuda{cuda}_plugin".format(cuda = cuda_major_version), + ], + main = "wheel_imports_test.py", + tags = [ + "manual", + "notap", + ], + deps = [":jax_cuda_plugin_py_import_for_test"], +) + +pytype_test( + name = "jax_cuda_pjrt_wheel_imports_test", + srcs = [":wheel_imports_test.py"], + args = [ + "--package-name=jax_plugins.xla_cuda{cuda}".format(cuda = cuda_major_version), + ], + main = "wheel_imports_test.py", + tags = [ + "manual", + "notap", + ], + deps = [ + # CUDA PJRT wheel depends on jax and jaxlib + ":jax_cuda_pjrt_py_import_for_test", + ":jaxlib_py_import", + "//:jax_py_import", + ], +) diff --git a/jaxlib/tools/wheel_imports_test.py b/jaxlib/tools/wheel_imports_test.py new file mode 100644 index 000000000000..a72c9b66cda7 --- /dev/null +++ b/jaxlib/tools/wheel_imports_test.py @@ -0,0 +1,170 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Verifies that all modules in a wheel are importable. + +This script is designed to be run after a wheel has been installed. It +discovers all modules within a specified package and attempts to import each +one. This helps catch packaging issues where modules are missing or have +unmet dependencies. +""" + +import argparse +import importlib +import logging +import pkgutil + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def parse_args(): + """Arguments parser.""" + parser = argparse.ArgumentParser( + description="Helper for the wheel package importing verification", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--package-name", required=True, help="Name of the package to test" + ) + return parser.parse_args() + + +def _discover_modules(package_name: str) -> list[str]: + """Discovers all modules in a package. + + Uses pkgutil.walk_packages to find all modules in a given package. It + includes an error handler to gracefully skip any modules that cause an + error during the discovery process. + + Args: + package_name: The name of the package to inspect (e.g., 'jax'). + + Returns: + A sorted list of the names of all discoverable modules in the package. + """ + modules: set[str] = set() + package = importlib.import_module(package_name) + if hasattr(package, "__path__"): + + def onerror(name): + """An error handler for walk_packages to log and continue.""" + logger.warning( + "pkgutil.walk_packages failed on module %s. Skipping.", name + ) + + for _, name, _ in pkgutil.walk_packages( + package.__path__, package.__name__ + ".", onerror=onerror + ): + modules.add(name) + + return sorted(modules) + + +def _is_c_extension(error: str) -> bool: + """Returns True if the import error is from a C extension.""" + return ( + "dynamic module does not define module export function" in error.lower() + ) + + +def verify_wheel_imports(args): + """Verifies that all modules in the specified package can be imported. + + This function first discovers all modules in the package and then attempts + to import each one. It logs any import failures and raises a RuntimeError if + any modules fail to import. + + Args: + args: An argparse.Namespace object containing the parsed command-line + arguments. + + Raises: + RuntimeError: If any modules fail to import. + """ + modules = _discover_modules(args.package_name) + + modules_to_skip = [ + # requires `xprof` to be installed + "jax.collect_profile", + # it's dependending on the Mosaic GPU bindings and will fail to import + "jax._src.lib.mosaic_gpu", + ] + failed_imports = [] + + for module in modules: + if module not in modules_to_skip: + try: + importlib.import_module(module) + except ModuleNotFoundError as e: + # If the missing module is not part of the package being tested, it's + # an optional dependency. We can safely skip it. + if e.name and not e.name.startswith(args.package_name): + logger.info( + "Skipping module %s due to optional dependency: %s", module, e + ) + else: + logger.warning( + "Module %s failed with an internal import error: %s", module, e + ) + failed_imports.append(module) + except ImportError as e: + error_str = str(e) + # Some errors are expected for optional parts of JAX. We check for + # specific error strings here because JAX may raise generic + # ImportError exceptions for these cases. + if _is_c_extension(error_str): + logger.info( + "Skipping module %s due to optional dependencies or not being" + " importable: %s", + module, + e, + ) + else: + logger.warning( + "Module %s failed with an internal import error: %s", module, e + ) + failed_imports.append(module) + except Exception as e: # pylint: disable=broad-exception-caught + error_str = str(e) + # Some modules define config options at import time. Since we import + # all modules, we might try to define the same config option + # multiple times, which raises a generic Exception. We check for the + # error string to safely skip these import errors. + if "already defined" in error_str: + logger.info( + "Skipping module %s due to already defined config option: %s", + module, + e, + ) + else: + logger.warning( + "Module %s raised an exception of type %s: %s", + module, + type(e).__name__, + e, + ) + failed_imports.append(module) + + if failed_imports: + raise RuntimeError( + f"Failed to import {len(failed_imports)}/{len(modules)} modules" + f" modules: {failed_imports}" + ) + + logger.info("Import of modules successful") + + +if __name__ == "__main__": + verify_wheel_imports(parse_args())