diff --git a/jax/_src/aot.py b/jax/_src/aot.py new file mode 100644 index 000000000000..63ccdc537e11 --- /dev/null +++ b/jax/_src/aot.py @@ -0,0 +1,159 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX AOT API""" + +from collections.abc import Hashable +import functools +import traceback +from typing import Any, Callable, Sequence + + +from absl import logging +from jax._src import aot_util +from jax._src import api +from jax._src import api_util +from jax._src import core +from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib +from jax._src import traceback_util +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import func as func_dialect + + +UserKey = Hashable | Callable[..., Hashable] +ComponentKey = aot_util.ComponentKey +get_cache = aot_util.get_cache + + +def component( + key: UserKey = None, +) -> Callable[..., Any]: + def _component(fun: Callable[..., Any]): + # TODO(dsuo): Need to consider static args, etc if fun is jitted. + component_key = ComponentKey(key) + + if component_key in aot_util._wrapper_cache.cache_keys(): + return aot_util._wrapper_cache.get(component_key) + + @api.jit + @util.wraps(fun) + @traceback_util.api_boundary + def wrapper(*args, **kwargs): + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + wrapped_fun = aot_util.wrap_init(fun, "component") + flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) + # TODO(dsuo): We need this in vmap(vmap(f)) because otherwise we create a + # new flat_fun and may not have called it yet; store will be empty. + flat_fun = aot_util.cached_flat_fun(flat_fun) + out_flat = component_p.bind( + *args_flat, + fun=api.jit(flat_fun.f_transformed), + component_key=component_key, + ) + return tree_util.tree_unflatten(out_tree(), out_flat) + + wrapper.key = component_key + wrapper.fun = fun + aot_util._wrapper_cache.put(component_key, wrapper) + return wrapper + + return _component + + +def component_impl(*args, fun: Callable[..., Any], **_): + return fun(*args) + + +def component_abstract_eval( + *args, + fun: Callable[..., Any], + component_key: ComponentKey, +) -> Sequence[core.AbstractValue] | None: + # TODO(dsuo): Is this an effectful rule since we read/write to disk? + entry = aot_util.get_entry(component_key) + if entry is None: + # TODO(dsuo): By the time we get to lowering, our trace context has picked + # up an empty AbstractMesh. Don't know why. + with mesh_lib.use_abstract_mesh(mesh_lib.AbstractMesh((), (), ())): + avals_out = tree_util.tree_map( + lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) + ) + aot_util.put_entry(component_key, entry := aot_util.CacheEntry(avals_out)) + return entry.avals_out + + +def component_lowering( + ctx: mlir.LoweringRuleContext, + *args, + fun: Callable[..., Any], + component_key: ComponentKey, +) -> Sequence[ir.Value]: + module_name = aot_util.get_module_name(ctx.module_context.module) + entry = aot_util.get_entry(component_key, ctx.module_context.context) + if entry is None: + raise ValueError("Should hit abstract_eval already, which would populate.") + + if (module := entry.module) is None: + entry.module = module = aot_util.lower_component_to_module( + ctx, fun, module_name, component_key + ) + aot_util.put_entry(component_key, entry, update=True) + + return aot_util.get_module_results(ctx, module, module_name, *args) + + +def component_batcher( + axis_data, + vals_in, + dims_in, + fun: Callable[..., Any], + component_key: ComponentKey, +): + # Missing from batching process_call: + # TODO(dsuo): Ignore ragged. + # TODO(dsuo): Ignore updating annotations. + + # TODO(dsuo): This doesn't handle nesting. + wrapped_fun = aot_util.wrap_init(fun, "vmap(component)") + + # ????(dsuo): I don't understand trace tags. + batched_fun, dims_out = batching.batch_subtrace( + wrapped_fun, core.TraceTag(), axis_data, tuple(dims_in) + ) + # TODO(dsuo): We might need to reset stores because we may be calling a cached + # wrapped fun. + batched_fun = aot_util.maybe_reset_stores(batched_fun) + + vals_out = component_p.bind( + *vals_in, + fun=batched_fun.f_transformed, + component_key=ComponentKey.vmap(component_key), + ) + return vals_out, dims_out() + + +def component_jvp(arg_values, arg_tangents): + pass + + +component_p = core.Primitive("component") +component_p.multiple_results = True +component_p.def_impl(component_impl) +component_p.def_abstract_eval(component_abstract_eval) +mlir.register_lowering(component_p, component_lowering) +batching.fancy_primitive_batchers[component_p] = component_batcher diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py new file mode 100644 index 000000000000..d0b14da44409 --- /dev/null +++ b/jax/_src/aot_util.py @@ -0,0 +1,278 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX AOT API utilities.""" + +from collections.abc import Hashable +import functools +import pickle +import traceback +from typing import Any, Callable, NamedTuple, Self, Sequence + +from absl import logging +from jax._src import api +from jax._src import config +from jax._src import core +from jax._src import linear_util as lu +from jax._src import stages +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.lib import xla_client as xc +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import func as func_dialect + + +# For now, we don't worry about serialization. +SerializedType = bytes | Any + + +class ComponentKey: + def __init__(self, user_key: Hashable): + self.user_key = user_key + + def __hash__(self): + return hash(self.user_key) + + def __eq__(self, other): + return isinstance(other, ComponentKey) and self.user_key == other.user_key + + def __str__(self): + return self.user_key + + def __repr__(self): + return self.__str__() + + # TODO(dsuo): This is just a hack for now. + @classmethod + def vmap(cls, key): + return ComponentKey(f"vmap({key.user_key})") + + +def _validate_component_cache(val): + logging.info("Validating component cache config.") + assert val is None or isinstance(val, Cache) + if val is not None: + util.register_cache(val, "aot_cache") + + +component_cache = config.string_or_object_state( + name="jax_component_cache", + default=None, + help="Cache dir for components. Components won't be cached if None.", + validator=_validate_component_cache, +) + + +class CacheEntry: + def __init__( + self, + avals_out: Sequence[core.AbstractValue] | None, + module: ir.Module | None = None, + ): + self.avals_out = avals_out + self.module = module + + def serialize(self) -> SerializedType: + module_bytecode = None + if self.module is not None: + module_bytecode = mlir.module_to_bytecode(self.module) + return pickle.dumps((self.avals_out, module_bytecode)) + + @classmethod + def deserialize( + cls, blob: SerializedType, ctx: ir.Context | None = None + ) -> Self: + avals_out, module_bytecode = pickle.loads(blob) + if module_bytecode is None or ctx is None: + module = None + else: + with ctx: + module = ir.Module.parse(module_bytecode) + return cls(avals_out, module) + + +# TODO(dsuo): This should be a protocol. +class Cache: + def __init__(self): + self._in_memory_cache: dict[ComponentKey, SerializedType] = {} + self._in_memory_cache_info: dict[ComponentKey, dict[str, Any]] = {} + + def get(self, key: ComponentKey) -> SerializedType | None: + entry = self._in_memory_cache.get(key, None) + if entry is not None: + self._in_memory_cache_info[key] = dict( + hits=self._in_memory_cache_info[key]["hits"] + 1 + ) + return entry + + def put(self, key: ComponentKey, data: SerializedType, update: bool): + self._in_memory_cache[key] = data + if not update: + self._in_memory_cache_info[key] = dict(hits=0) + + def cache_keys( + self, + ) -> list[ComponentKey]: + return list(self._in_memory_cache.keys()) + + def cache_clear(self) -> None: + self._in_memory_cache.clear() + self._in_memory_cache_info.clear() + + def cache_info(self, key: ComponentKey) -> dict[str, Any]: + if key not in self._in_memory_cache_info: + raise ValueError(f"`{key}` not found in self._in_memory_cache_info") + return self._in_memory_cache_info[key] + + +def get_cache() -> Cache | None: + return component_cache.value + + +def get_entry( + key: ComponentKey, ctx: ir.Context | None = None +) -> CacheEntry | None: + if (cache := get_cache()) is not None: + if (blob := cache.get(key)) is not None: + return CacheEntry.deserialize(blob, ctx) + return None # sigh pytype + + +def put_entry( + key: ComponentKey, entry: CacheEntry, update: bool = False +) -> None: + if (cache := get_cache()) is not None: + cache.put(key, entry.serialize(), update) + + +@lu.cache +def cached_flat_fun(flat_fun): + return maybe_reset_stores(flat_fun) + + +# TODO(dsuo): Share logic with pmap. +def maybe_reset_stores(fun): + # TODO(dsuo): Hack to clear lu.Store borrowed from pmap. + f_transformed = fun.f_transformed + + # TODO(dsuo): Add this as a transformation. + def reset_stores_f_transformed(*args, **kwargs): + for store in fun.stores: + if store is not None: + store.reset() + return f_transformed(*args, **kwargs) + + fun.f_transformed = reset_stores_f_transformed + return fun + + +class WrapperCache: + def __init__(self): + self.data = dict() + self.info = dict() + + def get(self, key: ComponentKey) -> xc._xla.PjitFunction | None: + fun = self.data.get(key, None) + if fun is not None: + self.info[key]["hits"] += 1 + return fun + + def put(self, key: ComponentKey, fun: xc._xla.PjitFunction): + fun = self.data.setdefault(key, fun) + info = self.info.setdefault(key, dict(hits=0)) + self.info[key]["hits"] = 0 + + def cache_info(self): + return self.info + + def cache_clear(self): + self.data.clear() + + def cache_keys(self): + return self.data.keys() + + +_wrapper_cache = WrapperCache() +util.register_cache(_wrapper_cache, "aot_wrapper_cache") + + +def lower_component_to_module( + ctx: mlir.LoweringRuleContext, + fun: Callable[..., Any], + module_name: str, + component_key: ComponentKey, +) -> ir.Module: + traced = api.trace(fun, *ctx.avals_in) + lowering_result = mlir.lower_jaxpr_to_module( + module_name=module_name, + jaxpr=traced.jaxpr, + num_const_args=traced._num_consts, + in_avals=ctx.avals_in, + # TODO(dsuo): What are ordered effects vs effects? + ordered_effects=traced.jaxpr.effects, + # TODO(dsuo): Figure out why ctx.platforms=None. + platforms=["cpu"], + backend=ctx.module_context.backend, + axis_context=ctx.module_context.axis_context, + donated_args=tuple( + x.donated for x in tree_util.tree_leaves(traced.args_info) + ), + lowering_parameters=mlir.LoweringParameters(), + # TODO(dsuo): Presumably we need to forward the rest of the arguments to + # lower_jaxpr_to_module? + ) + # TODO(dsuo): What should we do about the other attributes on + # LoweringResult? + # - keepalive: probably not supported. + # - host_callbacks: probably not supported. + # - shape_poly_state: talk to necula@ + module = lowering_result.module + # TODO(dsuo): Do we need to do this step to strip context? + module = ir.Module.parse(mlir.module_to_bytecode(module)) + + return module + + +def get_module_name(module: ir.Module) -> str: + # TODO(dsuo): Is this reasonable? + module_name = str(module.operation.attributes["sym_name"].value) + logging.info("module_name: %s", module_name) + return module_name + + +def get_module_results( + ctx: mlir.LoweringRuleContext, module: ir.Module, module_name: str, *args +) -> Sequence[ir.Value]: + symtab = ir.SymbolTable(module.operation) + module = mlir.merge_mlir_modules( + ctx.module_context.module, + f"component_{module_name}", + module, + dst_symtab=ctx.module_context.symbol_table, + ) + results = symtab["main"].type.results + call = func_dialect.CallOp(results, ir.FlatSymbolRefAttr.get(module), args) + + return call.results + + +def wrap_init(fun: Callable[..., Any], traced_for: str) -> lu.WrappedFun: + # TODO(dsuo): Dummy debug info. + if isinstance(fun, functools.partial): + name = fun.func.__name__ + else: + name = fun.__name__ + return lu.wrap_init( + fun, debug_info=lu.DebugInfo(traced_for, name, None, None) + ) diff --git a/jax/_src/api.py b/jax/_src/api.py index 2752a1af9332..a243ed3c61d0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3108,11 +3108,17 @@ def eval_shape(fun, *args, **kwargs): >>> print(out.dtype) float32 """ + return trace(fun, *args, **kwargs).out_info # type: ignore + + +@api_boundary +def trace(fun: Callable, *args, **kwargs): if type(fun) is xc._xla.PjitFunction: - return fun.trace(*args, **kwargs).out_info # type: ignore + return fun.trace(*args, **kwargs) try: hash(fun) except TypeError: fun = partial(fun) - return jit(fun).trace(*args, **kwargs).out_info + return jit(fun).trace(*args, **kwargs) + @partial(api_boundary, repro_api_name="jax.named_call") diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 7491fabab7d0..60b75eaf35d9 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -482,14 +482,17 @@ def cache(call: Callable, *, A memoized version of ``call``. """ fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + hit_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore + hit = hit_caches.setdefault(fun.f, new_hit := {}) key = (fun.transforms, fun.params, fun.in_type, args, config.trace_context()) result = cache.get(key, None) if result is not None: ans, stores = result fun.populate_stores(stores) + hit[key] += 1 else: if do_explain := explain and config.explain_cache_misses.value: start = time.time() @@ -497,6 +500,7 @@ def memoized_fun(fun: WrappedFun, *args): if do_explain: explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) + hit[key] = 0 return ans @@ -505,6 +509,10 @@ def _evict_function(f): memoized_fun.cache_clear = fun_caches.clear # type: ignore memoized_fun.evict_function = _evict_function # type: ignore + memoized_fun.cache_items = fun_caches.items # type: ignore + memoized_fun.cache_get = fun_caches.get # type: ignore + memoized_fun.cache_keys = fun_caches.keys # type: ignore + memoized_fun.hit_get = hit_caches.get # type: ignore register_cache(memoized_fun, str(call)) return memoized_fun diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 180953c01781..a6c989579670 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -254,6 +254,7 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): + logging.info('cpp_pjit fun: %s %s', id(fun), fun.__name__) # args do not include the const args # See https://docs.jax.dev/en/latest/internals/constants.html. if config.no_tracing.value: @@ -651,12 +652,22 @@ def _infer_params_internal( entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh) if entry.pjit_params is None: + if isinstance(fun, partial): + name = f'partial({fun.func.__name__})' + else: + name = fun.__name__ + logging.info('missed infer params: %s %s', id(fun), name) + # if fun.__name__ == 'f': + # import traceback + # traceback.print_stack() dbg = dbg_fn() p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) if p.params['jaxpr'].jaxpr.is_high: return p, p.consts + args_flat entry.pjit_params = p + else: + logging.info('hit infer params: %s %s', id(fun), fun.__name__) return entry.pjit_params, entry.pjit_params.consts + dynargs def _infer_input_type(fun: Callable, dbg_fn: Callable[[], core.DebugInfo], @@ -1196,6 +1207,7 @@ def _create_pjit_jaxpr( else: closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) final_consts = [] + logging.info('missed jaxpr: %s', fun.__name__) return closed_jaxpr, final_consts, global_out_avals diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 86a0f56528b9..35e658af6095 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -809,7 +809,7 @@ def backends() -> dict[str, xla_client.Client]: err_msg = f"Unable to initialize backend '{platform}': {err}" if fail_quietly: _backend_errors[platform] = str(err) - logger.info(err_msg) + # logger.info(err_msg) else: if config.jax_platforms.value: err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" diff --git a/tests/aot_test.py b/tests/aot_test.py index 6f3184daef1e..c9371d585828 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -13,18 +13,25 @@ # limitations under the License. import contextlib +import logging +from typing import Any, Callable, Sequence import unittest + from absl.testing import absltest import jax from jax import lax +from jax._src import aot +from jax._src import api +from jax._src import aot_util from jax._src import config from jax._src import core +from jax._src import pjit from jax._src import test_util as jtu from jax._src.lib import xla_client as xc from jax.experimental import topologies from jax.experimental.serialize_executable import ( - deserialize_and_load, - serialize, + deserialize_and_load, + serialize, ) import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -36,19 +43,19 @@ with contextlib.suppress(ImportError): import pytest + pytestmark = pytest.mark.multiaccelerator class JaxAotTest(jtu.JaxTestCase): - - @jtu.run_on_devices('tpu', 'gpu') + @jtu.run_on_devices("tpu", "gpu") def test_pickle_jit_lower(self): def fun(x): return x * x - with jax.set_mesh(jax.sharding.Mesh(np.array(jax.devices()), ('data',))): + with jax.set_mesh(jax.sharding.Mesh(np.array(jax.devices()), ("data",))): lowered = jax.jit( - fun, in_shardings=P('data'), out_shardings=P(None, 'data') + fun, in_shardings=P("data"), out_shardings=P(None, "data") ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) def verify_serialization(lowered): @@ -59,33 +66,34 @@ def verify_serialization(lowered): verify_serialization(lowered) verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100))) verify_serialization( - jax.pmap(lambda x: x * x).lower( - np.zeros((len(jax.devices()), 4), dtype=np.float32))) + jax.pmap(lambda x: x * x).lower( + np.zeros((len(jax.devices()), 4), dtype=np.float32) + ) + ) @jtu.skip_on_devices("tpu") # TODO(phawkins): This test is segfaulting on TPU def test_topology_jit_serialize(self): try: aot_topo = topologies.get_topology_desc( - platform=jax.devices()[0].platform + platform=jax.devices()[0].platform ) except NotImplementedError: - raise unittest.SkipTest('PJRT Topology not supported') + raise unittest.SkipTest("PJRT Topology not supported") if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: - raise unittest.SkipTest('Compilation caching not yet supported.') + raise unittest.SkipTest("Compilation caching not yet supported.") if jtu.is_device_cuda(): - raise unittest.SkipTest('Broken on GPU: b/442353988') + raise unittest.SkipTest("Broken on GPU: b/442353988") @jax.jit def fn(x): return x * x def lower_and_load(mesh): - s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) x_shape = jax.ShapeDtypeStruct( - shape=(16, 16), - dtype=jnp.dtype('float32'), - sharding=s) + shape=(16, 16), dtype=jnp.dtype("float32"), sharding=s + ) lowered = fn.lower(x_shape) serialized, in_tree, out_tree = serialize(lowered.compile()) compiled = deserialize_and_load(serialized, in_tree, out_tree) @@ -95,30 +103,30 @@ def lower_and_load(mesh): n = max(1, len(ref_topo.devices) // 2) mesh_shape = (len(ref_topo.devices) // n, n) - ref_mesh = topologies.make_mesh(ref_topo, mesh_shape, ('x', 'y')) - aot_mesh = topologies.make_mesh(aot_topo, mesh_shape, ('x', 'y')) + ref_mesh = topologies.make_mesh(ref_topo, mesh_shape, ("x", "y")) + aot_mesh = topologies.make_mesh(aot_topo, mesh_shape, ("x", "y")) self.assertEqual( - lower_and_load(ref_mesh).as_text(), lower_and_load(aot_mesh).as_text() + lower_and_load(ref_mesh).as_text(), lower_and_load(aot_mesh).as_text() ) def test_get_topology_from_devices(self): try: aot_topo = topologies.get_topology_desc( - platform=jax.devices()[0].platform + platform=jax.devices()[0].platform ) except NotImplementedError: - raise unittest.SkipTest('PJRT Topology not supported') + raise unittest.SkipTest("PJRT Topology not supported") topo = xc.get_topology_for_devices(aot_topo.devices) self.assertEqual( - topo.platform_version, aot_topo.devices[0].client.platform_version + topo.platform_version, aot_topo.devices[0].client.platform_version ) def test_lower_as_text_with_and_without_debug_info(self): def my_function(x): return jnp.sin(x) - lowered = jax.jit(my_function).lower(42.) + lowered = jax.jit(my_function).lower(42.0) stablehlo = lowered.as_text("stablehlo", debug_info=True) self.assertRegex(stablehlo, r"sine.* loc") stablehlo = lowered.as_text("stablehlo") @@ -131,30 +139,37 @@ def my_function(x): def test_constants_in_lowering_in_aot(self): const_size = 100 - const = jax.random.uniform(jax.random.key(0), (const_size,), - dtype=np.float32) + const = jax.random.uniform( + jax.random.key(0), (const_size,), dtype=np.float32 + ) def my_function(x): return jnp.sin(x) + const - lowered = jax.jit(my_function).lower(np.full_like(const, 42., dtype=const.dtype)) + lowered = jax.jit(my_function).lower( + np.full_like(const, 42.0, dtype=const.dtype) + ) stablehlo = lowered.as_text("stablehlo") if config.use_simplified_jaxpr_constants.value: - self.assertNotRegex(stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x") + self.assertNotRegex( + stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x" + ) self.assertLen(lowered._lowering.const_args, 1) self.assertIs(lowered._lowering.const_args[0], const) else: - self.assertRegex(stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x") + self.assertRegex( + stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x" + ) self.assertLen(lowered._lowering.const_args, 0) def test_with_constants(self): - const = jnp.arange(16.) + 42. # A distinctive shape and value + const = jnp.arange(16.0) + 42.0 # A distinctive shape and value @jax.jit def f(x): return const[0:8] + x - inp = jnp.arange(8.) + inp = jnp.arange(8.0) compiled = f.lower(inp).compile() self.assertLen(compiled.args_info[0], 1) # Not including const_args self.assertLen(compiled.in_avals[0], 1) @@ -167,13 +182,14 @@ def f(x): self.assertCacheMisses(lambda: compiled(inp), cpp=0, aot_call=0) @jtu.parameterized_filterable( - kwargs=[ - dict(use_np=use_np, lower=lower, compile=compile, exec=exec) - for use_np in (False, True) - for lower in (False, True) - for compile in (False, True) - for exec in (False, True) - ]) + kwargs=[ + dict(use_np=use_np, lower=lower, compile=compile, exec=exec) + for use_np in (False, True) + for lower in (False, True) + for compile in (False, True) + for exec in (False, True) + ] + ) def test_with_constants_enable_x64(self, *, use_np, lower, compile, exec): # Closed-over constant is 64-bit. Each of lowering, compilation, and # execution can be run in 64-bit or 32-bit mode. @@ -185,7 +201,7 @@ def test_with_constants_enable_x64(self, *, use_np, lower, compile, exec): def f(x): return lax.convert_element_type(const, np.float32) + x - inp = np.arange(8., dtype=np.float32) + inp = np.arange(8.0, dtype=np.float32) with config.enable_x64(True) if lower else contextlib.nullcontext(): lowered = f.lower(inp) with config.enable_x64(True) if compile else contextlib.nullcontext(): @@ -216,17 +232,22 @@ def run(): # In some cases we expect errors: in 32-bit mode, lowered with 64-bit mode # and execute in 32-bit mode. - if (config.use_simplified_jaxpr_constants.value and - not config.enable_x64.value and - use_np and lower and not exec): + if ( + config.use_simplified_jaxpr_constants.value + and not config.enable_x64.value + and use_np + and lower + and not exec + ): with self.assertRaisesRegex( - xc.XlaRuntimeError, - "got buffer with incompatible size"): + xc.XlaRuntimeError, "got buffer with incompatible size" + ): run() return - self.assertArraysEqual(run(), - lax.convert_element_type(const, inp.dtype) + inp) + self.assertArraysEqual( + run(), lax.convert_element_type(const, inp.dtype) + inp + ) # Trigger cache hit self.assertCacheMisses(run, cpp=0, aot_call=0) @@ -238,10 +259,10 @@ def f(x): x_ref[...] += x f_lowered = f.lower(1) - with self.assertRaisesRegex(ValueError, 'serialize with a closed-over'): + with self.assertRaisesRegex(ValueError, "serialize with a closed-over"): serialized, in_tree, out_tree = serialize(f_lowered.compile()) - @jtu.run_on_devices('gpu', 'tpu') + @jtu.run_on_devices("gpu", "tpu") def test_mismatched_backends_raises(self): @jax.jit def f(x): @@ -251,10 +272,372 @@ def f(x): f_lowered = f.lower(x) serialized, in_tree, out_tree = serialize(f_lowered.compile()) with self.assertRaisesRegex( - ValueError, - 'Execution devices belong to a client other than `backend`'): - deserialize_and_load(serialized, in_tree, out_tree, backend='cpu', - execution_devices=jax.devices()[:1]) + ValueError, "Execution devices belong to a client other than `backend`" + ): + deserialize_and_load( + serialized, + in_tree, + out_tree, + backend="cpu", + execution_devices=jax.devices()[:1], + ) + + +@jtu.thread_unsafe_test_class() +class ComponentTest(jtu.JaxTestCase): + @contextlib.contextmanager + def make_in_memory_cache(self): + cache = aot_util.Cache() + with aot_util.component_cache(cache): + yield + jax.clear_caches() + + # TODO(dsuo): It would be nice to have a way to grab the pjit jaxpr cache + # key easily. + def get_jaxpr_key(self, fun: Callable[..., Any]) -> Callable[..., Any] | None: + for key in pjit._create_pjit_jaxpr.cache_keys(): + f = key if not hasattr(key, "__wrapped__") else key.__wrapped__ + if f == fun: + return key + + def validate_cache_states( + self, + fun: Callable[..., Any], + component_key: aot_util.ComponentKey, + num_jaxpr_entries: int, + num_jaxpr_hits: int | Sequence[int], + num_trace_hits: int, + num_trace_misses: int, + num_wrapper_hits: int | None, + num_disk_hits: int, + ): + cache = aot.get_cache() + + jaxpr_key = self.get_jaxpr_key(fun) + jaxpr_cache = pjit._create_pjit_jaxpr.cache_get(jaxpr_key) + jaxpr_hit = pjit._create_pjit_jaxpr.hit_get(jaxpr_key) + if isinstance(num_jaxpr_hits, int): + num_jaxpr_hits = (num_jaxpr_hits,) + # Verify fun exists in the jaxpr cache. + self.assertIsNotNone(jaxpr_key) + # Verify number of entries in jaxpr cache for fun. + self.assertEqual(len(jaxpr_cache), num_jaxpr_entries) + # Verify number of hits for each entry. + self.assertEqual(tuple(jaxpr_hit.values()), num_jaxpr_hits) + + # Verify the number of hits and misses we expect. + self.assertEqual( + pjit._infer_params_cached.cache_info().hits, num_trace_hits + ) + self.assertEqual( + pjit._infer_params_cached.cache_info().misses, num_trace_misses + ) + + # Verify component key exists in disk cache. + self.assertIn(component_key, cache.cache_keys()) + + # If we don't have wrapper hits (0) or we don't expect component_key (-1). + if num_wrapper_hits is not None: + # Verify the number of wrapper cache hits. + self.assertEqual( + aot_util._wrapper_cache.cache_info()[component_key]["hits"], + num_wrapper_hits, + ) + + # Verify the number of disk hits. + self.assertEqual(cache.cache_info(component_key)["hits"], num_disk_hits) + + # NOTE(dsuo): Disable checks because otherwise we check jaxprs in (at least) + # four places and makes reasoning about cache hits and misses harder. + # 1. After the initial abstract eval. + # 2. Before converting const vars. + # 3. After lifting the jaxpr. + # 4. After DCE. + @config.enable_checks(False) + def test_component_basic(self): + with self.make_in_memory_cache(): + + @aot.component(key="f") + def f(x): + return x + 1.0 + + self.assertEqual(f(1.0), 2.0) + self.validate_cache_states( + f.fun, + f.key, + # Make sure there is only one entry for f.fun. If there are more, then + # it means the lowering rule missed. + num_jaxpr_entries=1, + # There should be no hits in the jaxpr cache for f.fun because we've + # only just created it. + num_jaxpr_hits=0, + # We should have 1 hit from the trace in lowering. + num_trace_hits=1, + # We should have 4 misses: add, equal, f, and call_wrapped. + num_trace_misses=4, + # We shouldn't have hit the wrapper cache yet. + num_wrapper_hits=None, + # We get 1 hit on the disk cache during the lowering rule. However, this + # hit is for an incomplete CacheEntry; only avals_out were populated + # and not the lowered module. The lowering rule updates the CacheEntry + # with the lowered module. + num_disk_hits=1, + ) + + @aot.component(key="f") + def g(x): + raise NotImplementedError + + self.assertEqual(f.fun, g.fun) + self.assertEqual(g(1.0), 2.0) + # Cache state should remain unchanged except we grabbed the wrapped fun. + self.validate_cache_states(g.fun, g.key, 1, 0, 1, 4, 1, 1) + + @config.enable_checks(False) + def test_component_in_function(self): + with self.make_in_memory_cache(): + + @aot.component(key="f") + def f(x): + return x + 1.0 + + @jax.jit + def g(x): + return f(x) + 1.0 + + self.assertEqual(f(1.0), 2.0) + + # We should have the same cache states as in test_component_basic. + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 4, None, 1) + + # 1 hit when lowering g. g is not a component, so doesn't look up + # CacheEntry during abstract_eval. + self.assertEqual(g(1.0), 3.0) + # We have one more trace cache hit on f from tracing g, two more misses, + # one for g and one for add, and one more disk cache hit from lowering g. + self.validate_cache_states(f.fun, f.key, 1, 0, 2, 6, None, 2) + + @config.enable_checks(False) + def test_jit_of_component(self): + with self.make_in_memory_cache(): + + @jax.jit + @aot.component(key="f") + def f(x): + return x + 1.0 + + # Create cache entry when abstract_eval f. 1 hit when lowering f. + self.assertEqual(f(1.0), 2.0) + # We should have the same cache states as in test_component_basic except + # one additional infer params cache miss for the outermost jitted f. + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 5, None, 1) + + @aot.component(key="f") + def g(x): + raise NotImplementedError + + # We ignore g's implementation because it was turned into a component with + # key "f". + self.assertEqual(f.fun, g.fun) + self.assertEqual(g(1.0), 2.0) + # We have one more hit in infer params cache for the inner f and one more + # hit in the disk cache for the lowering of f. + self.validate_cache_states(g.fun, g.key, 1, 0, 2, 5, 1, 2) + + @config.enable_checks(False) + def test_component_of_jit(self): + with self.make_in_memory_cache(): + + @aot.component(key="f") + @jax.jit + def f(x): + return x + 1.0 + + self.assertEqual(f(1.0), 2.0) + # We should have the same cache states as in test_component_basic except + # one additional infer params cache miss for the outermost jitted f. + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 5, None, 1) + + @aot.component(key="f") + def g(x): + raise NotImplementedError + + # We ignore g's implementation because it was turned into a component with + # key "f". + self.assertEqual(f.fun, g.fun) + self.assertEqual(g(1.0), 2.0) + logging.info(g(1.0)) + # We have one hit in the wrapper cache. + self.validate_cache_states(g.fun, g.key, 1, 0, 1, 5, 1, 1) + + @config.enable_checks(False) + def test_explicit_lowering(self): + with self.make_in_memory_cache(): + + @aot.component(key="f") + def f(x): + return x + 1.0 + + lowered = f.lower(jax.ShapeDtypeStruct((), "float32")) + # One less infer params cache miss because we just have add, f, and + # call_wrapped; no equal. + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 3, None, 1) + + logging.info("\n\n\n") + + @aot.component(key="f") + def g(x): + raise NotImplementedError + + lowered = g.lower(jax.ShapeDtypeStruct((), "float32")) + # We hit the wrapper cache because we have the same component key, but + # because we explicitly call lower, we hit infer params cache again. + # TODO(dsuo): Do we want this behavior? + self.validate_cache_states(f.fun, f.key, 1, 0, 2, 3, 1, 1) + + @config.enable_checks(False) + def test_vmap_of_component(self): + with self.make_in_memory_cache(): + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + component_f = aot.component(key="f")(f) + logging.info("\n\ncomponent_f id %s", id(component_f)) + vmapped_f = jax.vmap(component_f) + logging.info("\n\nvmapped_f id %s", id(vmapped_f)) + logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) + + # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. + vmapped_key = aot_util.ComponentKey.vmap(component_f.key) + + self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) + self.validate_cache_states( + component_f.fun, component_f.key, 1, 0, 1, 7, 0, 0 + ) + self.validate_cache_states( + component_f.fun, vmapped_key, 1, 0, 1, 7, None, 1 + ) + + @config.enable_checks(False) + def test_vmap_of_vmap_of_component(self): + with self.make_in_memory_cache(): + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + c_f = aot.component(key="f")(f) + logging.info("\n\nc_f id %s", id(c_f)) + self.assertEqual(c_f(1.0), 2.0) + self.validate_cache_states(c_f.fun, c_f.key, 1, 0, 1, 4, 0, 1) + v_f = jax.vmap(c_f) + logging.info("\n\nv_f id %s", id(v_f)) + self.assertArraysEqual(v_f(jnp.ones((4,))), jnp.ones((4,)) + 1.0) + # TODO(dsuo): How to put component key on vmapped? + # We now have 2 entries, one for f and one for vmap(f). + self.validate_cache_states( + c_f.fun, aot_util.ComponentKey.vmap(c_f.key), 2, (0, 0), 1, 12, None, 1 + ) + vv_f = jax.vmap(v_f) + logging.info("\n\nvv_f id %s", id(vv_f)) + self.assertArraysEqual( + vv_f( + jnp.ones(( + 4, + 4, + )) + ), + jnp.ones((4, 4)) + 1.0, + ) + self.validate_cache_states( + c_f.fun, + aot_util.ComponentKey.vmap(aot_util.ComponentKey.vmap(c_f.key)), + 2, + (0, 0), + 3, + 16, + None, + 1, + ) + + @config.enable_checks(False) + def test_vmap_of_jit_of_component(self): + # NOTE: This should be the same as test_vmap_of_component except for one + # more infer params cache miss because of the extra jit. + with self.make_in_memory_cache(): + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + component_f = jax.jit(aot.component(key="f")(f)) + logging.info("\n\ncomponent_f id %s", id(component_f)) + vmapped_f = jax.vmap(component_f) + logging.info("\n\nvmapped_f id %s", id(vmapped_f)) + logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) + + # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. + vmapped_key = aot_util.ComponentKey.vmap(component_f.key) + + self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) + self.validate_cache_states( + component_f.fun, component_f.key, 1, 0, 1, 8, 0, 0 + ) + self.validate_cache_states( + component_f.fun, vmapped_key, 1, 0, 1, 8, None, 1 + ) + + @config.enable_checks(False) + def test_jit_of_vmap_of_component(self): + # NOTE: This should be the same as test_vmap_of_component except for one + # more infer params cache miss because of the extra jit. + with self.make_in_memory_cache(): + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + component_f = aot.component(key="f")(f) + logging.info("\n\ncomponent_f id %s", id(component_f)) + vmapped_f = jax.jit(jax.vmap(component_f)) + logging.info("\n\nvmapped_f id %s", id(vmapped_f)) + logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) + + # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. + vmapped_key = aot_util.ComponentKey.vmap(component_f.key) + + self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) + self.validate_cache_states( + component_f.fun, component_f.key, 1, 0, 1, 8, 0, 0 + ) + self.validate_cache_states( + component_f.fun, vmapped_key, 1, 0, 1, 8, None, 1 + ) + + @config.enable_checks(False) + def test_scan_of_component(self): + with self.make_in_memory_cache(): + + @aot.component(key="f") + def f(x): + logging.info("running!") + return x + 1.0 + + def body(carry, x): + return f(carry), f(x) + + carry, ys = jax.lax.scan(body, 0, jnp.arange(10, dtype="float32")) + + self.assertEqual(carry, 10) + self.assertArraysEqual(ys, jnp.arange(10, dtype="float32") + 1) + self.validate_cache_states(f.fun, f.key, 2, (0, 0), 2, 13, None, 2) + -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())