Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,21 @@ pytype_test(
"notap",
],
)

pytype_test(
name = "jax_wheel_imports_test",
srcs = ["//jaxlib/tools:wheel_imports_test.py"],
args = [
"--package-name=jax",
],
main = "wheel_imports_test.py",
tags = [
"manual",
"notap",
],
deps = [
":jax_py_import",
"//jaxlib/tools:jaxlib_py_import"
],
)

12 changes: 6 additions & 6 deletions ci/run_bazel_test_cpu_rbe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
fi

if [[ "$JAXCI_BUILD_JAXLIB" == "false" ]]; then
WHEEL_SIZE_TESTS=""
WHEEL_TESTS=""
else
WHEEL_SIZE_TESTS="//jaxlib/tools:jaxlib_wheel_size_test"
WHEEL_TESTS="//jaxlib/tools:jaxlib_wheel_size_test //jaxlib/tools:jaxlib_wheel_imports_test"
fi

if [[ "$JAXCI_BUILD_JAX" != "false" ]]; then
WHEEL_SIZE_TESTS="$WHEEL_SIZE_TESTS //:jax_wheel_size_test"
WHEEL_TESTS="$WHEEL_TESTS //:jax_wheel_size_test //:jax_wheel_imports_test"
fi

if [[ "$JAXCI_HERMETIC_PYTHON_VERSION" == *"-nogil" ]]; then
Expand Down Expand Up @@ -82,7 +82,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] )
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
--test_output=errors \
--color=yes \
$WHEEL_SIZE_TESTS \
$WHEEL_TESTS \
//tests:cpu_tests //tests:backend_independent_tests \
//jax/experimental/jax2tf/tests:jax2tf_test_cpu \
//tests/multiprocess:cpu_tests \
Expand All @@ -100,7 +100,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] )
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
--test_output=errors \
--color=yes \
$WHEEL_SIZE_TESTS \
$WHEEL_TESTS \
//tests:cpu_tests //tests:backend_independent_tests \
//jax/experimental/jax2tf/tests:jax2tf_test_cpu \
//tests/multiprocess:cpu_tests \
Expand All @@ -119,7 +119,7 @@ else
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
--test_output=errors \
--color=yes \
$WHEEL_SIZE_TESTS \
$WHEEL_TESTS \
//tests:cpu_tests //tests:backend_independent_tests \
//jax/experimental/jax2tf/tests:jax2tf_test_cpu \
//tests/multiprocess:cpu_tests \
Expand Down
13 changes: 8 additions & 5 deletions ci/run_bazel_test_cuda_rbe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ fi
source "ci/utilities/setup_build_environment.sh"

if [[ "$JAXCI_BUILD_JAXLIB" == "false" ]]; then
WHEEL_SIZE_TESTS=""
WHEEL_TESTS=""
else
WHEEL_SIZE_TESTS="//jaxlib/tools:jax_cuda_plugin_wheel_size_test \
WHEEL_TESTS="//jaxlib/tools:jax_cuda_plugin_wheel_size_test \
//jaxlib/tools:jax_cuda_pjrt_wheel_size_test \
//jaxlib/tools:jaxlib_wheel_size_test"
//jaxlib/tools:jaxlib_wheel_size_test \
//jaxlib/tools:jax_cuda_plugin_wheel_imports_test \
//jaxlib/tools:jax_cuda_pjrt_wheel_imports_test \
//jaxlib/tools:jaxlib_wheel_imports_test"
fi

if [[ "$JAXCI_BUILD_JAX" != "false" ]]; then
WHEEL_SIZE_TESTS="$WHEEL_SIZE_TESTS //:jax_wheel_size_test"
WHEEL_TESTS="$WHEEL_TESTS //:jax_wheel_size_test //:jax_wheel_imports_test"
fi

if [[ "$JAXCI_BUILD_JAXLIB" != "true" ]]; then
Expand Down Expand Up @@ -71,4 +74,4 @@ bazel test --config=rbe_linux_x86_64_cuda${JAXCI_CUDA_VERSION} \
--//jax:build_jax=$JAXCI_BUILD_JAX \
//tests:gpu_tests //tests:backend_independent_tests \
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \
$WHEEL_SIZE_TESTS
$WHEEL_TESTS
63 changes: 62 additions & 1 deletion jaxlib/tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ licenses(["notice"]) # Apache 2

package(default_visibility = ["//visibility:public"])

exports_files(["wheel_size_test.py"])
exports_files([
"wheel_size_test.py",
"wheel_imports_test.py",
])

genrule(
name = "platform_tags_py",
Expand Down Expand Up @@ -476,6 +479,17 @@ py_import(
wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]),
)

# Py_import targets for tests.
py_import(
name = "jax_cuda_plugin_py_import_for_test",
wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version),
)

py_import(
name = "jax_cuda_pjrt_py_import_for_test",
wheel = ":jax_cuda{cuda}_pjrt_wheel".format(cuda = cuda_major_version),
)

# Mosaic GPU

py_binary(
Expand Down Expand Up @@ -633,3 +647,50 @@ pytype_test(
"notap",
],
)

pytype_test(
name = "jaxlib_wheel_imports_test",
srcs = [":wheel_imports_test.py"],
args = [
"--package-name=jaxlib",
],
main = "wheel_imports_test.py",
tags = [
"manual",
"notap",
],
deps = [":jaxlib_py_import"],
)

pytype_test(
name = "jax_cuda_plugin_wheel_imports_test",
srcs = [":wheel_imports_test.py"],
args = [
"--package-name=jax_cuda{cuda}_plugin".format(cuda = cuda_major_version),
],
main = "wheel_imports_test.py",
tags = [
"manual",
"notap",
],
deps = [":jax_cuda_plugin_py_import_for_test"],
)

pytype_test(
name = "jax_cuda_pjrt_wheel_imports_test",
srcs = [":wheel_imports_test.py"],
args = [
"--package-name=jax_plugins.xla_cuda{cuda}".format(cuda = cuda_major_version),
],
main = "wheel_imports_test.py",
tags = [
"manual",
"notap",
],
deps = [
# CUDA PJRT wheel depends on jax and jaxlib
":jax_cuda_pjrt_py_import_for_test",
":jaxlib_py_import",
"//:jax_py_import",
],
)
170 changes: 170 additions & 0 deletions jaxlib/tools/wheel_imports_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Verifies that all modules in a wheel are importable.

This script is designed to be run after a wheel has been installed. It
discovers all modules within a specified package and attempts to import each
one. This helps catch packaging issues where modules are missing or have
unmet dependencies.
"""

import argparse
import importlib
import logging
import pkgutil

logger = logging.getLogger()
logger.setLevel(logging.INFO)


def parse_args():
"""Arguments parser."""
parser = argparse.ArgumentParser(
description="Helper for the wheel package importing verification",
fromfile_prefix_chars="@",
)
parser.add_argument(
"--package-name", required=True, help="Name of the package to test"
)
return parser.parse_args()


def _discover_modules(package_name: str) -> list[str]:
"""Discovers all modules in a package.

Uses pkgutil.walk_packages to find all modules in a given package. It
includes an error handler to gracefully skip any modules that cause an
error during the discovery process.

Args:
package_name: The name of the package to inspect (e.g., 'jax').

Returns:
A sorted list of the names of all discoverable modules in the package.
"""
modules: set[str] = set()
package = importlib.import_module(package_name)
if hasattr(package, "__path__"):

def onerror(name):
"""An error handler for walk_packages to log and continue."""
logger.warning(
"pkgutil.walk_packages failed on module %s. Skipping.", name
)

for _, name, _ in pkgutil.walk_packages(
package.__path__, package.__name__ + ".", onerror=onerror
):
modules.add(name)

return sorted(modules)


def _is_c_extension(error: str) -> bool:
"""Returns True if the import error is from a C extension."""
return (
"dynamic module does not define module export function" in error.lower()
)


def verify_wheel_imports(args):
"""Verifies that all modules in the specified package can be imported.

This function first discovers all modules in the package and then attempts
to import each one. It logs any import failures and raises a RuntimeError if
any modules fail to import.

Args:
args: An argparse.Namespace object containing the parsed command-line
arguments.

Raises:
RuntimeError: If any modules fail to import.
"""
modules = _discover_modules(args.package_name)

modules_to_skip = [
# requires `xprof` to be installed
"jax.collect_profile",
# it's dependending on the Mosaic GPU bindings and will fail to import
"jax._src.lib.mosaic_gpu",
]
failed_imports = []

for module in modules:
if module not in modules_to_skip:
try:
importlib.import_module(module)
except ModuleNotFoundError as e:
# If the missing module is not part of the package being tested, it's
# an optional dependency. We can safely skip it.
if e.name and not e.name.startswith(args.package_name):
logger.info(
"Skipping module %s due to optional dependency: %s", module, e
)
else:
logger.warning(
"Module %s failed with an internal import error: %s", module, e
)
failed_imports.append(module)
except ImportError as e:
error_str = str(e)
# Some errors are expected for optional parts of JAX. We check for
# specific error strings here because JAX may raise generic
# ImportError exceptions for these cases.
if _is_c_extension(error_str):
logger.info(
"Skipping module %s due to optional dependencies or not being"
" importable: %s",
module,
e,
)
else:
logger.warning(
"Module %s failed with an internal import error: %s", module, e
)
failed_imports.append(module)
except Exception as e: # pylint: disable=broad-exception-caught
error_str = str(e)
# Some modules define config options at import time. Since we import
# all modules, we might try to define the same config option
# multiple times, which raises a generic Exception. We check for the
# error string to safely skip these import errors.
if "already defined" in error_str:
logger.info(
"Skipping module %s due to already defined config option: %s",
module,
e,
)
else:
logger.warning(
"Module %s raised an exception of type %s: %s",
module,
type(e).__name__,
e,
)
failed_imports.append(module)

if failed_imports:
raise RuntimeError(
f"Failed to import {len(failed_imports)}/{len(modules)} modules"
f" modules: {failed_imports}"
)

logger.info("Import of modules successful")


if __name__ == "__main__":
verify_wheel_imports(parse_args())
Loading