From ef8cb0a51bced5697187592c7d08001cf8650c5b Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 3 Jul 2025 23:58:14 +0300 Subject: [PATCH 01/24] Add array-api copy semantics to dlpack MakePjrtBuffer --- jaxlib/_jax/__init__.pyi | 178 ++++++++++++++++++--------------------- jaxlib/dlpack.cc | 16 ++-- jaxlib/jax.cc | 13 ++- jaxlib/xla_client.py | 2 +- 4 files changed, 97 insertions(+), 112 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 055c1dc2d1de..f3399a83b705 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -134,8 +134,8 @@ class Shape: def array_shape( type: PrimitiveType, dims: Sequence[int], - layout: Sequence[int] | None = ..., - dynamic_dimensions: Sequence[bool] | None = ..., + layout: Sequence[int] | None = None, + dynamic_dimensions: Sequence[bool] | None = None, ) -> Shape: """Constructs an array shape.""" @@ -144,8 +144,8 @@ class Shape: def array_shape( type: numpy.dtype, dims: Sequence[int], - layout: Sequence[int] | None = ..., - dynamic_dimensions: Sequence[bool] | None = ..., + layout: Sequence[int] | None = None, + dynamic_dimensions: Sequence[bool] | None = None, ) -> Shape: ... @staticmethod def token_shape() -> Shape: ... @@ -191,7 +191,7 @@ class Literal: def __init__(self, arg: Shape, /) -> None: ... def __repr__(self) -> str: ... def __array__( - self, dtype: object | None = ..., copy: bool | None = ... + self, dtype: object | None = None, copy: bool | None = None ) -> NDArray: ... def shape(self) -> Shape: ... @@ -201,7 +201,7 @@ class XlaComputation: def program_shape(self) -> ProgramShape: ... def name(self) -> str: ... def as_serialized_hlo_module_proto(self) -> bytes: ... - def as_hlo_text(self, print_large_constants: bool = ...) -> str: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... def as_hlo_dot_graph(self) -> str: ... def hash(self) -> int: ... def as_hlo_module(self) -> HloModule: ... @@ -368,8 +368,8 @@ def register_custom_call_target( fn_name: object, fn: object, platform: str, - api_version: int = ..., - traits: int = ..., + api_version: int = 0, + traits: int = 0, ) -> None: ... def custom_call_targets(platform: str) -> dict: ... @@ -724,9 +724,9 @@ class HloSharding: @staticmethod def iota_tile( dims: Sequence[int], - reshape_dims: Sequence[int] = ..., - transpose_perm: Sequence[int] = ..., - subgroup_types: Sequence[OpSharding_Type] = ..., + reshape_dims: Sequence[int] = [], + transpose_perm: Sequence[int] = [], + subgroup_types: Sequence[OpSharding_Type] = [], ) -> HloSharding: ... @staticmethod def manual() -> HloSharding: ... @@ -739,7 +739,7 @@ class HloSharding: @staticmethod def subgroup_with_device_ordering( tile_assignment: Annotated[NDArray[numpy.int64], dict(order='C')], - subgroup_types: Sequence[OpSharding_Type] = ..., + subgroup_types: Sequence[OpSharding_Type] = [], ) -> HloSharding: ... def __eq__(self, other: object, /) -> bool: ... def __ne__(self, other: object, /) -> bool: ... @@ -872,8 +872,8 @@ class Client: def buffer_from_pyval( self, argument: object, - device: Device | None = ..., - force_copy: bool = ..., + device: Device | None = None, + force_copy: bool = False, host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, ) -> object: ... def compile( @@ -924,23 +924,15 @@ class Client: self, serialized: bytes, executable_devices: DeviceList, - compile_options: CompileOptions | None = ..., - host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., - ) -> LoadedExecutable: ... - @overload - def deserialize_executable( - self, - serialized: bytes, - executable_devices: DeviceList, - compile_options: CompileOptions | None = ..., - host_callbacks: Sequence[Callable] = ..., + compile_options: CompileOptions | None = None, + host_callbacks: Sequence[typing_extensions.CapsuleType] = [], ) -> LoadedExecutable: ... @overload def deserialize_executable( self, serialized: bytes, executable_devices: Sequence, - compile_options: CompileOptions | None = ..., + compile_options: CompileOptions | None = None, ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> None: ... @@ -951,7 +943,7 @@ class Client: result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], recv_channel_ids: Sequence[int], - serializer: Callable | None = ..., + serializer: Callable | None = None, ) -> object: ... def get_default_layout( self, dtype: numpy.dtype, shard_shape: Sequence, device: Device @@ -979,34 +971,34 @@ class CpuCollectives: def make_gloo_tcp_collectives( distributed_client: DistributedRuntimeClient, - hostname: str | None = ..., - interface: str | None = ..., + hostname: str | None = None, + interface: str | None = None, ) -> CpuCollectives: ... def make_mpi_collectives() -> CpuCollectives: ... def get_tfrt_cpu_client( - asynchronous: bool = ..., - distributed_client: DistributedRuntimeClient | None = ..., - node_id: int = ..., - num_nodes: int = ..., - collectives: CpuCollectives | None = ..., - num_devices: int | None = ..., - get_local_topology_timeout_minutes: int | None = ..., - get_global_topology_timeout_minutes: int | None = ..., - transfer_server_factory: TransferServerInterfaceFactory | None = ..., + asynchronous: bool = True, + distributed_client: DistributedRuntimeClient | None = None, + node_id: int = 0, + num_nodes: int = 1, + collectives: CpuCollectives | None = None, + num_devices: int | None = None, + get_local_topology_timeout_minutes: int | None = None, + get_global_topology_timeout_minutes: int | None = None, + transfer_server_factory: TransferServerInterfaceFactory | None = None, ) -> Client: ... def pjrt_plugin_loaded(arg: str, /) -> bool: ... def load_pjrt_plugin( platform_name: str, - library_path: str | None = ..., - c_api: typing_extensions.CapsuleType | None = ..., + library_path: str | None = None, + c_api: typing_extensions.CapsuleType | None = None, ) -> typing_extensions.CapsuleType: ... def pjrt_plugin_initialized(arg: str, /) -> bool: ... def initialize_pjrt_plugin(arg: str, /) -> None: ... def get_c_api_client( platform_name: str, - options: Mapping[str, str | bool | int | Sequence[int] | float] = ..., - distributed_client: DistributedRuntimeClient | None = ..., - transfer_server_factory: TransferServerInterfaceFactory | None = ..., + options: Mapping[str, str | bool | int | Sequence[int] | float] = {}, + distributed_client: DistributedRuntimeClient | None = None, + transfer_server_factory: TransferServerInterfaceFactory | None = None, ) -> Client: ... def get_default_c_api_topology( arg0: str, @@ -1034,7 +1026,7 @@ def batched_copy_array_to_devices_with_sharding( /, ) -> list[Array]: ... def array_result_handler( - aval: object, sharding: object, committed: bool, _skip_checks: bool = ... + aval: object, sharding: object, committed: bool, _skip_checks: bool = False ) -> ResultHandler: ... class ResultHandler: @@ -1076,8 +1068,8 @@ class NamedSharding(Sharding): self, mesh: object, spec: PartitionSpec, - memory_kind: object | None = ..., - _logical_device_ids: object | None = ..., + memory_kind: object | None = None, + _logical_device_ids: object | None = None, ) -> None: ... @property def mesh(self) -> object: ... @@ -1094,7 +1086,7 @@ class NamedSharding(Sharding): class SingleDeviceSharding(Sharding): def __init__( - self, device: object, memory_kind: object | None = ... + self, device: object, memory_kind: object | None = None ) -> None: ... @property def _device(self) -> object: ... @@ -1120,28 +1112,28 @@ class GSPMDSharding(Sharding): self, devices: DeviceList, op_sharding: OpSharding, - memory_kind: object | None = ..., + memory_kind: object | None = None, ) -> None: ... @overload def __init__( self, devices: DeviceList, op_sharding: HloSharding, - memory_kind: object | None = ..., + memory_kind: object | None = None, ) -> None: ... @overload def __init__( self, devices: Sequence[Device], op_sharding: OpSharding, - memory_kind: object | None = ..., + memory_kind: object | None = None, ) -> None: ... @overload def __init__( self, devices: Sequence[Device], op_sharding: HloSharding, - memory_kind: object | None = ..., + memory_kind: object | None = None, ) -> None: ... @property def _devices(self) -> DeviceList: ... @@ -1223,7 +1215,9 @@ class LoadedExecutable: def size_of_generated_code_in_bytes(self) -> int: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def execute_sharded( - self, arguments: Sequence[Array], with_tokens: bool = ... + self, + arguments: Sequence[Array | Sequence[Array]], + with_tokens: bool = False, ) -> ExecuteResults: ... def hlo_modules(self) -> list[HloModule]: ... def get_output_memory_kinds(self) -> list[list[str]]: ... @@ -1248,16 +1242,13 @@ class ShardedToken: def get_token(self, arg: int, /) -> Token: ... def buffer_to_dlpack_managed_tensor( - buffer: object, stream: int | None = ... + buffer: object, stream: int | None = None ) -> typing_extensions.CapsuleType: ... def dlpack_managed_tensor_to_buffer( - dlpack: typing_extensions.CapsuleType, - device: Device, - stream: int | None, - copy: bool | None = ..., + dlpack: typing_extensions.CapsuleType, device: Device, stream: int | None, copy: bool | None = None ) -> ArrayImpl: ... def cuda_array_interface_to_buffer( - cai: dict, gpu_backend: Client | None = ..., device_id: int | None = ... + cai: dict, gpu_backend: Client | None = None, device_id: int | None = None ) -> object: ... class RuntimeTracebackMode(enum.Enum): @@ -1278,7 +1269,7 @@ def set_send_traceback_to_runtime_thread_local( ) -> None: ... class PjitFunctionCache: - def __init__(self, capacity: int = ...) -> None: ... + def __init__(self, capacity: int = 4096) -> None: ... def size(self) -> int: ... def capacity(self) -> int: ... def clear(self) -> None: ... @@ -1294,7 +1285,7 @@ class PjitFunction: def __call__(self, /, *args, **kwargs): """Call self as a function.""" - def __get__(self, instance, owner=..., /): + def __get__(self, instance, owner=None, /): """Return an attribute of instance, which is of type owner.""" __vectorcalloffset__: types.MemberDescriptorType = ... @@ -1391,8 +1382,8 @@ def register_custom_call_partitioner( prop_user_sharding: object, partition: object, infer_sharding_from_operands: object, - can_side_effecting_have_replicated_sharding: bool = ..., - c_api: typing_extensions.CapsuleType | None = ..., + can_side_effecting_have_replicated_sharding: bool = False, + c_api: typing_extensions.CapsuleType | None = None, ) -> None: """Registers a partitioner for a custom-call operation. @@ -1413,7 +1404,7 @@ def register_custom_call_partitioner( def encode_inspect_sharding_callback(arg: object, /) -> bytes: ... def register_custom_call_as_batch_partitionable( - target_name: str, c_api: typing_extensions.CapsuleType | None = ... + target_name: str, c_api: typing_extensions.CapsuleType | None = None ) -> None: """Registers a custom call as batch partitionable. @@ -1430,7 +1421,6 @@ def register_custom_call_as_batch_partitionable( class TransferConnection: def _testonly_inject_failure(self) -> None: ... - def _poison_connection(self) -> None: ... def _pull_flat( self, arg0: int, arg1: Client, arg2: Sequence[object], / ) -> list[Array]: ... @@ -1447,19 +1437,19 @@ class TransferServer: def _make_error_array(arg0: Client, arg1: object, arg2: str, /) -> Array: ... def start_transfer_server( client: Client, - address: str = ..., - transport_addresses: Sequence[str] = ..., - max_num_parallel_copies: int = ..., - transfer_size: int = ..., - supports_pinned_allocator: bool = ..., - use_raw_buffers: bool = ..., + address: str = '[::]:0', + transport_addresses: Sequence[str] = [], + max_num_parallel_copies: int = 8, + transfer_size: int = 268435456, + supports_pinned_allocator: bool = False, + use_raw_buffers: bool = False, ) -> TransferServer: ... def make_transfer_server_interface_factory( - transfer_size: int = ..., - cross_host_transfer_timeout_seconds: int = ..., - distributed_client: DistributedRuntimeClient | None = ..., - socket_address: str = ..., - transport_addresses: Sequence[str] = ..., + transfer_size: int = 268435456, + cross_host_transfer_timeout_seconds: int = 60, + distributed_client: DistributedRuntimeClient | None = None, + socket_address: str = '[::]:0', + transport_addresses: Sequence[str] = [], ) -> TransferServerInterfaceFactory: ... class PreemptionSyncManager: @@ -1488,14 +1478,14 @@ class DistributedRuntimeClient: self, barrier_id: str, timeout_in_ms: int, - process_ids: Sequence[int] | None = ..., + process_ids: Sequence[int] | None = None, ) -> None: ... def get_live_nodes(self, process_ids: Sequence[int]) -> dict[int, int]: ... def key_value_set( - self, key: str, value: str, allow_overwrite: bool = ... + self, key: str, value: str, allow_overwrite: bool = False ) -> None: ... def key_value_set_bytes( - self, key: str, value: bytes, allow_overwrite: bool = ... + self, key: str, value: bytes, allow_overwrite: bool = False ) -> None: ... def key_value_dir_get(self, key: str) -> list[tuple[str, str]]: ... def key_value_dir_get_bytes(self, key: str) -> list[tuple[str, bytes]]: ... @@ -1504,21 +1494,21 @@ class DistributedRuntimeClient: def get_distributed_runtime_service( address: str, num_nodes: int, - heartbeat_timeout: int | None = ..., - cluster_register_timeout: int | None = ..., - shutdown_timeout: int | None = ..., + heartbeat_timeout: int | None = None, + cluster_register_timeout: int | None = None, + shutdown_timeout: int | None = None, ) -> DistributedRuntimeService: ... def get_distributed_runtime_client( address: str, node_id: int, - rpc_timeout: int | None = ..., - init_timeout: int | None = ..., - shutdown_timeout: int | None = ..., - heartbeat_timeout: int | None = ..., - missed_heartbeat_callback: Callable | None = ..., - shutdown_on_destruction: bool | None = ..., - use_compression: bool | None = ..., - recoverable: bool | None = ..., + rpc_timeout: int | None = None, + init_timeout: int | None = None, + shutdown_timeout: int | None = None, + heartbeat_timeout: int | None = None, + missed_heartbeat_callback: Callable | None = None, + shutdown_on_destruction: bool | None = None, + use_compression: bool | None = None, + recoverable: bool | None = None, ) -> DistributedRuntimeClient: ... def collect_garbage() -> None: ... def is_optimized_build() -> bool: ... @@ -1571,10 +1561,10 @@ def batched_device_put( sharding: object, xs: Sequence[object], devices: Sequence[Device], - committed: bool = ..., - force_copy: bool = ..., + committed: bool = True, + force_copy: bool = False, host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, - enable_x64: bool | None = ..., + enable_x64: bool | None = None, ) -> object: ... def reorder_shards( x: Array, dst_sharding: object, array_copy_semantics: ArrayCopySemantics @@ -1584,15 +1574,15 @@ def check_and_canonicalize_memory_kind( memory_kind: object | None, device_list: DeviceList ) -> object: ... -ifrt_version_number: int = ... +ifrt_version_number: int = 34 def approx_top_k_reduction_output_size( input_size: int, rank: int, top_k: int, recall_target: float, - aggregate_to_topk: bool = ..., - input_size_override: int = ..., + aggregate_to_topk: bool = True, + input_size_override: int = -1, ) -> tuple[int, int]: ... def get_internal_device_put_info() -> dict[str, int]: ... diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc index ca90dd1a5c5e..48d4a7329876 100644 --- a/jaxlib/dlpack.cc +++ b/jaxlib/dlpack.cc @@ -194,16 +194,13 @@ absl::StatusOr> MakePjrtBuffer( // On CPU, creating a view may fail because of unaligned data buffer // in which case we'll fallback to copy. On non-CPU, array-api copy // semantics is handled in dlpack._place_array function. - bool fallback_to_copy = - !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU; + bool fallback_to_copy = !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU; // Create a view. if (!copy.value_or(false)) { auto result = device.client()->CreateViewOfDeviceBuffer( - data, shape, *device.default_memory_space(), on_delete_callback, - stream); - if (!(result.status().code() == absl::StatusCode::kInvalidArgument && - fallback_to_copy)) { + data, shape, *device.default_memory_space(), on_delete_callback, stream); + if (!(result.status().code() == absl::StatusCode::kInvalidArgument && fallback_to_copy)) { return result; } } @@ -219,8 +216,8 @@ absl::StatusOr> MakePjrtBuffer( // Create a copy. return device.client()->BufferFromHostBuffer( data, element_type, dimensions, byte_strides, - xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, - on_delete_callback, memory_space, /*device_layout=*/nullptr); + xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + memory_space, /*device_layout=*/nullptr); } } // namespace @@ -320,8 +317,7 @@ absl::StatusOr BufferToDLPackManagedTensor( absl::StatusOr DLPackManagedTensorToBuffer( const nb::capsule& tensor, ifrt::Device* ifrt_device, - nb_class_ptr client, std::optional stream, - std::optional copy) { + nb_class_ptr client, std::optional stream, std::optional copy) { ifrt::PjRtDevice* device = llvm::dyn_cast_or_null(ifrt_device); if (device == nullptr) { diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc index f914467ab322..cbda01880089 100644 --- a/jaxlib/jax.cc +++ b/jaxlib/jax.cc @@ -590,18 +590,17 @@ NB_MODULE(_jax, m) { return xla::ValueOrThrow(DLPackManagedTensorToBuffer( tensor, device->device(), device->client(), stream, copy)); }, - nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none(), - nb::arg("copy").none() = nb::none(), - nb::sig( - // clang-format off + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none(), nb::arg("copy").none() = nb::none(), + nb::sig( + // clang-format off "def dlpack_managed_tensor_to_buffer(" "dlpack: typing_extensions.CapsuleType, " "device: Device, " "stream: int | None, " - "copy: bool | None = ..." + "copy: bool | None" ") -> ArrayImpl" - // clang-format on - )); + // clang-format on + )); m.def("cuda_array_interface_to_buffer", xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), nb::arg("gpu_backend").none() = nb::none(), diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index ce0379c4cb03..7d7983f11483 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 387 # Introduce lowering support for lax.ragged_dot + collectives. +_version = 384 # Add a new copy argument to dlpack_managed_tensor_to_buffer # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 2b8b32cfb98ca4e676e695e59f44e07271c8dff9 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 29 Oct 2025 09:51:38 -0400 Subject: [PATCH 02/24] [aot] WIP --- jax/_src/aot.py | 155 +++++++++++++++++++++++++++++++++++++++++++ jax/_src/aot_util.py | 124 ++++++++++++++++++++++++++++++++++ tests/aot_test.py | 65 ++++++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 jax/_src/aot.py create mode 100644 jax/_src/aot_util.py diff --git a/jax/_src/aot.py b/jax/_src/aot.py new file mode 100644 index 000000000000..5aac06ffb440 --- /dev/null +++ b/jax/_src/aot.py @@ -0,0 +1,155 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX AOT API""" + +import functools +from collections.abc import Hashable +from typing import Any, Callable, NamedTuple + + +from absl import logging +from jax._src import aot_util +from jax._src import api +from jax._src import core +from jax._src import traceback_util +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import func as func_dialect + +component_p = core.Primitive("component") + + +ComponentKey = Hashable | Callable[..., Hashable] + + +def component(component_key: ComponentKey = None) -> Callable[..., Any]: + def _component(f: Callable[..., Any]): + @api.jit + @util.wraps(f) + @traceback_util.api_boundary + def wrapper(*args): + # TODO(dsuo): Flatten function as in shard_map pmap. + # TODO(dsuo): Need to consider static args. + return component_p.bind(*args, f=f, component_key=component_key) + + # NOTE(dsuo): Using a component means we'll jit you in this dummy + # implementation. + return wrapper + + return _component + + +def component_impl(*args, f: Callable[..., Any], **_): + # TODO(dsuo): Call should not re-trace. + logging.info("component_impl") + return f(*args) + + +def component_abstract_eval( + *args, f: Callable[..., Any], component_key: ComponentKey +): + logging.info("component_abstract_eval: %s", component_key) + key = aot_util.make_abstract_eval_key(component_key) + + def abstract_eval(): + # TODO(dsuo): How / when to check if we're tracing again by mistake. + if key not in aot_util.traced_cache: + aot_util.traced_cache[key] = aot_util.trace(f, *args) + return tree_util.tree_map( + lambda x: core.ShapedArray(x.shape, x.dtype), + aot_util.traced_cache[key].out_info, + ) + + return aot_util.get_cached_or_put( + key, + abstract_eval, + aot_util.serialize_abstract_eval, + aot_util.deserialize_abstract_eval, + ) + + +def component_lowering( + ctx, *args, f: Callable[..., Any], component_key: ComponentKey +): + logging.info("component_lowering: %s", component_key) + key = aot_util.make_lowering_key(component_key) + + # TODO(dsuo): Is this something we can grab from LoweringRuleContext or + # traced? + module_name = f"{component_key}.module" + traced_key = aot_util.make_abstract_eval_key(component_key) + # TODO(dsuo): Expect entry exists. TBD for transformations. + traced = aot_util.traced_cache[traced_key] + + def lower_jaxpr_to_module(): + # with ctx.module_context.module.context: + lowering_result = mlir.lower_jaxpr_to_module( + module_name=module_name, + jaxpr=traced.jaxpr, + num_const_args=traced._num_consts, + in_avals=ctx.avals_in, + # TODO(dsuo): What are ordered effects vs effects? + ordered_effects=traced.jaxpr.effects, + # TODO(dsuo): Figure out why ctx.platforms=None. + platforms=['cpu'], + backend=ctx.module_context.backend, + axis_context=ctx.module_context.axis_context, + donated_args=tuple( + x.donated for x in tree_util.tree_leaves(traced.args_info) + ), + lowering_parameters=mlir.LoweringParameters(), + # TODO(dsuo): Presumably we need to forward the rest of the arguments to + # lower_jaxpr_to_module? + ) + # TODO(dsuo): What should we do about the other attributes on + # LoweringResult? + # - keepalive: probably not supported. + # - host_callbacks: probably not supported. + # - shape_poly_state: talk to necula@ + submodule = lowering_result.module + # TODO(dsuo): We have this to ensure the source and destination modules have + # the same context, but is it necessary? Perhaps yes, since we need to get + # rid of the submodule context before merging. Could we just create it with + # the right context? + submodule = ir.Module.parse(mlir.module_to_bytecode(submodule)) + return submodule + + submodule = aot_util.get_cached_or_put( + key, + lower_jaxpr_to_module, + aot_util.serialize_lowering, + aot_util.deserialize_lowering, + ) + + symtab = ir.SymbolTable(submodule.operation) + fn = mlir.merge_mlir_modules( + ctx.module_context.module, + f"component_{module_name}", + submodule, + dst_symtab=ctx.module_context.symbol_table, + ) + # TODO(dsuo): There's quite a bit of logic from jax.export, but we just strip + # away most of that for this demo. e.g., ordered effects, platforms. + # submodule_args = [mlir.aval_to_ir_type(x) for x in ctx.avals_in] + results = symtab["main"].type.results + call = func_dialect.CallOp(results, ir.FlatSymbolRefAttr.get(fn), args) + + return call.results + + +component_p.def_impl(component_impl) +component_p.def_abstract_eval(component_abstract_eval) +mlir.register_lowering(component_p, component_lowering) diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py new file mode 100644 index 000000000000..1ae2d9b37778 --- /dev/null +++ b/jax/_src/aot_util.py @@ -0,0 +1,124 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX AOT API utilities.""" + +from collections.abc import Hashable +import dataclasses +import functools +import pickle +from typing import Any, Callable, NamedTuple + +from absl import logging +from jax._src import api +from jax._src import config +from jax._src import stages +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib import xla_client as xc + + +# For now, we don't worry about serialization. +SerializedType = bytes | Any + + +def _validate_component_cache(val): + assert val is None or isinstance(val, Cache) + + +component_cache = config.string_or_object_state( + name="jax_component_cache", + default=None, + help="Cache dir for components. Components won't be cached if None.", + validator=_validate_component_cache, +) + + +traced_cache: dict[Hashable, stages.Traced] = {} + + +def trace(f: Callable[..., Any], *args, **kwargs): + if type(f) is xc._xla.PjitFunction: + return f.trace(*args, **kwargs) + try: + hash(f) + except TypeError: + fun = functools.partial(f) + return api.jit(f).trace(*args, **kwargs) + + +class CacheEntry: + def __init__(self, blob: SerializedType, hits: int = 0): + self.blob = blob + self.hits = hits + +class Cache(NamedTuple): + get: Callable[[Hashable, bool], bytes | None] + put: Callable[[Hashable, bytes], None] + keys: Callable[[], list[Hashable]] + clear: Callable[[], None] + + +_in_memory_cache: dict[Hashable, CacheEntry] = {} + + +def make_in_memory_cache(): + def get(key: Hashable, update_hits: bool = True) -> SerializedType | None: + entry = _in_memory_cache.get(key, None) + if entry is not None and update_hits: + _in_memory_cache[key].hits += 1 + return entry.blob + return entry + + def put(key: Hashable, data: SerializedType): + _in_memory_cache[key] = CacheEntry(data) + + def keys() -> list[Hashable]: + return list(_in_memory_cache.keys()) + + def clear(): + _in_memory_cache.clear() + + return Cache(get, put, keys, clear) + + +KeyFn = Callable[[Hashable], Hashable] +SerFn = Callable[[Any], SerializedType] +DesFn = Callable[[SerializedType], Any] + +make_abstract_eval_key: KeyFn = lambda k: f"{k}.abstract_eval" +serialize_abstract_eval: SerFn = lambda obj: pickle.dumps(obj) +deserialize_abstract_eval: DesFn = lambda blob: pickle.loads(blob) +make_lowering_key: KeyFn = lambda k: f"{k}.lowering" + +# TODO(dsuo): When is serialize/deserialize portable artifact needed? +serialize_lowering: SerFn = lambda obj: mlir.module_to_bytecode(obj) +deserialize_lowering: DesFn = lambda blob: ir.Module.parse(blob) + + +def get_cached_or_put(key, make, serialize, deserialize): + if (cache := component_cache.value) is None: + logging.info("Component cache is not set.") + return make() + + if blob := cache.get(key): # pytype: disable=attribute-error + logging.info("Key %s found with blob %s.", key, blob) + return deserialize(blob) + + logging.info("Key %s missing.", key) + obj = make() + blob = serialize(obj) + logging.info("Putting key %s with blob %s.", key, blob) + cache.put(key, blob) # pytype: disable=attribute-error + logging.info("Cache keys: %s", cache.keys()) + return obj diff --git a/tests/aot_test.py b/tests/aot_test.py index 6f3184daef1e..5cfec7954dc4 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -17,6 +17,9 @@ from absl.testing import absltest import jax from jax import lax +from jax._src import aot +from jax._src import api +from jax._src import aot_util from jax._src import config from jax._src import core from jax._src import test_util as jtu @@ -256,5 +259,67 @@ def f(x): deserialize_and_load(serialized, in_tree, out_tree, backend='cpu', execution_devices=jax.devices()[:1]) + +@jtu.thread_unsafe_test_class() +class ComponentTest(jtu.JaxTestCase): + + @contextlib.contextmanager + def make_in_memory_cache(self): + cache = aot_util.make_in_memory_cache() + with aot_util.component_cache(cache): + yield + aot_util.component_cache.value.clear() + + def test_component_lowering_cache_hit(self): + with self.make_in_memory_cache(): + cache = aot_util.component_cache.value + @aot.component(component_key='f') + def f(x): + return x + 1.0 + + self.assertEqual(f(1.0), 2.0) + self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) + + @aot.component(component_key='f') + def g(x): + raise NotImplementedError + + self.assertEqual(g(1.0), 2.0) + # TODO(dsuo): Why is abstract_eval rule called so many times? + self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + + def test_component_call_in_function(self): + with self.make_in_memory_cache(): + cache = aot_util.component_cache.value + @aot.component(component_key='f') + def f(x): + return x + 1.0 + + @jax.jit + def g(x): + return f(x) + 1.0 + + self.assertEqual(g(1.0), 3.0) + self.assertEqual(f(1.0), 2.0) + self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + + def test_explicit_cached_lowering(self): + with self.make_in_memory_cache(): + cache = aot_util.component_cache.value + @aot.component(component_key='f') + def f(x): + return x + 1.0 + + lowered = f.lower(jax.ShapeDtypeStruct((), 'float32')) + self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) + + @aot.component(component_key='f') + def g(x): + raise NotImplementedError + + lowered = g.lower(jax.ShapeDtypeStruct((), 'float32')) + self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From f675d37474b4e19fc87479e8d4f43499fa6b20dd Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 29 Oct 2025 16:18:25 -0400 Subject: [PATCH 03/24] Update --- jax/_src/aot.py | 48 +++++++++++++++++++++++++++++--------------- jax/_src/aot_util.py | 25 ++++++++--------------- jax/_src/api.py | 10 +++++++-- jax/_src/pjit.py | 1 + tests/aot_test.py | 42 +++++++++++++++++++++++++------------- 5 files changed, 77 insertions(+), 49 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 5aac06ffb440..f00a13d59e6c 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -25,6 +25,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util +from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect @@ -36,14 +37,15 @@ def component(component_key: ComponentKey = None) -> Callable[..., Any]: - def _component(f: Callable[..., Any]): + def _component(fun: Callable[..., Any]): @api.jit - @util.wraps(f) + @util.wraps(fun) @traceback_util.api_boundary def wrapper(*args): + logging.info("wrapper: %s", args) # TODO(dsuo): Flatten function as in shard_map pmap. # TODO(dsuo): Need to consider static args. - return component_p.bind(*args, f=f, component_key=component_key) + return component_p.bind(*args, fun=fun, component_key=component_key) # NOTE(dsuo): Using a component means we'll jit you in this dummy # implementation. @@ -52,25 +54,29 @@ def wrapper(*args): return _component -def component_impl(*args, f: Callable[..., Any], **_): +def component_impl(*args, fun: Callable[..., Any], **_): # TODO(dsuo): Call should not re-trace. logging.info("component_impl") - return f(*args) + return fun(*args) def component_abstract_eval( - *args, f: Callable[..., Any], component_key: ComponentKey + *args, fun: Callable[..., Any], component_key: ComponentKey ): logging.info("component_abstract_eval: %s", component_key) key = aot_util.make_abstract_eval_key(component_key) def abstract_eval(): - # TODO(dsuo): How / when to check if we're tracing again by mistake. - if key not in aot_util.traced_cache: - aot_util.traced_cache[key] = aot_util.trace(f, *args) + logging.info("component_abstract_eval args: %s", args) + # NOTE(dsuo): The claim is tracing cache will handle caching jaxprs for us. + # However, we'll need to convert ir.Values in the lowering rule to avals to + # trace in lowering with args. There are two further downsides: + # 1. `fun` must have the same id (in addition to same everything else) in + # order for us to use this cache within the same process. + # 2. It's not easy to inspect the _infer_params_cached.cache_info() to + # understand if we've gotten a cache hit or not (more relevant for testing). return tree_util.tree_map( - lambda x: core.ShapedArray(x.shape, x.dtype), - aot_util.traced_cache[key].out_info, + lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) ) return aot_util.get_cached_or_put( @@ -82,7 +88,7 @@ def abstract_eval(): def component_lowering( - ctx, *args, f: Callable[..., Any], component_key: ComponentKey + ctx, *args, fun: Callable[..., Any], component_key: ComponentKey ): logging.info("component_lowering: %s", component_key) key = aot_util.make_lowering_key(component_key) @@ -92,10 +98,11 @@ def component_lowering( module_name = f"{component_key}.module" traced_key = aot_util.make_abstract_eval_key(component_key) # TODO(dsuo): Expect entry exists. TBD for transformations. - traced = aot_util.traced_cache[traced_key] + logging.info("component_lowering avals_in: %s", ctx.avals_in) def lower_jaxpr_to_module(): # with ctx.module_context.module.context: + traced = api.trace(fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, jaxpr=traced.jaxpr, @@ -104,7 +111,7 @@ def lower_jaxpr_to_module(): # TODO(dsuo): What are ordered effects vs effects? ordered_effects=traced.jaxpr.effects, # TODO(dsuo): Figure out why ctx.platforms=None. - platforms=['cpu'], + platforms=["cpu"], backend=ctx.module_context.backend, axis_context=ctx.module_context.axis_context, donated_args=tuple( @@ -135,7 +142,7 @@ def lower_jaxpr_to_module(): ) symtab = ir.SymbolTable(submodule.operation) - fn = mlir.merge_mlir_modules( + module = mlir.merge_mlir_modules( ctx.module_context.module, f"component_{module_name}", submodule, @@ -145,11 +152,20 @@ def lower_jaxpr_to_module(): # away most of that for this demo. e.g., ordered effects, platforms. # submodule_args = [mlir.aval_to_ir_type(x) for x in ctx.avals_in] results = symtab["main"].type.results - call = func_dialect.CallOp(results, ir.FlatSymbolRefAttr.get(fn), args) + call = func_dialect.CallOp(results, ir.FlatSymbolRefAttr.get(module), args) return call.results +def component_batcher( + vals_in, dims_in, fun: Callable[..., Any], component_key: ComponentKey +): + return fun(vals_in[0]), dims_in[0] + + component_p.def_impl(component_impl) component_p.def_abstract_eval(component_abstract_eval) +# TODO(dsuo): Figure out multiple_results i.e., distinguishing between (1,) and +# 1. mlir.register_lowering(component_p, component_lowering) +batching.primitive_batchers[component_p] = component_batcher diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 1ae2d9b37778..3d6d8adbadc0 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -14,17 +14,20 @@ """JAX AOT API utilities.""" from collections.abc import Hashable -import dataclasses import functools import pickle from typing import Any, Callable, NamedTuple from absl import logging from jax._src import api +from jax._src import api_util from jax._src import config -from jax._src import stages +from jax._src import mesh as mesh_lib +from jax._src import pjit +from jax._src import tree_util from jax._src.interpreters import mlir from jax._src.lib.mlir import ir +from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -44,24 +47,12 @@ def _validate_component_cache(val): ) -traced_cache: dict[Hashable, stages.Traced] = {} - - -def trace(f: Callable[..., Any], *args, **kwargs): - if type(f) is xc._xla.PjitFunction: - return f.trace(*args, **kwargs) - try: - hash(f) - except TypeError: - fun = functools.partial(f) - return api.jit(f).trace(*args, **kwargs) - - class CacheEntry: def __init__(self, blob: SerializedType, hits: int = 0): self.blob = blob self.hits = hits + class Cache(NamedTuple): get: Callable[[Hashable, bool], bytes | None] put: Callable[[Hashable, bytes], None] @@ -112,13 +103,13 @@ def get_cached_or_put(key, make, serialize, deserialize): return make() if blob := cache.get(key): # pytype: disable=attribute-error - logging.info("Key %s found with blob %s.", key, blob) + logging.info("Key %s found.", key) return deserialize(blob) logging.info("Key %s missing.", key) obj = make() blob = serialize(obj) - logging.info("Putting key %s with blob %s.", key, blob) + logging.info("Putting key %s.", key) cache.put(key, blob) # pytype: disable=attribute-error logging.info("Cache keys: %s", cache.keys()) return obj diff --git a/jax/_src/api.py b/jax/_src/api.py index 2752a1af9332..a243ed3c61d0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3108,11 +3108,17 @@ def eval_shape(fun, *args, **kwargs): >>> print(out.dtype) float32 """ + return trace(fun, *args, **kwargs).out_info # type: ignore + + +@api_boundary +def trace(fun: Callable, *args, **kwargs): if type(fun) is xc._xla.PjitFunction: - return fun.trace(*args, **kwargs).out_info # type: ignore + return fun.trace(*args, **kwargs) try: hash(fun) except TypeError: fun = partial(fun) - return jit(fun).trace(*args, **kwargs).out_info + return jit(fun).trace(*args, **kwargs) + @partial(api_boundary, repro_api_name="jax.named_call") diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 180953c01781..6ef1dc6ac5b7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -633,6 +633,7 @@ def _infer_params( def _infer_params_internal( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: + logging.info("infer_params fun: %s", fun) ctx_mesh = mesh_lib.get_concrete_mesh() dbg_fn = lambda: debug_info( 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, diff --git a/tests/aot_test.py b/tests/aot_test.py index 5cfec7954dc4..3fb29741ee1b 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -13,7 +13,9 @@ # limitations under the License. import contextlib +import logging import unittest + from absl.testing import absltest import jax from jax import lax @@ -22,6 +24,7 @@ from jax._src import aot_util from jax._src import config from jax._src import core +from jax._src import pjit from jax._src import test_util as jtu from jax._src.lib import xla_client as xc from jax.experimental import topologies @@ -269,23 +272,27 @@ def make_in_memory_cache(self): with aot_util.component_cache(cache): yield aot_util.component_cache.value.clear() + jax.clear_caches() def test_component_lowering_cache_hit(self): with self.make_in_memory_cache(): cache = aot_util.component_cache.value - @aot.component(component_key='f') + @jax.jit def f(x): return x + 1.0 - self.assertEqual(f(1.0), 2.0) + component_f = aot.component(component_key='f')(f) + self.assertEqual(component_f(1.0), 2.0) self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) - @aot.component(component_key='f') - def g(x): - raise NotImplementedError + logging.info(pjit._infer_params_cached.cache_info()) + + logging.info("GGGGGG") + component_g = aot.component(component_key='f')(f) - self.assertEqual(g(1.0), 2.0) - # TODO(dsuo): Why is abstract_eval rule called so many times? + self.assertNotEqual(component_f, component_g) + self.assertEqual(component_g(1.0), 2.0) + logging.info(pjit._infer_params_cached.cache_info()) self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) def test_component_call_in_function(self): @@ -306,20 +313,27 @@ def g(x): def test_explicit_cached_lowering(self): with self.make_in_memory_cache(): cache = aot_util.component_cache.value - @aot.component(component_key='f') def f(x): return x + 1.0 - lowered = f.lower(jax.ShapeDtypeStruct((), 'float32')) + component_f = aot.component(component_key='f')(f) + lowered = component_f.lower(jax.ShapeDtypeStruct((), 'float32')) self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) - @aot.component(component_key='f') - def g(x): - raise NotImplementedError - - lowered = g.lower(jax.ShapeDtypeStruct((), 'float32')) + component_g = aot.component(component_key='f')(f) + lowered = component_g.lower(jax.ShapeDtypeStruct((), 'float32')) self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + def test_vmap_of_component(self): + with self.make_in_memory_cache(): + cache = aot_util.component_cache.value + @aot.component(component_key='f') + def f(x): + return x + 1.0 + + vmapped_f = jax.vmap(f) + + self.assertArraysEqual(vmapped_f(jax.numpy.ones(8,)), jax.numpy.ones(8,) + 1.0) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 037b4961103237e52552bbc66d13b0ba843002ac Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 31 Oct 2025 09:02:36 -0400 Subject: [PATCH 04/24] Update --- jax/_src/aot.py | 43 ++++++++++++++++++++++--------------- jax/_src/aot_util.py | 50 +++++++++++++++++++++++++++++++++++++------- tests/aot_test.py | 37 ++++++++++++++++++-------------- 3 files changed, 90 insertions(+), 40 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index f00a13d59e6c..707b1e160edc 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -13,9 +13,8 @@ # limitations under the License. """JAX AOT API""" -import functools from collections.abc import Hashable -from typing import Any, Callable, NamedTuple +from typing import Any, Callable from absl import logging @@ -30,21 +29,22 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect -component_p = core.Primitive("component") + +UserKey = Hashable | Callable[..., Hashable] -ComponentKey = Hashable | Callable[..., Hashable] +ComponentKey = aot_util.ComponentKey -def component(component_key: ComponentKey = None) -> Callable[..., Any]: +def component(key: UserKey = None) -> Callable[..., Any]: def _component(fun: Callable[..., Any]): @api.jit @util.wraps(fun) @traceback_util.api_boundary def wrapper(*args): - logging.info("wrapper: %s", args) # TODO(dsuo): Flatten function as in shard_map pmap. # TODO(dsuo): Need to consider static args. + component_key = ComponentKey(key) return component_p.bind(*args, fun=fun, component_key=component_key) # NOTE(dsuo): Using a component means we'll jit you in this dummy @@ -56,27 +56,30 @@ def wrapper(*args): def component_impl(*args, fun: Callable[..., Any], **_): # TODO(dsuo): Call should not re-trace. - logging.info("component_impl") + logging.debug("component_impl") return fun(*args) def component_abstract_eval( *args, fun: Callable[..., Any], component_key: ComponentKey ): - logging.info("component_abstract_eval: %s", component_key) + logging.debug("component_abstract_eval: %s, %s", component_key, fun) key = aot_util.make_abstract_eval_key(component_key) def abstract_eval(): - logging.info("component_abstract_eval args: %s", args) + logging.debug("component_abstract_eval args: %s", args) # NOTE(dsuo): The claim is tracing cache will handle caching jaxprs for us. # However, we'll need to convert ir.Values in the lowering rule to avals to - # trace in lowering with args. There are two further downsides: + # trace in lowering with args. There are three issues: # 1. `fun` must have the same id (in addition to same everything else) in # order for us to use this cache within the same process. - # 2. It's not easy to inspect the _infer_params_cached.cache_info() to + # 2. It's not easy to inspect the _infer_params_cached.cache_debug() to # understand if we've gotten a cache hit or not (more relevant for testing). + # 3. lu.cache on _create_pjit_jaxpr keys on fun transformations. + # So... just do the easy thing for now and worry about this later. + traced = aot_util.get_traced(key, fun, *args) return tree_util.tree_map( - lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) + lambda x: core.ShapedArray(x.shape, x.dtype), traced.out_info ) return aot_util.get_cached_or_put( @@ -90,19 +93,19 @@ def abstract_eval(): def component_lowering( ctx, *args, fun: Callable[..., Any], component_key: ComponentKey ): - logging.info("component_lowering: %s", component_key) - key = aot_util.make_lowering_key(component_key) + logging.debug("component_lowering: %s, %s", component_key, fun) # TODO(dsuo): Is this something we can grab from LoweringRuleContext or # traced? module_name = f"{component_key}.module" traced_key = aot_util.make_abstract_eval_key(component_key) + lowering_key = aot_util.make_lowering_key(component_key) # TODO(dsuo): Expect entry exists. TBD for transformations. - logging.info("component_lowering avals_in: %s", ctx.avals_in) + logging.debug("component_lowering avals_in: %s", ctx.avals_in) def lower_jaxpr_to_module(): # with ctx.module_context.module.context: - traced = api.trace(fun, *ctx.avals_in) + traced = aot_util.get_traced(traced_key, fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, jaxpr=traced.jaxpr, @@ -135,7 +138,7 @@ def lower_jaxpr_to_module(): return submodule submodule = aot_util.get_cached_or_put( - key, + lowering_key, lower_jaxpr_to_module, aot_util.serialize_lowering, aot_util.deserialize_lowering, @@ -163,6 +166,12 @@ def component_batcher( return fun(vals_in[0]), dims_in[0] +def clear_caches(): + aot_util.component_cache.value.clear() + aot_util._traced_cache.clear() + + +component_p = core.Primitive("component") component_p.def_impl(component_impl) component_p.def_abstract_eval(component_abstract_eval) # TODO(dsuo): Figure out multiple_results i.e., distinguishing between (1,) and diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 3d6d8adbadc0..bf5c47d27d56 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -24,6 +24,7 @@ from jax._src import config from jax._src import mesh as mesh_lib from jax._src import pjit +from jax._src import stages from jax._src import tree_util from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -35,6 +36,23 @@ SerializedType = bytes | Any +class ComponentKey: + def __init__(self, user_key: Hashable): + self.user_key = user_key + + def __hash__(self): + return hash(self.user_key) + + def __eq__(self, other): + return hash(self) == hash(other) + + def __str__(self): + return self.user_key + + def __repr__(self): + return self.__str__ + + def _validate_component_cache(val): assert val is None or isinstance(val, Cache) @@ -47,6 +65,24 @@ def _validate_component_cache(val): ) +class TracedCacheEntry: + def __init__(self, traced: stages.Traced, hits: int = 0): + self.traced = traced + self.hits = hits + + +_traced_cache: dict[Hashable, TracedCacheEntry] = {} + + +def get_traced(key: Hashable, fun: Callable[..., Any], *args): + entry = _traced_cache.get(key, None) + if entry: + entry.hits += 1 + else: + entry = _traced_cache[key] = TracedCacheEntry(api.trace(fun, *args)) + return entry.traced + + class CacheEntry: def __init__(self, blob: SerializedType, hits: int = 0): self.blob = blob @@ -87,23 +123,23 @@ def clear(): SerFn = Callable[[Any], SerializedType] DesFn = Callable[[SerializedType], Any] -make_abstract_eval_key: KeyFn = lambda k: f"{k}.abstract_eval" +make_abstract_eval_key: KeyFn = lambda k: ComponentKey( + f"{k.user_key}.abstract_eval" +) serialize_abstract_eval: SerFn = lambda obj: pickle.dumps(obj) deserialize_abstract_eval: DesFn = lambda blob: pickle.loads(blob) -make_lowering_key: KeyFn = lambda k: f"{k}.lowering" - -# TODO(dsuo): When is serialize/deserialize portable artifact needed? +make_lowering_key: KeyFn = lambda k: ComponentKey(f"{k.user_key}.lowering") serialize_lowering: SerFn = lambda obj: mlir.module_to_bytecode(obj) deserialize_lowering: DesFn = lambda blob: ir.Module.parse(blob) def get_cached_or_put(key, make, serialize, deserialize): if (cache := component_cache.value) is None: - logging.info("Component cache is not set.") + logging.debug("Component cache is not set.") return make() if blob := cache.get(key): # pytype: disable=attribute-error - logging.info("Key %s found.", key) + logging.debug("Key %s found.", key) return deserialize(blob) logging.info("Key %s missing.", key) @@ -111,5 +147,5 @@ def get_cached_or_put(key, make, serialize, deserialize): blob = serialize(obj) logging.info("Putting key %s.", key) cache.put(key, blob) # pytype: disable=attribute-error - logging.info("Cache keys: %s", cache.keys()) + logging.debug("Cache keys: %s", cache.keys()) return obj diff --git a/tests/aot_test.py b/tests/aot_test.py index 3fb29741ee1b..0c5bc34003c8 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -271,34 +271,37 @@ def make_in_memory_cache(self): cache = aot_util.make_in_memory_cache() with aot_util.component_cache(cache): yield - aot_util.component_cache.value.clear() + aot.clear_caches() jax.clear_caches() def test_component_lowering_cache_hit(self): with self.make_in_memory_cache(): cache = aot_util.component_cache.value - @jax.jit + @aot.component(key='f') def f(x): return x + 1.0 - component_f = aot.component(component_key='f')(f) - self.assertEqual(component_f(1.0), 2.0) + self.assertEqual(f(1.0), 2.0) + # We get 1 hit on traced cache during the lowering rule. + traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) + self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) - logging.info(pjit._infer_params_cached.cache_info()) - - logging.info("GGGGGG") - component_g = aot.component(component_key='f')(f) + @aot.component(key='f') + def g(x): + raise NotImplementedError - self.assertNotEqual(component_f, component_g) - self.assertEqual(component_g(1.0), 2.0) - logging.info(pjit._infer_params_cached.cache_info()) + self.assertEqual(g(1.0), 2.0) + # We get 1 hit for component cache and so we don't even check traced + # cache. self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) + self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) def test_component_call_in_function(self): with self.make_in_memory_cache(): cache = aot_util.component_cache.value - @aot.component(component_key='f') + @aot.component(key='f') def f(x): return x + 1.0 @@ -306,9 +309,11 @@ def f(x): def g(x): return f(x) + 1.0 - self.assertEqual(g(1.0), 3.0) self.assertEqual(f(1.0), 2.0) + self.assertEqual(g(1.0), 3.0) self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) + self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) def test_explicit_cached_lowering(self): with self.make_in_memory_cache(): @@ -316,18 +321,18 @@ def test_explicit_cached_lowering(self): def f(x): return x + 1.0 - component_f = aot.component(component_key='f')(f) + component_f = aot.component(key='f')(f) lowered = component_f.lower(jax.ShapeDtypeStruct((), 'float32')) self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) - component_g = aot.component(component_key='f')(f) + component_g = aot.component(key='f')(f) lowered = component_g.lower(jax.ShapeDtypeStruct((), 'float32')) self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) def test_vmap_of_component(self): with self.make_in_memory_cache(): cache = aot_util.component_cache.value - @aot.component(component_key='f') + @aot.component(key='f') def f(x): return x + 1.0 From 1b9eb3b31577fc50f13ca5e8b25bfc3b9e9f0fa2 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 31 Oct 2025 10:20:54 -0400 Subject: [PATCH 05/24] Update --- jax/_src/aot.py | 55 +++++++--------------------- jax/_src/aot_util.py | 87 +++++++++++++++++++++----------------------- 2 files changed, 56 insertions(+), 86 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 707b1e160edc..54ddbe454bb6 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -14,7 +14,7 @@ """JAX AOT API""" from collections.abc import Hashable -from typing import Any, Callable +from typing import Any, Callable, Sequence from absl import logging @@ -44,68 +44,40 @@ def _component(fun: Callable[..., Any]): def wrapper(*args): # TODO(dsuo): Flatten function as in shard_map pmap. # TODO(dsuo): Need to consider static args. + # TODO(dsuo): Do we have all the information we need at this point to make + # the component key? component_key = ComponentKey(key) return component_p.bind(*args, fun=fun, component_key=component_key) - # NOTE(dsuo): Using a component means we'll jit you in this dummy - # implementation. return wrapper return _component def component_impl(*args, fun: Callable[..., Any], **_): - # TODO(dsuo): Call should not re-trace. - logging.debug("component_impl") return fun(*args) def component_abstract_eval( *args, fun: Callable[..., Any], component_key: ComponentKey -): - logging.debug("component_abstract_eval: %s, %s", component_key, fun) - key = aot_util.make_abstract_eval_key(component_key) - - def abstract_eval(): - logging.debug("component_abstract_eval args: %s", args) - # NOTE(dsuo): The claim is tracing cache will handle caching jaxprs for us. - # However, we'll need to convert ir.Values in the lowering rule to avals to - # trace in lowering with args. There are three issues: - # 1. `fun` must have the same id (in addition to same everything else) in - # order for us to use this cache within the same process. - # 2. It's not easy to inspect the _infer_params_cached.cache_debug() to - # understand if we've gotten a cache hit or not (more relevant for testing). - # 3. lu.cache on _create_pjit_jaxpr keys on fun transformations. - # So... just do the easy thing for now and worry about this later. - traced = aot_util.get_traced(key, fun, *args) - return tree_util.tree_map( +) -> Sequence[core.AbstractValue] | None: + def abstract_eval() -> aot_util.CacheEntry: + traced = aot_util.get_traced(component_key, fun, *args) + avals_out = tree_util.tree_map( lambda x: core.ShapedArray(x.shape, x.dtype), traced.out_info ) + return aot_util.CacheEntry(avals_out) - return aot_util.get_cached_or_put( - key, - abstract_eval, - aot_util.serialize_abstract_eval, - aot_util.deserialize_abstract_eval, - ) + return aot_util.get_entry(component_key, abstract_eval).avals_out def component_lowering( ctx, *args, fun: Callable[..., Any], component_key: ComponentKey -): - logging.debug("component_lowering: %s, %s", component_key, fun) - - # TODO(dsuo): Is this something we can grab from LoweringRuleContext or - # traced? +) -> Sequence[ir.Value]: module_name = f"{component_key}.module" - traced_key = aot_util.make_abstract_eval_key(component_key) - lowering_key = aot_util.make_lowering_key(component_key) - # TODO(dsuo): Expect entry exists. TBD for transformations. - logging.debug("component_lowering avals_in: %s", ctx.avals_in) - - def lower_jaxpr_to_module(): - # with ctx.module_context.module.context: - traced = aot_util.get_traced(traced_key, fun, *ctx.avals_in) + + def lower_jaxpr_to_module() -> aot_util.CacheEntry: + traced = aot_util.get_traced(component_key, fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, jaxpr=traced.jaxpr, @@ -135,6 +107,7 @@ def lower_jaxpr_to_module(): # rid of the submodule context before merging. Could we just create it with # the right context? submodule = ir.Module.parse(mlir.module_to_bytecode(submodule)) + entry = aot_util.get_entry(component_key) return submodule submodule = aot_util.get_cached_or_put( diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index bf5c47d27d56..d6e0999d0a6c 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -16,12 +16,13 @@ from collections.abc import Hashable import functools import pickle -from typing import Any, Callable, NamedTuple +from typing import Any, Callable, NamedTuple, Self, Sequence from absl import logging from jax._src import api from jax._src import api_util from jax._src import config +from jax._src import core from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import stages @@ -84,33 +85,47 @@ def get_traced(key: Hashable, fun: Callable[..., Any], *args): class CacheEntry: - def __init__(self, blob: SerializedType, hits: int = 0): - self.blob = blob + def __init__( + self, + avals_out: Sequence[core.AbstractValue] | None, + module: ir.Module | None = None, + hits: int = 0, + ): + self.avals_out = avals_out + self.module = module self.hits = hits + def serialize(self) -> SerializedType: + module_bytecode = mlir.module_to_bytecode(self.module) + return pickle.dumps((self.avals_out, module_bytecode)) + + @classmethod + def deserialize(cls, blob: SerializedType) -> Self: + avals_out, module = pickle.loads(blob) + return cls(avals_out, module) + class Cache(NamedTuple): - get: Callable[[Hashable, bool], bytes | None] - put: Callable[[Hashable, bytes], None] - keys: Callable[[], list[Hashable]] + get: Callable[[ComponentKey], bytes | None] + put: Callable[[ComponentKey, bytes], None] + keys: Callable[[], list[ComponentKey]] clear: Callable[[], None] -_in_memory_cache: dict[Hashable, CacheEntry] = {} +_in_memory_cache: dict[ComponentKey, SerializedType] = {} +_in_memory_cache_hits: dict[ComponentKey, int] = {} def make_in_memory_cache(): - def get(key: Hashable, update_hits: bool = True) -> SerializedType | None: - entry = _in_memory_cache.get(key, None) - if entry is not None and update_hits: - _in_memory_cache[key].hits += 1 - return entry.blob - return entry + def get(key: ComponentKey) -> SerializedType | None: + hits = _in_memory_cache_hits.setdefault(key, 0) + _in_memory_cache_hits[key] += 1 + return _in_memory_cache.get(key, None) - def put(key: Hashable, data: SerializedType): - _in_memory_cache[key] = CacheEntry(data) + def put(key: ComponentKey, data: SerializedType): + _in_memory_cache[key] = data - def keys() -> list[Hashable]: + def keys() -> list[ComponentKey]: return list(_in_memory_cache.keys()) def clear(): @@ -119,33 +134,15 @@ def clear(): return Cache(get, put, keys, clear) -KeyFn = Callable[[Hashable], Hashable] -SerFn = Callable[[Any], SerializedType] -DesFn = Callable[[SerializedType], Any] - -make_abstract_eval_key: KeyFn = lambda k: ComponentKey( - f"{k.user_key}.abstract_eval" -) -serialize_abstract_eval: SerFn = lambda obj: pickle.dumps(obj) -deserialize_abstract_eval: DesFn = lambda blob: pickle.loads(blob) -make_lowering_key: KeyFn = lambda k: ComponentKey(f"{k.user_key}.lowering") -serialize_lowering: SerFn = lambda obj: mlir.module_to_bytecode(obj) -deserialize_lowering: DesFn = lambda blob: ir.Module.parse(blob) - - -def get_cached_or_put(key, make, serialize, deserialize): - if (cache := component_cache.value) is None: - logging.debug("Component cache is not set.") +def get_entry( + key: ComponentKey, make: Callable[[], CacheEntry] +) -> CacheEntry: + cache: Cache = component_cache.value + if cache is None: return make() - - if blob := cache.get(key): # pytype: disable=attribute-error - logging.debug("Key %s found.", key) - return deserialize(blob) - - logging.info("Key %s missing.", key) - obj = make() - blob = serialize(obj) - logging.info("Putting key %s.", key) - cache.put(key, blob) # pytype: disable=attribute-error - logging.debug("Cache keys: %s", cache.keys()) - return obj + if blob := cache.get(key): + return CacheEntry.deserialize(blob) + entry = make() + blob = entry.serialize() + cache.put(key, blob) + return entry From 9a4c4e207a3babf3cfd5c7e2f37390564944fa52 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 31 Oct 2025 13:32:45 -0400 Subject: [PATCH 06/24] Update --- jax/_src/aot.py | 46 ++++++++++++++--------------- jax/_src/aot_util.py | 69 +++++++++++++++++++++++++------------------- tests/aot_test.py | 32 ++++++++++---------- 3 files changed, 79 insertions(+), 68 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 54ddbe454bb6..215602d21d9d 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -31,24 +31,25 @@ UserKey = Hashable | Callable[..., Hashable] - - ComponentKey = aot_util.ComponentKey +get_cache = aot_util.get_cache def component(key: UserKey = None) -> Callable[..., Any]: def _component(fun: Callable[..., Any]): + # TODO(dsuo): Do we have all the information we need at this point to make + # the component key? + component_key = ComponentKey(key) + @api.jit @util.wraps(fun) @traceback_util.api_boundary def wrapper(*args): # TODO(dsuo): Flatten function as in shard_map pmap. # TODO(dsuo): Need to consider static args. - # TODO(dsuo): Do we have all the information we need at this point to make - # the component key? - component_key = ComponentKey(key) return component_p.bind(*args, fun=fun, component_key=component_key) + wrapper.component_key = component_key return wrapper return _component @@ -61,22 +62,27 @@ def component_impl(*args, fun: Callable[..., Any], **_): def component_abstract_eval( *args, fun: Callable[..., Any], component_key: ComponentKey ) -> Sequence[core.AbstractValue] | None: - def abstract_eval() -> aot_util.CacheEntry: + entry = aot_util.get_entry(component_key) + if entry is None: traced = aot_util.get_traced(component_key, fun, *args) avals_out = tree_util.tree_map( lambda x: core.ShapedArray(x.shape, x.dtype), traced.out_info ) - return aot_util.CacheEntry(avals_out) - - return aot_util.get_entry(component_key, abstract_eval).avals_out + aot_util.put_entry(component_key, entry := aot_util.CacheEntry(avals_out)) + return entry.avals_out def component_lowering( ctx, *args, fun: Callable[..., Any], component_key: ComponentKey ) -> Sequence[ir.Value]: - module_name = f"{component_key}.module" + with ctx.module_context.context as ir_ctx: + entry = aot_util.get_entry(component_key, ir_ctx) + if entry is None: + raise ValueError("Should hit abstract_eval already, which would populate.") - def lower_jaxpr_to_module() -> aot_util.CacheEntry: + module_name = f"{component_key}.module" + if (module := entry.module) is None: + logging.info('missed lowering: %s', fun) traced = aot_util.get_traced(component_key, fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, @@ -101,27 +107,19 @@ def lower_jaxpr_to_module() -> aot_util.CacheEntry: # - keepalive: probably not supported. # - host_callbacks: probably not supported. # - shape_poly_state: talk to necula@ - submodule = lowering_result.module + module = lowering_result.module # TODO(dsuo): We have this to ensure the source and destination modules have # the same context, but is it necessary? Perhaps yes, since we need to get # rid of the submodule context before merging. Could we just create it with # the right context? - submodule = ir.Module.parse(mlir.module_to_bytecode(submodule)) - entry = aot_util.get_entry(component_key) - return submodule - - submodule = aot_util.get_cached_or_put( - lowering_key, - lower_jaxpr_to_module, - aot_util.serialize_lowering, - aot_util.deserialize_lowering, - ) + entry.module = module = ir.Module.parse(mlir.module_to_bytecode(module)) + aot_util.put_entry(component_key, entry) - symtab = ir.SymbolTable(submodule.operation) + symtab = ir.SymbolTable(module.operation) module = mlir.merge_mlir_modules( ctx.module_context.module, f"component_{module_name}", - submodule, + module, dst_symtab=ctx.module_context.symbol_table, ) # TODO(dsuo): There's quite a bit of logic from jax.export, but we just strip diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index d6e0999d0a6c..5d0fd562ed4d 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -14,23 +14,17 @@ """JAX AOT API utilities.""" from collections.abc import Hashable -import functools import pickle from typing import Any, Callable, NamedTuple, Self, Sequence from absl import logging from jax._src import api -from jax._src import api_util from jax._src import config from jax._src import core -from jax._src import mesh as mesh_lib -from jax._src import pjit from jax._src import stages -from jax._src import tree_util from jax._src.interpreters import mlir from jax._src.lib.mlir import ir -from jax._src.lib import jax_jit -from jax._src.lib import xla_client as xc +from jax._src.lib.mlir.dialects import func as func_dialect # For now, we don't worry about serialization. @@ -77,10 +71,10 @@ def __init__(self, traced: stages.Traced, hits: int = 0): def get_traced(key: Hashable, fun: Callable[..., Any], *args): entry = _traced_cache.get(key, None) - if entry: - entry.hits += 1 - else: + if entry is None: entry = _traced_cache[key] = TracedCacheEntry(api.trace(fun, *args)) + else: + entry.hits += 1 return entry.traced @@ -96,12 +90,19 @@ def __init__( self.hits = hits def serialize(self) -> SerializedType: - module_bytecode = mlir.module_to_bytecode(self.module) + module_bytecode = None + if self.module is not None: + module_bytecode = mlir.module_to_bytecode(self.module) return pickle.dumps((self.avals_out, module_bytecode)) @classmethod - def deserialize(cls, blob: SerializedType) -> Self: - avals_out, module = pickle.loads(blob) + def deserialize(cls, blob: SerializedType, ctx: ir.Context | None = None) -> Self: + avals_out, module_bytecode = pickle.loads(blob) + if module_bytecode is None or ctx is None: + module = None + else: + with ctx: + module = ir.Module.parse(module_bytecode) return cls(avals_out, module) @@ -110,6 +111,7 @@ class Cache(NamedTuple): put: Callable[[ComponentKey, bytes], None] keys: Callable[[], list[ComponentKey]] clear: Callable[[], None] + hits: Callable[[ComponentKey], int] _in_memory_cache: dict[ComponentKey, SerializedType] = {} @@ -118,7 +120,8 @@ class Cache(NamedTuple): def make_in_memory_cache(): def get(key: ComponentKey) -> SerializedType | None: - hits = _in_memory_cache_hits.setdefault(key, 0) + logging.info('getting key') + hits = _in_memory_cache_hits.setdefault(key, -1) _in_memory_cache_hits[key] += 1 return _in_memory_cache.get(key, None) @@ -130,19 +133,27 @@ def keys() -> list[ComponentKey]: def clear(): _in_memory_cache.clear() + _in_memory_cache_hits.clear() + + def hits(key: ComponentKey): + if key not in _in_memory_cache_hits: + raise ValueError(f"key {key} not found in cache hits") + return _in_memory_cache_hits[key] + + return Cache(get, put, keys, clear, hits) + + +def get_cache() -> Cache | None: + return component_cache.value + + +def get_entry(key: ComponentKey, ctx: ir.Context | None = None) -> CacheEntry | None: + if (cache := get_cache()) is not None: + if (blob := cache.get(key)) is not None: + return CacheEntry.deserialize(blob, ctx) + return None # sigh pytype + - return Cache(get, put, keys, clear) - - -def get_entry( - key: ComponentKey, make: Callable[[], CacheEntry] -) -> CacheEntry: - cache: Cache = component_cache.value - if cache is None: - return make() - if blob := cache.get(key): - return CacheEntry.deserialize(blob) - entry = make() - blob = entry.serialize() - cache.put(key, blob) - return entry +def put_entry(key: ComponentKey, entry: CacheEntry) -> None: + if (cache := get_cache()) is not None: + cache.put(key, entry.serialize()) diff --git a/tests/aot_test.py b/tests/aot_test.py index 0c5bc34003c8..385b417b2551 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -276,16 +276,15 @@ def make_in_memory_cache(self): def test_component_lowering_cache_hit(self): with self.make_in_memory_cache(): - cache = aot_util.component_cache.value + cache = aot.get_cache() @aot.component(key='f') def f(x): return x + 1.0 self.assertEqual(f(1.0), 2.0) # We get 1 hit on traced cache during the lowering rule. - traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) - self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) - self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) + self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + # self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) @aot.component(key='f') def g(x): @@ -294,13 +293,13 @@ def g(x): self.assertEqual(g(1.0), 2.0) # We get 1 hit for component cache and so we don't even check traced # cache. - self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + self.assertEqual(cache.hits(f.component_key), 1) traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) def test_component_call_in_function(self): with self.make_in_memory_cache(): - cache = aot_util.component_cache.value + cache = aot.get_cache() @aot.component(key='f') def f(x): return x + 1.0 @@ -311,27 +310,30 @@ def g(x): self.assertEqual(f(1.0), 2.0) self.assertEqual(g(1.0), 3.0) - self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + self.assertEqual(cache.hits(f.component_key), 1) traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) - self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) + self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) def test_explicit_cached_lowering(self): with self.make_in_memory_cache(): - cache = aot_util.component_cache.value + cache = aot.get_cache() + + @aot.component(key='f') def f(x): return x + 1.0 - component_f = aot.component(key='f')(f) - lowered = component_f.lower(jax.ShapeDtypeStruct((), 'float32')) + lowered = f.lower(jax.ShapeDtypeStruct((), 'float32')) self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) - component_g = aot.component(key='f')(f) - lowered = component_g.lower(jax.ShapeDtypeStruct((), 'float32')) - self.assertEqual(cache.get('f.lowering', update_hits=False).hits, 1) + @aot.component(key='f') + def g(x): + raise NotImplementedError + lowered = g.lower(jax.ShapeDtypeStruct((), 'float32')) + self.assertEqual(cache.get('f.lowering').hits, 1) def test_vmap_of_component(self): with self.make_in_memory_cache(): - cache = aot_util.component_cache.value + cache = aot.get_cache() @aot.component(key='f') def f(x): return x + 1.0 From 0fab387926216928bfcab9d76c89a46c8989a425 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 3 Nov 2025 14:20:43 -0500 Subject: [PATCH 07/24] Update --- jax/_src/aot.py | 4 +++- jax/_src/aot_util.py | 51 +++++++++++++++++++++++++++----------------- jax/_src/pjit.py | 1 - tests/aot_test.py | 31 ++++++++++++++++++++------- 4 files changed, 57 insertions(+), 30 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 215602d21d9d..75628a0ddc3d 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -63,6 +63,7 @@ def component_abstract_eval( *args, fun: Callable[..., Any], component_key: ComponentKey ) -> Sequence[core.AbstractValue] | None: entry = aot_util.get_entry(component_key) + logging.info('component_abstract_eval got entry %s', component_key) if entry is None: traced = aot_util.get_traced(component_key, fun, *args) avals_out = tree_util.tree_map( @@ -77,6 +78,7 @@ def component_lowering( ) -> Sequence[ir.Value]: with ctx.module_context.context as ir_ctx: entry = aot_util.get_entry(component_key, ir_ctx) + logging.info('component_lowering got entry %s', component_key) if entry is None: raise ValueError("Should hit abstract_eval already, which would populate.") @@ -113,7 +115,7 @@ def component_lowering( # rid of the submodule context before merging. Could we just create it with # the right context? entry.module = module = ir.Module.parse(mlir.module_to_bytecode(module)) - aot_util.put_entry(component_key, entry) + aot_util.put_entry(component_key, entry, update=True) symtab = ir.SymbolTable(module.operation) module = mlir.merge_mlir_modules( diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 5d0fd562ed4d..814bd8f6e974 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -15,6 +15,7 @@ from collections.abc import Hashable import pickle +import traceback from typing import Any, Callable, NamedTuple, Self, Sequence from absl import logging @@ -83,11 +84,9 @@ def __init__( self, avals_out: Sequence[core.AbstractValue] | None, module: ir.Module | None = None, - hits: int = 0, ): self.avals_out = avals_out self.module = module - self.hits = hits def serialize(self) -> SerializedType: module_bytecode = None @@ -96,7 +95,9 @@ def serialize(self) -> SerializedType: return pickle.dumps((self.avals_out, module_bytecode)) @classmethod - def deserialize(cls, blob: SerializedType, ctx: ir.Context | None = None) -> Self: + def deserialize( + cls, blob: SerializedType, ctx: ir.Context | None = None + ) -> Self: avals_out, module_bytecode = pickle.loads(blob) if module_bytecode is None or ctx is None: module = None @@ -111,49 +112,59 @@ class Cache(NamedTuple): put: Callable[[ComponentKey, bytes], None] keys: Callable[[], list[ComponentKey]] clear: Callable[[], None] - hits: Callable[[ComponentKey], int] + info: Callable[[ComponentKey], dict[str, Any]] _in_memory_cache: dict[ComponentKey, SerializedType] = {} -_in_memory_cache_hits: dict[ComponentKey, int] = {} +_in_memory_cache_info: dict[ComponentKey, dict[str, Any]] = {} def make_in_memory_cache(): def get(key: ComponentKey) -> SerializedType | None: - logging.info('getting key') - hits = _in_memory_cache_hits.setdefault(key, -1) - _in_memory_cache_hits[key] += 1 - return _in_memory_cache.get(key, None) + entry = _in_memory_cache.get(key, None) + if entry is None: + _in_memory_cache_info[key] = dict(hits=0) + else: + _in_memory_cache_info[key] = dict( + hits=_in_memory_cache_info[key]["hits"] + 1 + ) + return entry - def put(key: ComponentKey, data: SerializedType): + def put(key: ComponentKey, data: SerializedType, update: bool): _in_memory_cache[key] = data + if not update: + _in_memory_cache_info[key] = dict(hits=0) def keys() -> list[ComponentKey]: return list(_in_memory_cache.keys()) - def clear(): + def clear() -> None: _in_memory_cache.clear() - _in_memory_cache_hits.clear() + _in_memory_cache_info.clear() - def hits(key: ComponentKey): - if key not in _in_memory_cache_hits: - raise ValueError(f"key {key} not found in cache hits") - return _in_memory_cache_hits[key] + def info(key: ComponentKey) -> dict[str, Any]: + if key not in _in_memory_cache_info: + raise ValueError(f"`{key}` not found in _in_memory_cache_info") + return _in_memory_cache_info[key] - return Cache(get, put, keys, clear, hits) + return Cache(get, put, keys, clear, info) def get_cache() -> Cache | None: return component_cache.value -def get_entry(key: ComponentKey, ctx: ir.Context | None = None) -> CacheEntry | None: +def get_entry( + key: ComponentKey, ctx: ir.Context | None = None +) -> CacheEntry | None: if (cache := get_cache()) is not None: if (blob := cache.get(key)) is not None: return CacheEntry.deserialize(blob, ctx) return None # sigh pytype -def put_entry(key: ComponentKey, entry: CacheEntry) -> None: +def put_entry( + key: ComponentKey, entry: CacheEntry, update: bool = False +) -> None: if (cache := get_cache()) is not None: - cache.put(key, entry.serialize()) + cache.put(key, entry.serialize(), update) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6ef1dc6ac5b7..180953c01781 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -633,7 +633,6 @@ def _infer_params( def _infer_params_internal( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: - logging.info("infer_params fun: %s", fun) ctx_mesh = mesh_lib.get_concrete_mesh() dbg_fn = lambda: debug_info( 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, diff --git a/tests/aot_test.py b/tests/aot_test.py index 385b417b2551..b07c4f644b2f 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -274,6 +274,13 @@ def make_in_memory_cache(self): aot.clear_caches() jax.clear_caches() + # NOTE(dsuo): Disable checks because otherwise we check jaxprs in (at least) + # four places and makes reasoning about cache hits and misses harder. + # 1. After the initial abstract eval. + # 2. Before converting const vars. + # 3. After lifting the jaxpr. + # 4. After DCE. + @config.enable_checks(False) def test_component_lowering_cache_hit(self): with self.make_in_memory_cache(): cache = aot.get_cache() @@ -282,9 +289,11 @@ def f(x): return x + 1.0 self.assertEqual(f(1.0), 2.0) + self.assertEqual(cache.keys(), [f.component_key]) # We get 1 hit on traced cache during the lowering rule. self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) - # self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) + # We get 1 hit on the disk cache during the lowering rule. + self.assertEqual(cache.info(f.component_key)['hits'], 1) @aot.component(key='f') def g(x): @@ -293,10 +302,12 @@ def g(x): self.assertEqual(g(1.0), 2.0) # We get 1 hit for component cache and so we don't even check traced # cache. - self.assertEqual(cache.hits(f.component_key), 1) - traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) - self.assertEqual(aot_util._traced_cache[traced_key].hits, 1) + self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + # We get two additional hits on the disk cache during abstract eval and + # lowering for g. + self.assertEqual(cache.info(f.component_key)['hits'], 3) + @config.enable_checks(False) def test_component_call_in_function(self): with self.make_in_memory_cache(): cache = aot.get_cache() @@ -308,12 +319,14 @@ def f(x): def g(x): return f(x) + 1.0 + # 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) + # 1 hit when lowering g. Why no abstract eval? self.assertEqual(g(1.0), 3.0) - self.assertEqual(cache.hits(f.component_key), 1) - traced_key = aot_util.make_abstract_eval_key(aot_util.ComponentKey('f')) + self.assertEqual(cache.info(f.component_key)['hits'], 2) self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + @config.enable_checks(False) def test_explicit_cached_lowering(self): with self.make_in_memory_cache(): cache = aot.get_cache() @@ -323,14 +336,16 @@ def f(x): return x + 1.0 lowered = f.lower(jax.ShapeDtypeStruct((), 'float32')) - self.assertEqual(cache.keys(), ['f.abstract_eval', 'f.lowering']) + self.assertEqual(cache.keys(), [f.component_key]) @aot.component(key='f') def g(x): raise NotImplementedError lowered = g.lower(jax.ShapeDtypeStruct((), 'float32')) - self.assertEqual(cache.get('f.lowering').hits, 1) + self.assertEqual(cache.info(f.component_key)['hits'], 3) + self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + @config.enable_checks(False) def test_vmap_of_component(self): with self.make_in_memory_cache(): cache = aot.get_cache() From b182e5c655cd576c36461195ce90bd29ebb1c421 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 4 Nov 2025 07:11:38 -0500 Subject: [PATCH 08/24] Update --- jax/_src/aot.py | 92 ++++++++++++++++++++++++++++++++++++++------ jax/_src/aot_util.py | 4 ++ tests/aot_test.py | 3 +- 3 files changed, 87 insertions(+), 12 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 75628a0ddc3d..c4f84766168d 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -20,12 +20,15 @@ from absl import logging from jax._src import aot_util from jax._src import api +from jax._src import api_util from jax._src import core +from jax._src import linear_util as lu from jax._src import traceback_util from jax._src import tree_util from jax._src import util from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect @@ -35,19 +38,40 @@ get_cache = aot_util.get_cache -def component(key: UserKey = None) -> Callable[..., Any]: +def component( + key: UserKey = None, + # TODO(dsuo): This is really only useful for if we have return of length 1. + # I.e., interpret (x,) as x. If this is True, then interpret (x,) as (x,). + # If we see more than one return value, then of course you have multiple + # results. Need to think about this a bit more since component_p is similar to + # a call primitive, which requires multiple_results=True, but want to + # distinguish between (x,) and x. + multiple_results: bool = False, +) -> Callable[..., Any]: def _component(fun: Callable[..., Any]): # TODO(dsuo): Do we have all the information we need at this point to make # the component key? component_key = ComponentKey(key) + # TODO(dsuo): Jit your function if it isn't. This is so we can produce the + # debug_info object we need in order to wrap fun later on in batching, but + # might be the wrong way of doing things. + if not isinstance(fun, xc._xla.PjitFunction): + fun = api.jit(fun) + @api.jit @util.wraps(fun) @traceback_util.api_boundary - def wrapper(*args): + def wrapper(*args, **kwargs): # TODO(dsuo): Flatten function as in shard_map pmap. + # TODO(dsuo): Do something about kwargs. # TODO(dsuo): Need to consider static args. - return component_p.bind(*args, fun=fun, component_key=component_key) + return component_p.bind( + *args, + fun=fun, + component_key=component_key, + multiple_results=multiple_results, + ) wrapper.component_key = component_key return wrapper @@ -56,14 +80,19 @@ def wrapper(*args): def component_impl(*args, fun: Callable[..., Any], **_): + if isinstance(fun, lu.WrappedFun): + return fun.call_wrapped(*args) return fun(*args) def component_abstract_eval( - *args, fun: Callable[..., Any], component_key: ComponentKey + *args, + fun: Callable[..., Any], + component_key: ComponentKey, + multiple_results: bool, ) -> Sequence[core.AbstractValue] | None: entry = aot_util.get_entry(component_key) - logging.info('component_abstract_eval got entry %s', component_key) + logging.info("component_abstract_eval got entry %s", component_key) if entry is None: traced = aot_util.get_traced(component_key, fun, *args) avals_out = tree_util.tree_map( @@ -74,17 +103,21 @@ def component_abstract_eval( def component_lowering( - ctx, *args, fun: Callable[..., Any], component_key: ComponentKey + ctx, + *args, + fun: Callable[..., Any], + component_key: ComponentKey, + multiple_results: bool, ) -> Sequence[ir.Value]: with ctx.module_context.context as ir_ctx: entry = aot_util.get_entry(component_key, ir_ctx) - logging.info('component_lowering got entry %s', component_key) + logging.info("component_lowering got entry %s", component_key) if entry is None: raise ValueError("Should hit abstract_eval already, which would populate.") module_name = f"{component_key}.module" if (module := entry.module) is None: - logging.info('missed lowering: %s', fun) + logging.info("missed lowering: %s", fun) traced = aot_util.get_traced(component_key, fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, @@ -134,9 +167,46 @@ def component_lowering( def component_batcher( - vals_in, dims_in, fun: Callable[..., Any], component_key: ComponentKey + axis_data, + vals_in, + dims_in, + fun: Callable[..., Any], + component_key: ComponentKey, + multiple_results: bool, ): - return fun(vals_in[0]), dims_in[0] + # Missing from batching process_call: + # TODO(dsuo): Ignore ragged. + # TODO(dsuo): Ignore updating annotations. + + ji = fun._jit_info + + # TODO(dsuo): Dummy debug info + debug_info = lu.DebugInfo( + "component_batcher", fun.__name__, arg_names=None, result_paths=None + ) + # ????(dsuo): Should we be wrapping fun? + f_ = lu.wrap_init(fun, debug_info=debug_info) + + # ????(dsuo): I don't understand trace tags. + f_, dims_out = batching.batch_subtrace( + f_, core.TraceTag(), axis_data, tuple(dims_in) + ) + + # TODO(dsuo): Derp how to actually do this. + # with core.set_current_trace(vals_in[0]._trace.parent_trace): + # assert False + component_key.vmap() + vals_out = component_p.bind( + *vals_in, + fun=f_, + component_key=component_key, + multiple_results=multiple_results, + ) + # if not multiple_results and len(vals_out) == 1: + # vals_out = vals_out[0] + # dims_out = dims_out[0] + # logging.info("component_batcher vals_out %s, %s", vals_out, dims_out) + return vals_out, dims_out def clear_caches(): @@ -150,4 +220,4 @@ def clear_caches(): # TODO(dsuo): Figure out multiple_results i.e., distinguishing between (1,) and # 1. mlir.register_lowering(component_p, component_lowering) -batching.primitive_batchers[component_p] = component_batcher +batching.fancy_primitive_batchers[component_p] = component_batcher diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 814bd8f6e974..76f31c8f30b0 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -48,6 +48,10 @@ def __str__(self): def __repr__(self): return self.__str__ + # TODO(dsuo): This is just a hack for now. + def vmap(self): + self.user_key = f'vmap({self.user_key})' + def _validate_component_cache(val): assert val is None or isinstance(val, Cache) diff --git a/tests/aot_test.py b/tests/aot_test.py index b07c4f644b2f..8c1326ea891c 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -355,7 +355,8 @@ def f(x): vmapped_f = jax.vmap(f) - self.assertArraysEqual(vmapped_f(jax.numpy.ones(8,)), jax.numpy.ones(8,) + 1.0) + with config.checking_leaks(): + self.assertArraysEqual(vmapped_f(jax.numpy.ones(8,)), jax.numpy.ones(8,) + 1.0) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 14633b51a30ee60ae68d04afe9eb49b51b9d136f Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 4 Nov 2025 14:18:43 -0500 Subject: [PATCH 09/24] Update --- jax/_src/aot.py | 76 +++++++------------ jax/_src/aot_util.py | 19 ++++- tests/aot_test.py | 175 +++++++++++++++++++++++++------------------ 3 files changed, 145 insertions(+), 125 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index c4f84766168d..1110d6555843 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -14,6 +14,7 @@ """JAX AOT API""" from collections.abc import Hashable +import traceback from typing import Any, Callable, Sequence @@ -40,38 +41,36 @@ def component( key: UserKey = None, - # TODO(dsuo): This is really only useful for if we have return of length 1. - # I.e., interpret (x,) as x. If this is True, then interpret (x,) as (x,). - # If we see more than one return value, then of course you have multiple - # results. Need to think about this a bit more since component_p is similar to - # a call primitive, which requires multiple_results=True, but want to - # distinguish between (x,) and x. - multiple_results: bool = False, ) -> Callable[..., Any]: def _component(fun: Callable[..., Any]): + # TODO(dsuo): Need to consider static args, etc if fun is jitted. # TODO(dsuo): Do we have all the information we need at this point to make # the component key? component_key = ComponentKey(key) - # TODO(dsuo): Jit your function if it isn't. This is so we can produce the - # debug_info object we need in order to wrap fun later on in batching, but - # might be the wrong way of doing things. - if not isinstance(fun, xc._xla.PjitFunction): - fun = api.jit(fun) - @api.jit @util.wraps(fun) @traceback_util.api_boundary def wrapper(*args, **kwargs): - # TODO(dsuo): Flatten function as in shard_map pmap. - # TODO(dsuo): Do something about kwargs. - # TODO(dsuo): Need to consider static args. - return component_p.bind( + args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) + dbg = api_util.debug_info("component", fun, args, kwargs) + wrapped_fun = lu.wrap_init(fun, debug_info=dbg) + flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) + + # TODO(dsuo): Hack to clear lu.Store borrowed from pmap. + f_transformed = flat_fun.f_transformed + def reset_stores_f_transformed(*args, **kwargs): + for store in flat_fun.stores: + if store is not None: + store.reset() + return f_transformed(*args, **kwargs) + flat_fun.f_transformed = reset_stores_f_transformed + out_flat = component_p.bind( *args, - fun=fun, + fun=flat_fun, component_key=component_key, - multiple_results=multiple_results, ) + return tree_util.tree_unflatten(out_tree(), out_flat) wrapper.component_key = component_key return wrapper @@ -89,11 +88,13 @@ def component_abstract_eval( *args, fun: Callable[..., Any], component_key: ComponentKey, - multiple_results: bool, ) -> Sequence[core.AbstractValue] | None: + # ????(dsuo): Is this an effectful rule? entry = aot_util.get_entry(component_key) logging.info("component_abstract_eval got entry %s", component_key) if entry is None: + logging.info("missed abstract_eval %s", component_key) + traceback.print_stack() traced = aot_util.get_traced(component_key, fun, *args) avals_out = tree_util.tree_map( lambda x: core.ShapedArray(x.shape, x.dtype), traced.out_info @@ -107,7 +108,6 @@ def component_lowering( *args, fun: Callable[..., Any], component_key: ComponentKey, - multiple_results: bool, ) -> Sequence[ir.Value]: with ctx.module_context.context as ir_ctx: entry = aot_util.get_entry(component_key, ir_ctx) @@ -117,7 +117,7 @@ def component_lowering( module_name = f"{component_key}.module" if (module := entry.module) is None: - logging.info("missed lowering: %s", fun) + logging.info("missed lowering: %s", component_key) traced = aot_util.get_traced(component_key, fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, @@ -172,41 +172,22 @@ def component_batcher( dims_in, fun: Callable[..., Any], component_key: ComponentKey, - multiple_results: bool, ): # Missing from batching process_call: # TODO(dsuo): Ignore ragged. # TODO(dsuo): Ignore updating annotations. - ji = fun._jit_info - - # TODO(dsuo): Dummy debug info - debug_info = lu.DebugInfo( - "component_batcher", fun.__name__, arg_names=None, result_paths=None - ) - # ????(dsuo): Should we be wrapping fun? - f_ = lu.wrap_init(fun, debug_info=debug_info) - # ????(dsuo): I don't understand trace tags. - f_, dims_out = batching.batch_subtrace( - f_, core.TraceTag(), axis_data, tuple(dims_in) + batched_fun, dims_out = batching.batch_subtrace( + fun, core.TraceTag(), axis_data, tuple(dims_in) ) - # TODO(dsuo): Derp how to actually do this. - # with core.set_current_trace(vals_in[0]._trace.parent_trace): - # assert False - component_key.vmap() vals_out = component_p.bind( *vals_in, - fun=f_, - component_key=component_key, - multiple_results=multiple_results, + fun=batched_fun, + component_key=ComponentKey.vmap(component_key), ) - # if not multiple_results and len(vals_out) == 1: - # vals_out = vals_out[0] - # dims_out = dims_out[0] - # logging.info("component_batcher vals_out %s, %s", vals_out, dims_out) - return vals_out, dims_out + return vals_out, dims_out() def clear_caches(): @@ -215,9 +196,8 @@ def clear_caches(): component_p = core.Primitive("component") +component_p.multiple_results = True component_p.def_impl(component_impl) component_p.def_abstract_eval(component_abstract_eval) -# TODO(dsuo): Figure out multiple_results i.e., distinguishing between (1,) and -# 1. mlir.register_lowering(component_p, component_lowering) batching.fancy_primitive_batchers[component_p] = component_batcher diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 76f31c8f30b0..82160e3c4864 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -46,11 +46,12 @@ def __str__(self): return self.user_key def __repr__(self): - return self.__str__ + return self.__str__() # TODO(dsuo): This is just a hack for now. - def vmap(self): - self.user_key = f'vmap({self.user_key})' + @classmethod + def vmap(cls, key): + return ComponentKey(f"vmap({key.user_key})") def _validate_component_cache(val): @@ -70,6 +71,12 @@ def __init__(self, traced: stages.Traced, hits: int = 0): self.traced = traced self.hits = hits + def __str__(self): + return f"{self.traced.fun_name}: {self.hits}" + + def __repr__(self): + return self.__str__() + _traced_cache: dict[Hashable, TracedCacheEntry] = {} @@ -77,8 +84,12 @@ def __init__(self, traced: stages.Traced, hits: int = 0): def get_traced(key: Hashable, fun: Callable[..., Any], *args): entry = _traced_cache.get(key, None) if entry is None: - entry = _traced_cache[key] = TracedCacheEntry(api.trace(fun, *args)) + logging.info("missed trace cache %s", key) + entry = _traced_cache[key] = TracedCacheEntry( + api.trace(fun.f_transformed, *args) + ) else: + logging.info("hit trace cache %s", key) entry.hits += 1 return entry.traced diff --git a/tests/aot_test.py b/tests/aot_test.py index 8c1326ea891c..df5e8d6f8d01 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -29,8 +29,8 @@ from jax._src.lib import xla_client as xc from jax.experimental import topologies from jax.experimental.serialize_executable import ( - deserialize_and_load, - serialize, + deserialize_and_load, + serialize, ) import jax.numpy as jnp from jax.sharding import PartitionSpec as P @@ -42,19 +42,19 @@ with contextlib.suppress(ImportError): import pytest + pytestmark = pytest.mark.multiaccelerator class JaxAotTest(jtu.JaxTestCase): - - @jtu.run_on_devices('tpu', 'gpu') + @jtu.run_on_devices("tpu", "gpu") def test_pickle_jit_lower(self): def fun(x): return x * x - with jax.set_mesh(jax.sharding.Mesh(np.array(jax.devices()), ('data',))): + with jax.set_mesh(jax.sharding.Mesh(np.array(jax.devices()), ("data",))): lowered = jax.jit( - fun, in_shardings=P('data'), out_shardings=P(None, 'data') + fun, in_shardings=P("data"), out_shardings=P(None, "data") ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) def verify_serialization(lowered): @@ -65,33 +65,34 @@ def verify_serialization(lowered): verify_serialization(lowered) verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100))) verify_serialization( - jax.pmap(lambda x: x * x).lower( - np.zeros((len(jax.devices()), 4), dtype=np.float32))) + jax.pmap(lambda x: x * x).lower( + np.zeros((len(jax.devices()), 4), dtype=np.float32) + ) + ) @jtu.skip_on_devices("tpu") # TODO(phawkins): This test is segfaulting on TPU def test_topology_jit_serialize(self): try: aot_topo = topologies.get_topology_desc( - platform=jax.devices()[0].platform + platform=jax.devices()[0].platform ) except NotImplementedError: - raise unittest.SkipTest('PJRT Topology not supported') + raise unittest.SkipTest("PJRT Topology not supported") if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: - raise unittest.SkipTest('Compilation caching not yet supported.') + raise unittest.SkipTest("Compilation caching not yet supported.") if jtu.is_device_cuda(): - raise unittest.SkipTest('Broken on GPU: b/442353988') + raise unittest.SkipTest("Broken on GPU: b/442353988") @jax.jit def fn(x): return x * x def lower_and_load(mesh): - s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + s = jax.sharding.NamedSharding(mesh, P("x", "y")) x_shape = jax.ShapeDtypeStruct( - shape=(16, 16), - dtype=jnp.dtype('float32'), - sharding=s) + shape=(16, 16), dtype=jnp.dtype("float32"), sharding=s + ) lowered = fn.lower(x_shape) serialized, in_tree, out_tree = serialize(lowered.compile()) compiled = deserialize_and_load(serialized, in_tree, out_tree) @@ -101,30 +102,30 @@ def lower_and_load(mesh): n = max(1, len(ref_topo.devices) // 2) mesh_shape = (len(ref_topo.devices) // n, n) - ref_mesh = topologies.make_mesh(ref_topo, mesh_shape, ('x', 'y')) - aot_mesh = topologies.make_mesh(aot_topo, mesh_shape, ('x', 'y')) + ref_mesh = topologies.make_mesh(ref_topo, mesh_shape, ("x", "y")) + aot_mesh = topologies.make_mesh(aot_topo, mesh_shape, ("x", "y")) self.assertEqual( - lower_and_load(ref_mesh).as_text(), lower_and_load(aot_mesh).as_text() + lower_and_load(ref_mesh).as_text(), lower_and_load(aot_mesh).as_text() ) def test_get_topology_from_devices(self): try: aot_topo = topologies.get_topology_desc( - platform=jax.devices()[0].platform + platform=jax.devices()[0].platform ) except NotImplementedError: - raise unittest.SkipTest('PJRT Topology not supported') + raise unittest.SkipTest("PJRT Topology not supported") topo = xc.get_topology_for_devices(aot_topo.devices) self.assertEqual( - topo.platform_version, aot_topo.devices[0].client.platform_version + topo.platform_version, aot_topo.devices[0].client.platform_version ) def test_lower_as_text_with_and_without_debug_info(self): def my_function(x): return jnp.sin(x) - lowered = jax.jit(my_function).lower(42.) + lowered = jax.jit(my_function).lower(42.0) stablehlo = lowered.as_text("stablehlo", debug_info=True) self.assertRegex(stablehlo, r"sine.* loc") stablehlo = lowered.as_text("stablehlo") @@ -137,30 +138,37 @@ def my_function(x): def test_constants_in_lowering_in_aot(self): const_size = 100 - const = jax.random.uniform(jax.random.key(0), (const_size,), - dtype=np.float32) + const = jax.random.uniform( + jax.random.key(0), (const_size,), dtype=np.float32 + ) def my_function(x): return jnp.sin(x) + const - lowered = jax.jit(my_function).lower(np.full_like(const, 42., dtype=const.dtype)) + lowered = jax.jit(my_function).lower( + np.full_like(const, 42.0, dtype=const.dtype) + ) stablehlo = lowered.as_text("stablehlo") if config.use_simplified_jaxpr_constants.value: - self.assertNotRegex(stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x") + self.assertNotRegex( + stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x" + ) self.assertLen(lowered._lowering.const_args, 1) self.assertIs(lowered._lowering.const_args[0], const) else: - self.assertRegex(stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x") + self.assertRegex( + stablehlo, rf"stablehlo.constant dense.*tensor<{const_size}x" + ) self.assertLen(lowered._lowering.const_args, 0) def test_with_constants(self): - const = jnp.arange(16.) + 42. # A distinctive shape and value + const = jnp.arange(16.0) + 42.0 # A distinctive shape and value @jax.jit def f(x): return const[0:8] + x - inp = jnp.arange(8.) + inp = jnp.arange(8.0) compiled = f.lower(inp).compile() self.assertLen(compiled.args_info[0], 1) # Not including const_args self.assertLen(compiled.in_avals[0], 1) @@ -173,13 +181,14 @@ def f(x): self.assertCacheMisses(lambda: compiled(inp), cpp=0, aot_call=0) @jtu.parameterized_filterable( - kwargs=[ - dict(use_np=use_np, lower=lower, compile=compile, exec=exec) - for use_np in (False, True) - for lower in (False, True) - for compile in (False, True) - for exec in (False, True) - ]) + kwargs=[ + dict(use_np=use_np, lower=lower, compile=compile, exec=exec) + for use_np in (False, True) + for lower in (False, True) + for compile in (False, True) + for exec in (False, True) + ] + ) def test_with_constants_enable_x64(self, *, use_np, lower, compile, exec): # Closed-over constant is 64-bit. Each of lowering, compilation, and # execution can be run in 64-bit or 32-bit mode. @@ -191,7 +200,7 @@ def test_with_constants_enable_x64(self, *, use_np, lower, compile, exec): def f(x): return lax.convert_element_type(const, np.float32) + x - inp = np.arange(8., dtype=np.float32) + inp = np.arange(8.0, dtype=np.float32) with config.enable_x64(True) if lower else contextlib.nullcontext(): lowered = f.lower(inp) with config.enable_x64(True) if compile else contextlib.nullcontext(): @@ -222,17 +231,22 @@ def run(): # In some cases we expect errors: in 32-bit mode, lowered with 64-bit mode # and execute in 32-bit mode. - if (config.use_simplified_jaxpr_constants.value and - not config.enable_x64.value and - use_np and lower and not exec): + if ( + config.use_simplified_jaxpr_constants.value + and not config.enable_x64.value + and use_np + and lower + and not exec + ): with self.assertRaisesRegex( - xc.XlaRuntimeError, - "got buffer with incompatible size"): + xc.XlaRuntimeError, "got buffer with incompatible size" + ): run() return - self.assertArraysEqual(run(), - lax.convert_element_type(const, inp.dtype) + inp) + self.assertArraysEqual( + run(), lax.convert_element_type(const, inp.dtype) + inp + ) # Trigger cache hit self.assertCacheMisses(run, cpp=0, aot_call=0) @@ -244,10 +258,10 @@ def f(x): x_ref[...] += x f_lowered = f.lower(1) - with self.assertRaisesRegex(ValueError, 'serialize with a closed-over'): + with self.assertRaisesRegex(ValueError, "serialize with a closed-over"): serialized, in_tree, out_tree = serialize(f_lowered.compile()) - @jtu.run_on_devices('gpu', 'tpu') + @jtu.run_on_devices("gpu", "tpu") def test_mismatched_backends_raises(self): @jax.jit def f(x): @@ -257,15 +271,19 @@ def f(x): f_lowered = f.lower(x) serialized, in_tree, out_tree = serialize(f_lowered.compile()) with self.assertRaisesRegex( - ValueError, - 'Execution devices belong to a client other than `backend`'): - deserialize_and_load(serialized, in_tree, out_tree, backend='cpu', - execution_devices=jax.devices()[:1]) + ValueError, "Execution devices belong to a client other than `backend`" + ): + deserialize_and_load( + serialized, + in_tree, + out_tree, + backend="cpu", + execution_devices=jax.devices()[:1], + ) @jtu.thread_unsafe_test_class() class ComponentTest(jtu.JaxTestCase): - @contextlib.contextmanager def make_in_memory_cache(self): cache = aot_util.make_in_memory_cache() @@ -281,10 +299,11 @@ def make_in_memory_cache(self): # 3. After lifting the jaxpr. # 4. After DCE. @config.enable_checks(False) - def test_component_lowering_cache_hit(self): + def test_component_basic(self): with self.make_in_memory_cache(): cache = aot.get_cache() - @aot.component(key='f') + + @aot.component(key="f") def f(x): return x + 1.0 @@ -293,9 +312,9 @@ def f(x): # We get 1 hit on traced cache during the lowering rule. self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) # We get 1 hit on the disk cache during the lowering rule. - self.assertEqual(cache.info(f.component_key)['hits'], 1) + self.assertEqual(cache.info(f.component_key)["hits"], 1) - @aot.component(key='f') + @aot.component(key="f") def g(x): raise NotImplementedError @@ -305,13 +324,14 @@ def g(x): self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) # We get two additional hits on the disk cache during abstract eval and # lowering for g. - self.assertEqual(cache.info(f.component_key)['hits'], 3) + self.assertEqual(cache.info(f.component_key)["hits"], 3) @config.enable_checks(False) def test_component_call_in_function(self): with self.make_in_memory_cache(): cache = aot.get_cache() - @aot.component(key='f') + + @aot.component(key="f") def f(x): return x + 1.0 @@ -323,40 +343,49 @@ def g(x): self.assertEqual(f(1.0), 2.0) # 1 hit when lowering g. Why no abstract eval? self.assertEqual(g(1.0), 3.0) - self.assertEqual(cache.info(f.component_key)['hits'], 2) self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + self.assertEqual(cache.info(f.component_key)["hits"], 2) @config.enable_checks(False) def test_explicit_cached_lowering(self): with self.make_in_memory_cache(): cache = aot.get_cache() - @aot.component(key='f') + @aot.component(key="f") def f(x): return x + 1.0 - lowered = f.lower(jax.ShapeDtypeStruct((), 'float32')) + lowered = f.lower(jax.ShapeDtypeStruct((), "float32")) self.assertEqual(cache.keys(), [f.component_key]) - @aot.component(key='f') + @aot.component(key="f") def g(x): raise NotImplementedError - lowered = g.lower(jax.ShapeDtypeStruct((), 'float32')) - self.assertEqual(cache.info(f.component_key)['hits'], 3) + + lowered = g.lower(jax.ShapeDtypeStruct((), "float32")) self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + self.assertEqual(cache.info(f.component_key)["hits"], 3) @config.enable_checks(False) def test_vmap_of_component(self): - with self.make_in_memory_cache(): - cache = aot.get_cache() - @aot.component(key='f') - def f(x): - return x + 1.0 + with self.make_in_memory_cache(): + cache = aot.get_cache() + + @aot.component(key="f") + def f(x): + logging.info('running!') + return x + 1.0 + + vmapped_f = jax.vmap(f) + + # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. + vmapped_key = aot_util.ComponentKey.vmap(aot_util.ComponentKey("f")) - vmapped_f = jax.vmap(f) + self.assertArraysEqual(vmapped_f(jnp.ones((8,))), jnp.ones((8,)) + 1.0) + self.assertEqual(cache.keys(), [f.component_key, vmapped_key]) + self.assertEqual(aot_util._traced_cache[f.component_key].hits, 0) + self.assertEqual(aot_util._traced_cache[vmapped_key].hits, 1) - with config.checking_leaks(): - self.assertArraysEqual(vmapped_f(jax.numpy.ones(8,)), jax.numpy.ones(8,) + 1.0) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 291f383ad6363ec472b62b434b6b44979079f0cd Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 4 Nov 2025 16:37:42 -0500 Subject: [PATCH 10/24] Update --- jax/_src/aot.py | 44 +++++++----- jax/_src/aot_util.py | 49 ++++++-------- jax/_src/linear_util.py | 3 + tests/aot_test.py | 147 ++++++++++++++++++++++++++++++++++++---- 4 files changed, 182 insertions(+), 61 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 1110d6555843..8e829bee8cfc 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -39,6 +39,10 @@ get_cache = aot_util.get_cache +_fun_cache: dict[ComponentKey, Callable[..., Any]] = {} + + + def component( key: UserKey = None, ) -> Callable[..., Any]: @@ -47,32 +51,30 @@ def _component(fun: Callable[..., Any]): # TODO(dsuo): Do we have all the information we need at this point to make # the component key? component_key = ComponentKey(key) + fun = _fun_cache.setdefault(component_key, fun) @api.jit @util.wraps(fun) @traceback_util.api_boundary def wrapper(*args, **kwargs): args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) - dbg = api_util.debug_info("component", fun, args, kwargs) - wrapped_fun = lu.wrap_init(fun, debug_info=dbg) + wrapped_fun = lu.wrap_init( + fun, debug_info=api_util.debug_info("component", fun, args, kwargs) + ) flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) + flat_fun = aot_util.cached_flat_fun(flat_fun) + logging.info("component flat_fun %s:", id(flat_fun)) + jitted_fun = api.jit(flat_fun.call_wrapped) - # TODO(dsuo): Hack to clear lu.Store borrowed from pmap. - f_transformed = flat_fun.f_transformed - def reset_stores_f_transformed(*args, **kwargs): - for store in flat_fun.stores: - if store is not None: - store.reset() - return f_transformed(*args, **kwargs) - flat_fun.f_transformed = reset_stores_f_transformed out_flat = component_p.bind( *args, - fun=flat_fun, + fun=jitted_fun, component_key=component_key, ) return tree_util.tree_unflatten(out_tree(), out_flat) wrapper.component_key = component_key + wrapper.fun = fun return wrapper return _component @@ -94,10 +96,10 @@ def component_abstract_eval( logging.info("component_abstract_eval got entry %s", component_key) if entry is None: logging.info("missed abstract_eval %s", component_key) - traceback.print_stack() - traced = aot_util.get_traced(component_key, fun, *args) + if isinstance(fun, lu.WrappedFun): + fun = aot_util.maybe_reset_stores(fun).call_wrapped avals_out = tree_util.tree_map( - lambda x: core.ShapedArray(x.shape, x.dtype), traced.out_info + lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) ) aot_util.put_entry(component_key, entry := aot_util.CacheEntry(avals_out)) return entry.avals_out @@ -118,7 +120,9 @@ def component_lowering( module_name = f"{component_key}.module" if (module := entry.module) is None: logging.info("missed lowering: %s", component_key) - traced = aot_util.get_traced(component_key, fun, *ctx.avals_in) + if isinstance(fun, lu.WrappedFun): + fun = aot_util.maybe_reset_stores(fun).call_wrapped + traced = api.trace(fun, *ctx.avals_in) lowering_result = mlir.lower_jaxpr_to_module( module_name=module_name, jaxpr=traced.jaxpr, @@ -177,9 +181,14 @@ def component_batcher( # TODO(dsuo): Ignore ragged. # TODO(dsuo): Ignore updating annotations. + # TODO(dsuo): Dummy debug info. + wrapped_fun = lu.wrap_init( + fun, debug_info=lu.DebugInfo("vmap(component)", fun.__name__, None, None) + ) + # ????(dsuo): I don't understand trace tags. batched_fun, dims_out = batching.batch_subtrace( - fun, core.TraceTag(), axis_data, tuple(dims_in) + wrapped_fun, core.TraceTag(), axis_data, tuple(dims_in) ) vals_out = component_p.bind( @@ -191,8 +200,7 @@ def component_batcher( def clear_caches(): - aot_util.component_cache.value.clear() - aot_util._traced_cache.clear() + get_cache().clear() component_p = core.Primitive("component") diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 82160e3c4864..dbd9ab0898ec 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -22,6 +22,7 @@ from jax._src import api from jax._src import config from jax._src import core +from jax._src import linear_util as lu from jax._src import stages from jax._src.interpreters import mlir from jax._src.lib.mlir import ir @@ -66,34 +67,6 @@ def _validate_component_cache(val): ) -class TracedCacheEntry: - def __init__(self, traced: stages.Traced, hits: int = 0): - self.traced = traced - self.hits = hits - - def __str__(self): - return f"{self.traced.fun_name}: {self.hits}" - - def __repr__(self): - return self.__str__() - - -_traced_cache: dict[Hashable, TracedCacheEntry] = {} - - -def get_traced(key: Hashable, fun: Callable[..., Any], *args): - entry = _traced_cache.get(key, None) - if entry is None: - logging.info("missed trace cache %s", key) - entry = _traced_cache[key] = TracedCacheEntry( - api.trace(fun.f_transformed, *args) - ) - else: - logging.info("hit trace cache %s", key) - entry.hits += 1 - return entry.traced - - class CacheEntry: def __init__( self, @@ -183,3 +156,23 @@ def put_entry( ) -> None: if (cache := get_cache()) is not None: cache.put(key, entry.serialize(), update) + +@lu.cache +def cached_flat_fun(flat_fun): + return maybe_reset_stores(flat_fun) + + +# TODO(dsuo): Share logic with pmap. +def maybe_reset_stores(fun): + # TODO(dsuo): Hack to clear lu.Store borrowed from pmap. + f_transformed = fun.f_transformed + + # TODO(dsuo): Add this as a transformation. + def reset_stores_f_transformed(*args, **kwargs): + for store in fun.stores: + if store is not None: + store.reset() + return f_transformed(*args, **kwargs) + + fun.f_transformed = reset_stores_f_transformed + return fun diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 7491fabab7d0..967a2e23abea 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -505,6 +505,9 @@ def _evict_function(f): memoized_fun.cache_clear = fun_caches.clear # type: ignore memoized_fun.evict_function = _evict_function # type: ignore + memoized_fun.cache_items = fun_caches.items # type: ignore + memoized_fun.cache_get = fun_caches.get # type: ignore + memoized_fun.cache_keys = fun_caches.keys # type: ignore register_cache(memoized_fun, str(call)) return memoized_fun diff --git a/tests/aot_test.py b/tests/aot_test.py index df5e8d6f8d01..8c5764147c09 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -14,6 +14,7 @@ import contextlib import logging +from typing import Any, Callable import unittest from absl.testing import absltest @@ -292,6 +293,21 @@ def make_in_memory_cache(self): aot.clear_caches() jax.clear_caches() + # TODO(dsuo): It would be nice to have a way to grab the pjit jaxpr cache + # key easily. + def get_pjit_jaxpr_key( + self, fun: Callable[..., Any] + ) -> Callable[..., Any] | None: + for key in pjit._create_pjit_jaxpr.cache_keys(): + f = key if not hasattr(key, "__wrapped__") else key.__wrapped__ + if f == fun: + return key + + def get_pjit_jaxpr_entry( + self, key: Callable[..., Any] + ) -> dict[Callable[..., Any], dict[Any, Any]]: + return pjit._create_pjit_jaxpr.cache_get(key).items() + # NOTE(dsuo): Disable checks because otherwise we check jaxprs in (at least) # four places and makes reasoning about cache hits and misses harder. # 1. After the initial abstract eval. @@ -309,25 +325,37 @@ def f(x): self.assertEqual(f(1.0), 2.0) self.assertEqual(cache.keys(), [f.component_key]) - # We get 1 hit on traced cache during the lowering rule. - self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) - # We get 1 hit on the disk cache during the lowering rule. + + # Make sure the underlying function f.fun exists in the jaxpr cache. + pjit_key = self.get_pjit_jaxpr_key(f.fun) + self.assertIsNotNone(pjit_key) + # Make sure there is only one entry for f.fun. If there are more, then it + # means the lowering rule missed. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + # We get 1 hit on the disk cache during the lowering rule. However, this + # hit is for an incomplete CacheEntry i.e., only avals_out were populated + # and not the lowered module. The lowering rule updates the CacheEntry + # with the lowered module. self.assertEqual(cache.info(f.component_key)["hits"], 1) @aot.component(key="f") def g(x): raise NotImplementedError + # We ignore g's implementation because it was turned into a component with + # key "f". + self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) - # We get 1 hit for component cache and so we don't even check traced - # cache. - self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + # Confirm we still have just one entry in the jaxpr cache. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) # We get two additional hits on the disk cache during abstract eval and # lowering for g. self.assertEqual(cache.info(f.component_key)["hits"], 3) @config.enable_checks(False) - def test_component_call_in_function(self): + def test_component_in_function(self): with self.make_in_memory_cache(): cache = aot.get_cache() @@ -339,15 +367,98 @@ def f(x): def g(x): return f(x) + 1.0 - # 1 hit when lowering f. + # Create cache entry when abstract_eval f. 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) - # 1 hit when lowering g. Why no abstract eval? + self.assertEqual(cache.keys(), [f.component_key]) + self.assertEqual(cache.info(f.component_key)["hits"], 1) + # Make sure the underlying function f.fun exists in the jaxpr cache. + pjit_key = self.get_pjit_jaxpr_key(f.fun) + self.assertIsNotNone(pjit_key) + # Make sure there is only one entry for f.fun. If there are more, then it + # means the lowering rule missed. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + # 1 hit when lowering g. g is not a component, so doesn't look up + # CacheEntry during abstract_eval. self.assertEqual(g(1.0), 3.0) - self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) + # Make sure we didn't add any new entries for f.fun. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) self.assertEqual(cache.info(f.component_key)["hits"], 2) @config.enable_checks(False) - def test_explicit_cached_lowering(self): + def test_jit_of_component(self): + with self.make_in_memory_cache(): + cache = aot.get_cache() + + @jax.jit + @aot.component(key="f") + def f(x): + return x + 1.0 + + # Create cache entry when abstract_eval f. 1 hit when lowering f. + self.assertEqual(f(1.0), 2.0) + # Make sure the underlying function f.fun exists in the jaxpr cache. + pjit_key = self.get_pjit_jaxpr_key(f.fun) + self.assertIsNotNone(pjit_key) + # Make sure there is only one entry for f.fun. If there are more, then it + # means the lowering rule missed. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + + @aot.component(key="f") + def g(x): + raise NotImplementedError + + # We ignore g's implementation because it was turned into a component with + # key "f". + self.assertEqual(f.fun, g.fun) + self.assertEqual(g(1.0), 2.0) + # Confirm we still have just one entry in the jaxpr cache. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + # We get two additional hits on the disk cache during abstract eval and + # lowering for g. + self.assertEqual(cache.info(f.component_key)["hits"], 3) + + + @config.enable_checks(False) + def test_component_of_jit(self): + with self.make_in_memory_cache(): + cache = aot.get_cache() + + @aot.component(key="f") + @jax.jit + def f(x): + return x + 1.0 + + # Create cache entry when abstract_eval f. 1 hit when lowering f. + self.assertEqual(f(1.0), 2.0) + # Make sure the underlying function f.fun exists in the jaxpr cache. + pjit_key = self.get_pjit_jaxpr_key(f.fun) + self.assertIsNotNone(pjit_key) + # Make sure there is only one entry for f.fun. If there are more, then it + # means the lowering rule missed. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + + @aot.component(key="f") + def g(x): + raise NotImplementedError + + # We ignore g's implementation because it was turned into a component with + # key "f". + self.assertEqual(f.fun, g.fun) + self.assertEqual(g(1.0), 2.0) + # Confirm we still have just one entry in the jaxpr cache. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + # We get two additional hits on the disk cache during abstract eval and + # lowering for g. + self.assertEqual(cache.info(f.component_key)["hits"], 3) + + @config.enable_checks(False) + def test_explicit_lowering(self): with self.make_in_memory_cache(): cache = aot.get_cache() @@ -358,12 +469,18 @@ def f(x): lowered = f.lower(jax.ShapeDtypeStruct((), "float32")) self.assertEqual(cache.keys(), [f.component_key]) + pjit_key = self.get_pjit_jaxpr_key(f.fun) + self.assertIsNotNone(pjit_key) + # Make sure there is only one entry for f.fun. If there are more, then it + # means the lowering rule missed. + num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + self.assertEqual(num_entries, 1) + @aot.component(key="f") def g(x): raise NotImplementedError lowered = g.lower(jax.ShapeDtypeStruct((), "float32")) - self.assertEqual(aot_util._traced_cache[f.component_key].hits, 1) self.assertEqual(cache.info(f.component_key)["hits"], 3) @config.enable_checks(False) @@ -373,7 +490,7 @@ def test_vmap_of_component(self): @aot.component(key="f") def f(x): - logging.info('running!') + logging.info("running!") return x + 1.0 vmapped_f = jax.vmap(f) @@ -383,8 +500,8 @@ def f(x): self.assertArraysEqual(vmapped_f(jnp.ones((8,))), jnp.ones((8,)) + 1.0) self.assertEqual(cache.keys(), [f.component_key, vmapped_key]) - self.assertEqual(aot_util._traced_cache[f.component_key].hits, 0) - self.assertEqual(aot_util._traced_cache[vmapped_key].hits, 1) + # self.assertEqual(aot_util._traced_cache[f.component_key].hits, 0) + # self.assertEqual(aot_util._traced_cache[vmapped_key].hits, 1) if __name__ == "__main__": From 340466d2e0fdefa62ce3bf89cc84bf98bc09114b Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 5 Nov 2025 11:47:18 -0500 Subject: [PATCH 11/24] Update --- jax/_src/aot.py | 25 ++++++---- jax/_src/linear_util.py | 5 ++ jax/_src/pjit.py | 10 ++++ jax/_src/xla_bridge.py | 2 +- tests/aot_test.py | 102 ++++++++++++++++++++++++++-------------- 5 files changed, 98 insertions(+), 46 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 8e829bee8cfc..386f53cbbbd5 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -24,6 +24,7 @@ from jax._src import api_util from jax._src import core from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib from jax._src import traceback_util from jax._src import tree_util from jax._src import util @@ -38,9 +39,7 @@ ComponentKey = aot_util.ComponentKey get_cache = aot_util.get_cache - -_fun_cache: dict[ComponentKey, Callable[..., Any]] = {} - +_wrapper_cache: dict[ComponentKey, xc._xla.PjitFunction] = {} def component( @@ -51,7 +50,9 @@ def _component(fun: Callable[..., Any]): # TODO(dsuo): Do we have all the information we need at this point to make # the component key? component_key = ComponentKey(key) - fun = _fun_cache.setdefault(component_key, fun) + + if component_key in _wrapper_cache: + return _wrapper_cache[component_key] @api.jit @util.wraps(fun) @@ -63,8 +64,9 @@ def wrapper(*args, **kwargs): ) flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) flat_fun = aot_util.cached_flat_fun(flat_fun) - logging.info("component flat_fun %s:", id(flat_fun)) + logging.info("miss component flat_fun %s:", id(flat_fun)) jitted_fun = api.jit(flat_fun.call_wrapped) + logging.info("miss component jitted_fun %s:", id(jitted_fun)) out_flat = component_p.bind( *args, @@ -75,6 +77,8 @@ def wrapper(*args, **kwargs): wrapper.component_key = component_key wrapper.fun = fun + logging.info("wrapper id %s", id(wrapper._fun)) + _wrapper_cache[component_key] = wrapper return wrapper return _component @@ -96,11 +100,12 @@ def component_abstract_eval( logging.info("component_abstract_eval got entry %s", component_key) if entry is None: logging.info("missed abstract_eval %s", component_key) - if isinstance(fun, lu.WrappedFun): - fun = aot_util.maybe_reset_stores(fun).call_wrapped - avals_out = tree_util.tree_map( - lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) - ) + # TODO(dsuo): By the time we get to lowering, our trace context has picked + # up an empty AbstractMesh. Don't know why. + with mesh_lib.use_abstract_mesh(mesh_lib.AbstractMesh((), (), ())): + avals_out = tree_util.tree_map( + lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) + ) aot_util.put_entry(component_key, entry := aot_util.CacheEntry(avals_out)) return entry.avals_out diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 967a2e23abea..60b75eaf35d9 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -482,14 +482,17 @@ def cache(call: Callable, *, A memoized version of ``call``. """ fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + hit_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore + hit = hit_caches.setdefault(fun.f, new_hit := {}) key = (fun.transforms, fun.params, fun.in_type, args, config.trace_context()) result = cache.get(key, None) if result is not None: ans, stores = result fun.populate_stores(stores) + hit[key] += 1 else: if do_explain := explain and config.explain_cache_misses.value: start = time.time() @@ -497,6 +500,7 @@ def memoized_fun(fun: WrappedFun, *args): if do_explain: explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) + hit[key] = 0 return ans @@ -508,6 +512,7 @@ def _evict_function(f): memoized_fun.cache_items = fun_caches.items # type: ignore memoized_fun.cache_get = fun_caches.get # type: ignore memoized_fun.cache_keys = fun_caches.keys # type: ignore + memoized_fun.hit_get = hit_caches.get # type: ignore register_cache(memoized_fun, str(call)) return memoized_fun diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 180953c01781..4c6079d5d680 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -254,6 +254,7 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): + logging.info('cpp_pjit fun: %s', id(fun)) # args do not include the const args # See https://docs.jax.dev/en/latest/internals/constants.html. if config.no_tracing.value: @@ -630,6 +631,7 @@ def _infer_params( return _infer_params_internal(fun, ji, args, kwargs) +cache = dict() def _infer_params_internal( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: @@ -651,12 +653,19 @@ def _infer_params_internal( entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh) if entry.pjit_params is None: + # if fun.__name__ not in ['add', 'equal']: + logging.info('missed infer params: %s %s', id(fun), fun.__name__) + # if fun.__name__ == 'f': + # import traceback + # traceback.print_stack() dbg = dbg_fn() p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) if p.params['jaxpr'].jaxpr.is_high: return p, p.consts + args_flat entry.pjit_params = p + else: + logging.info('hit infer params: %s %s', id(fun), fun.__name__) return entry.pjit_params, entry.pjit_params.consts + dynargs def _infer_input_type(fun: Callable, dbg_fn: Callable[[], core.DebugInfo], @@ -1196,6 +1205,7 @@ def _create_pjit_jaxpr( else: closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) final_consts = [] + logging.info('missed jaxpr: %s', fun.__name__) return closed_jaxpr, final_consts, global_out_avals diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 86a0f56528b9..35e658af6095 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -809,7 +809,7 @@ def backends() -> dict[str, xla_client.Client]: err_msg = f"Unable to initialize backend '{platform}': {err}" if fail_quietly: _backend_errors[platform] = str(err) - logger.info(err_msg) + # logger.info(err_msg) else: if config.jax_platforms.value: err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" diff --git a/tests/aot_test.py b/tests/aot_test.py index 8c5764147c09..8d32c54b5246 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -14,7 +14,7 @@ import contextlib import logging -from typing import Any, Callable +from typing import Any, Callable, Sequence import unittest from absl.testing import absltest @@ -295,18 +295,53 @@ def make_in_memory_cache(self): # TODO(dsuo): It would be nice to have a way to grab the pjit jaxpr cache # key easily. - def get_pjit_jaxpr_key( - self, fun: Callable[..., Any] - ) -> Callable[..., Any] | None: + def get_jaxpr_key(self, fun: Callable[..., Any]) -> Callable[..., Any] | None: for key in pjit._create_pjit_jaxpr.cache_keys(): f = key if not hasattr(key, "__wrapped__") else key.__wrapped__ if f == fun: return key - def get_pjit_jaxpr_entry( - self, key: Callable[..., Any] - ) -> dict[Callable[..., Any], dict[Any, Any]]: - return pjit._create_pjit_jaxpr.cache_get(key).items() + def validate_cache_states( + self, + fun: Callable[..., Any], + num_jaxpr_entries: int, + num_jaxpr_hits: int | Sequence[int], + num_trace_hits: int, + num_trace_misses: int, + num_disk_hits: int, + ): + cache = aot.get_cache() + component_key = fun.component_key + + # Verify component key exists in disk cache. + self.assertIn(component_key, cache.keys()) + + # Verify the number of disk hits. + self.assertEqual(cache.info(component_key)["hits"], num_disk_hits) + + jaxpr_key = self.get_jaxpr_key(fun.fun) + jaxpr_cache = pjit._create_pjit_jaxpr.cache_get(jaxpr_key) + jaxpr_hit = pjit._create_pjit_jaxpr.hit_get(jaxpr_key) + if isinstance(num_jaxpr_hits, int): + num_jaxpr_hits = (num_jaxpr_hits,) + + # Verify fun exists in the jaxpr cache. + self.assertIsNotNone(jaxpr_key) + + # Verify number of entries in jaxpr cache for fun. + self.assertEqual(len(jaxpr_cache), num_jaxpr_entries) + + # Verify number of hits for each entry. + self.assertEqual(tuple(jaxpr_hit.values()), num_jaxpr_hits) + + # Verify the number of hits and misses we expect. + self.assertEqual( + pjit._infer_params_cached.cache_info().hits, num_trace_hits + ) + + self.assertEqual( + pjit._infer_params_cached.cache_info().misses, num_trace_misses + ) # NOTE(dsuo): Disable checks because otherwise we check jaxprs in (at least) # four places and makes reasoning about cache hits and misses harder. @@ -324,35 +359,33 @@ def f(x): return x + 1.0 self.assertEqual(f(1.0), 2.0) - self.assertEqual(cache.keys(), [f.component_key]) - - # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_pjit_jaxpr_key(f.fun) - self.assertIsNotNone(pjit_key) - # Make sure there is only one entry for f.fun. If there are more, then it - # means the lowering rule missed. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) - # We get 1 hit on the disk cache during the lowering rule. However, this - # hit is for an incomplete CacheEntry i.e., only avals_out were populated - # and not the lowered module. The lowering rule updates the CacheEntry - # with the lowered module. - self.assertEqual(cache.info(f.component_key)["hits"], 1) + self.validate_cache_states( + f, + # Make sure there is only one entry for f.fun. If there are more, then + # it means the lowering rule missed. + num_jaxpr_entries=1, + # There should be no hits in the jaxpr cache for f.fun because we've + # only just created it. + num_jaxpr_hits=0, + # We should have 1 hit from the trace in lowering. + num_trace_hits=1, + # We should have 4 misses: add, equal, f, and call_wrapped. + num_trace_misses=4, + # We get 1 hit on the disk cache during the lowering rule. However, this + # hit is for an incomplete CacheEntry; only avals_out were populated + # and not the lowered module. The lowering rule updates the CacheEntry + # with the lowered module. + num_disk_hits=1, + ) @aot.component(key="f") def g(x): raise NotImplementedError - # We ignore g's implementation because it was turned into a component with - # key "f". self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) - # Confirm we still have just one entry in the jaxpr cache. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) - # We get two additional hits on the disk cache during abstract eval and - # lowering for g. - self.assertEqual(cache.info(f.component_key)["hits"], 3) + # Cache state should remain unchanged. + self.validate_cache_states(g, 1, 0, 1, 4, 1) @config.enable_checks(False) def test_component_in_function(self): @@ -372,7 +405,7 @@ def g(x): self.assertEqual(cache.keys(), [f.component_key]) self.assertEqual(cache.info(f.component_key)["hits"], 1) # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_pjit_jaxpr_key(f.fun) + pjit_key = self.get_jaxpr_key(f.fun) self.assertIsNotNone(pjit_key) # Make sure there is only one entry for f.fun. If there are more, then it # means the lowering rule missed. @@ -399,7 +432,7 @@ def f(x): # Create cache entry when abstract_eval f. 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_pjit_jaxpr_key(f.fun) + pjit_key = self.get_jaxpr_key(f.fun) self.assertIsNotNone(pjit_key) # Make sure there is only one entry for f.fun. If there are more, then it # means the lowering rule missed. @@ -421,7 +454,6 @@ def g(x): # lowering for g. self.assertEqual(cache.info(f.component_key)["hits"], 3) - @config.enable_checks(False) def test_component_of_jit(self): with self.make_in_memory_cache(): @@ -435,7 +467,7 @@ def f(x): # Create cache entry when abstract_eval f. 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_pjit_jaxpr_key(f.fun) + pjit_key = self.get_jaxpr_key(f.fun) self.assertIsNotNone(pjit_key) # Make sure there is only one entry for f.fun. If there are more, then it # means the lowering rule missed. @@ -469,7 +501,7 @@ def f(x): lowered = f.lower(jax.ShapeDtypeStruct((), "float32")) self.assertEqual(cache.keys(), [f.component_key]) - pjit_key = self.get_pjit_jaxpr_key(f.fun) + pjit_key = self.get_jaxpr_key(f.fun) self.assertIsNotNone(pjit_key) # Make sure there is only one entry for f.fun. If there are more, then it # means the lowering rule missed. From 06a60ddb54e4bcf8adb88a5639a4d6d7a1e07b36 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 5 Nov 2025 13:03:46 -0500 Subject: [PATCH 12/24] Update --- jax/_src/aot.py | 12 ++---- jax/_src/aot_util.py | 91 +++++++++++++++++++++++++++++--------------- tests/aot_test.py | 48 ++++++++++++----------- 3 files changed, 89 insertions(+), 62 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 386f53cbbbd5..ee1ffc5ed5a9 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -30,7 +30,6 @@ from jax._src import util from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect @@ -39,7 +38,6 @@ ComponentKey = aot_util.ComponentKey get_cache = aot_util.get_cache -_wrapper_cache: dict[ComponentKey, xc._xla.PjitFunction] = {} def component( @@ -51,8 +49,8 @@ def _component(fun: Callable[..., Any]): # the component key? component_key = ComponentKey(key) - if component_key in _wrapper_cache: - return _wrapper_cache[component_key] + if component_key in aot_util._wrapper_cache.cache_keys(): + return aot_util._wrapper_cache.get(component_key) @api.jit @util.wraps(fun) @@ -78,7 +76,7 @@ def wrapper(*args, **kwargs): wrapper.component_key = component_key wrapper.fun = fun logging.info("wrapper id %s", id(wrapper._fun)) - _wrapper_cache[component_key] = wrapper + aot_util._wrapper_cache.put(component_key, wrapper) return wrapper return _component @@ -204,10 +202,6 @@ def component_batcher( return vals_out, dims_out() -def clear_caches(): - get_cache().clear() - - component_p = core.Primitive("component") component_p.multiple_results = True component_p.def_impl(component_impl) diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index dbd9ab0898ec..be6d01ab2557 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -24,7 +24,9 @@ from jax._src import core from jax._src import linear_util as lu from jax._src import stages +from jax._src import util from jax._src.interpreters import mlir +from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect @@ -56,7 +58,10 @@ def vmap(cls, key): def _validate_component_cache(val): + logging.info("Validating component cache config.") assert val is None or isinstance(val, Cache) + if val is not None: + util.register_cache(val, "aot_cache") component_cache = config.string_or_object_state( @@ -95,47 +100,40 @@ def deserialize( return cls(avals_out, module) -class Cache(NamedTuple): - get: Callable[[ComponentKey], bytes | None] - put: Callable[[ComponentKey, bytes], None] - keys: Callable[[], list[ComponentKey]] - clear: Callable[[], None] - info: Callable[[ComponentKey], dict[str, Any]] +# TODO(dsuo): This should be a protocol. +class Cache: + def __init__(self): + self._in_memory_cache: dict[ComponentKey, SerializedType] = {} + self._in_memory_cache_info: dict[ComponentKey, dict[str, Any]] = {} - -_in_memory_cache: dict[ComponentKey, SerializedType] = {} -_in_memory_cache_info: dict[ComponentKey, dict[str, Any]] = {} - - -def make_in_memory_cache(): - def get(key: ComponentKey) -> SerializedType | None: - entry = _in_memory_cache.get(key, None) + def get(self, key: ComponentKey) -> SerializedType | None: + entry = self._in_memory_cache.get(key, None) if entry is None: - _in_memory_cache_info[key] = dict(hits=0) + self._in_memory_cache_info[key] = dict(hits=0) else: - _in_memory_cache_info[key] = dict( - hits=_in_memory_cache_info[key]["hits"] + 1 + self._in_memory_cache_info[key] = dict( + hits=self._in_memory_cache_info[key]["hits"] + 1 ) return entry - def put(key: ComponentKey, data: SerializedType, update: bool): - _in_memory_cache[key] = data + def put(self, key: ComponentKey, data: SerializedType, update: bool): + self._in_memory_cache[key] = data if not update: - _in_memory_cache_info[key] = dict(hits=0) - - def keys() -> list[ComponentKey]: - return list(_in_memory_cache.keys()) + self._in_memory_cache_info[key] = dict(hits=0) - def clear() -> None: - _in_memory_cache.clear() - _in_memory_cache_info.clear() + def cache_keys( + self, + ) -> list[ComponentKey]: + return list(self._in_memory_cache.keys()) - def info(key: ComponentKey) -> dict[str, Any]: - if key not in _in_memory_cache_info: - raise ValueError(f"`{key}` not found in _in_memory_cache_info") - return _in_memory_cache_info[key] + def cache_clear(self) -> None: + self._in_memory_cache.clear() + self._in_memory_cache_info.clear() - return Cache(get, put, keys, clear, info) + def cache_info(self, key: ComponentKey) -> dict[str, Any]: + if key not in self._in_memory_cache_info: + raise ValueError(f"`{key}` not found in self._in_memory_cache_info") + return self._in_memory_cache_info[key] def get_cache() -> Cache | None: @@ -157,6 +155,7 @@ def put_entry( if (cache := get_cache()) is not None: cache.put(key, entry.serialize(), update) + @lu.cache def cached_flat_fun(flat_fun): return maybe_reset_stores(flat_fun) @@ -176,3 +175,33 @@ def reset_stores_f_transformed(*args, **kwargs): fun.f_transformed = reset_stores_f_transformed return fun + + +class WrapperCache: + def __init__(self): + self.data = dict() + self.info = dict() + + def get(self, key: ComponentKey) -> xc._xla.PjitFunction | None: + fun = self.data.get(key, None) + if fun is not None: + self.info[key]["hits"] += 1 + return fun + + def put(self, key: ComponentKey, fun: xc._xla.PjitFunction): + fun = self.data.setdefault(key, fun) + info = self.info.setdefault(key, dict(hits=0)) + self.info[key]["hits"] = 0 + + def cache_info(self): + return self.info + + def cache_clear(self): + self.data.clear() + + def cache_keys(self): + return self.data.keys() + + +_wrapper_cache = WrapperCache() +util.register_cache(_wrapper_cache, "aot_wrapper_cache") diff --git a/tests/aot_test.py b/tests/aot_test.py index 8d32c54b5246..6cb84886a943 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -287,10 +287,9 @@ def f(x): class ComponentTest(jtu.JaxTestCase): @contextlib.contextmanager def make_in_memory_cache(self): - cache = aot_util.make_in_memory_cache() + cache = aot_util.Cache() with aot_util.component_cache(cache): yield - aot.clear_caches() jax.clear_caches() # TODO(dsuo): It would be nice to have a way to grab the pjit jaxpr cache @@ -308,16 +307,22 @@ def validate_cache_states( num_jaxpr_hits: int | Sequence[int], num_trace_hits: int, num_trace_misses: int, + num_wrapper_hits: int, num_disk_hits: int, ): cache = aot.get_cache() - component_key = fun.component_key + component_key = fun.component_key # type: ignore # Verify component key exists in disk cache. - self.assertIn(component_key, cache.keys()) + self.assertIn(component_key, cache.cache_keys()) + # Verify the number of wrapper cache hits. + self.assertEqual( + aot_util._wrapper_cache.cache_info()[component_key]["hits"], + num_wrapper_hits, + ) # Verify the number of disk hits. - self.assertEqual(cache.info(component_key)["hits"], num_disk_hits) + self.assertEqual(cache.cache_info(component_key)["hits"], num_disk_hits) jaxpr_key = self.get_jaxpr_key(fun.fun) jaxpr_cache = pjit._create_pjit_jaxpr.cache_get(jaxpr_key) @@ -338,7 +343,6 @@ def validate_cache_states( self.assertEqual( pjit._infer_params_cached.cache_info().hits, num_trace_hits ) - self.assertEqual( pjit._infer_params_cached.cache_info().misses, num_trace_misses ) @@ -371,6 +375,8 @@ def f(x): num_trace_hits=1, # We should have 4 misses: add, equal, f, and call_wrapped. num_trace_misses=4, + # We shouldn't have hit the wrapper cache yet. + num_wrapper_hits=0, # We get 1 hit on the disk cache during the lowering rule. However, this # hit is for an incomplete CacheEntry; only avals_out were populated # and not the lowered module. The lowering rule updates the CacheEntry @@ -384,8 +390,8 @@ def g(x): self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) - # Cache state should remain unchanged. - self.validate_cache_states(g, 1, 0, 1, 4, 1) + # Cache state should remain unchanged except we grabbed the wrapped fun. + self.validate_cache_states(g, 1, 0, 1, 4, 1, 1) @config.enable_checks(False) def test_component_in_function(self): @@ -400,24 +406,22 @@ def f(x): def g(x): return f(x) + 1.0 - # Create cache entry when abstract_eval f. 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) - self.assertEqual(cache.keys(), [f.component_key]) - self.assertEqual(cache.info(f.component_key)["hits"], 1) - # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_jaxpr_key(f.fun) - self.assertIsNotNone(pjit_key) - # Make sure there is only one entry for f.fun. If there are more, then it - # means the lowering rule missed. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) + + # We should have the same cache states as in test_component_basic. + self.validate_cache_states(f, 1, 0, 1, 4, 1) + + logging.info("\n\n\n") + # 1 hit when lowering g. g is not a component, so doesn't look up # CacheEntry during abstract_eval. self.assertEqual(g(1.0), 3.0) - # Make sure we didn't add any new entries for f.fun. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) - self.assertEqual(cache.info(f.component_key)["hits"], 2) + # We incur one more missed trace for g and + self.validate_cache_states(f, 1, 0, 2, 6, 2) + # # Make sure we didn't add any new entries for f.fun. + # num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) + # self.assertEqual(num_entries, 1) + # self.assertEqual(cache.info(f.component_key)["hits"], 2) @config.enable_checks(False) def test_jit_of_component(self): From 7b04e0b5ad8c7ec8b9fad83e2dc812499bc6c39a Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 5 Nov 2025 14:39:01 -0500 Subject: [PATCH 13/24] Update --- tests/aot_test.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/aot_test.py b/tests/aot_test.py index 6cb84886a943..0457b410683e 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -409,19 +409,16 @@ def g(x): self.assertEqual(f(1.0), 2.0) # We should have the same cache states as in test_component_basic. - self.validate_cache_states(f, 1, 0, 1, 4, 1) + self.validate_cache_states(f, 1, 0, 1, 4, 0, 1) logging.info("\n\n\n") # 1 hit when lowering g. g is not a component, so doesn't look up # CacheEntry during abstract_eval. self.assertEqual(g(1.0), 3.0) - # We incur one more missed trace for g and - self.validate_cache_states(f, 1, 0, 2, 6, 2) - # # Make sure we didn't add any new entries for f.fun. - # num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - # self.assertEqual(num_entries, 1) - # self.assertEqual(cache.info(f.component_key)["hits"], 2) + # We have one more trace cache hit on f from tracing g, two more misses, + # one for g and one for add, and one more disk cache hit from lowering g. + self.validate_cache_states(f, 1, 0, 2, 6, 0, 2) @config.enable_checks(False) def test_jit_of_component(self): From ecb9c721c5520a9fa20ed8acd408d53136912e58 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 10:34:08 -0500 Subject: [PATCH 14/24] Update --- jax/_src/aot.py | 16 +++++- jax/_src/aot_util.py | 4 +- jax/_src/pjit.py | 9 ++- jax/_src/traceback_util.py | 34 +++++------ tests/aot_test.py | 112 ++++++++++++++++--------------------- 5 files changed, 83 insertions(+), 92 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index ee1ffc5ed5a9..be4afc95a2b1 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -50,6 +50,7 @@ def _component(fun: Callable[..., Any]): component_key = ComponentKey(key) if component_key in aot_util._wrapper_cache.cache_keys(): + logging.info('hit wrapper_cache: %s', component_key) return aot_util._wrapper_cache.get(component_key) @api.jit @@ -63,7 +64,9 @@ def wrapper(*args, **kwargs): flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) flat_fun = aot_util.cached_flat_fun(flat_fun) logging.info("miss component flat_fun %s:", id(flat_fun)) - jitted_fun = api.jit(flat_fun.call_wrapped) + flat_fun = flat_fun.f_transformed + flat_fun.__name__ = 'wrapped(flat_fun)' + jitted_fun = api.jit(flat_fun) logging.info("miss component jitted_fun %s:", id(jitted_fun)) out_flat = component_p.bind( @@ -75,7 +78,10 @@ def wrapper(*args, **kwargs): wrapper.component_key = component_key wrapper.fun = fun - logging.info("wrapper id %s", id(wrapper._fun)) + logging.info("jit(wrapper(fun)) wrapper id %s", id(wrapper)) + logging.info("wrapper(fun) wrapper._fun id %s", id(wrapper._fun)) + logging.info("fun wrapper._fun.__wrapped__ id %s", id(wrapper._fun.__wrapped__)) + logging.info("user fun id %s", id(fun)) aot_util._wrapper_cache.put(component_key, wrapper) return wrapper @@ -105,6 +111,8 @@ def component_abstract_eval( lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) ) aot_util.put_entry(component_key, entry := aot_util.CacheEntry(avals_out)) + else: + logging.info("hit abstract_eval %s", component_key) return entry.avals_out @@ -156,6 +164,8 @@ def component_lowering( # the right context? entry.module = module = ir.Module.parse(mlir.module_to_bytecode(module)) aot_util.put_entry(component_key, entry, update=True) + else: + logging.info("hit lowering: %s", component_key) symtab = ir.SymbolTable(module.operation) module = mlir.merge_mlir_modules( @@ -196,7 +206,7 @@ def component_batcher( vals_out = component_p.bind( *vals_in, - fun=batched_fun, + fun=batched_fun.f_transformed, component_key=ComponentKey.vmap(component_key), ) return vals_out, dims_out() diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index be6d01ab2557..22e24db00388 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -108,9 +108,7 @@ def __init__(self): def get(self, key: ComponentKey) -> SerializedType | None: entry = self._in_memory_cache.get(key, None) - if entry is None: - self._in_memory_cache_info[key] = dict(hits=0) - else: + if entry is not None: self._in_memory_cache_info[key] = dict( hits=self._in_memory_cache_info[key]["hits"] + 1 ) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4c6079d5d680..689c63dd2b7e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -254,7 +254,7 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): - logging.info('cpp_pjit fun: %s', id(fun)) + logging.info('cpp_pjit fun: %s %s', id(fun), fun.__name__) # args do not include the const args # See https://docs.jax.dev/en/latest/internals/constants.html. if config.no_tracing.value: @@ -653,8 +653,11 @@ def _infer_params_internal( entry = _infer_params_cached(fun, ji, signature, avals, ctx_mesh) if entry.pjit_params is None: - # if fun.__name__ not in ['add', 'equal']: - logging.info('missed infer params: %s %s', id(fun), fun.__name__) + if isinstance(fun, partial): + name = f'partial({fun.func.__name__})' + else: + name = fun.__name__ + logging.info('missed infer params: %s %s', id(fun), name) # if fun.__name__ == 'f': # import traceback # traceback.print_stack() diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index b7641c209589..e99db5ca546c 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -160,10 +160,8 @@ def _filtering_mode() -> str: mode = "quiet_remove_frames" return mode -def api_boundary( - fun: C, *, - repro_api_name: str | None = None, - repro_user_func: bool = False) -> C: +import logging +def api_boundary(fun: C) -> C: '''Wraps ``fun`` to form a boundary for filtering exception tracebacks. When an exception occurs below ``fun``, this appends to it a custom @@ -184,10 +182,10 @@ def api_boundary( ``g``. Because the function returned by :func:`~jax.jit` is annotated as an :func:`~api_boundary`, such an exception is accompanied by an additional traceback that excludes the frames specific to JAX's implementation. - - For the "repro" kwargs, see the comments for `repro.boundary`. ''' + if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: + logging.info("api_boundary fun id %s", id(fun)) @functools.wraps(fun) def reraise_with_filtered_traceback(*args, **kwargs): __tracebackhide__ = True @@ -223,17 +221,13 @@ def reraise_with_filtered_traceback(*args, **kwargs): raise finally: del mode, tb - if (repro_api_name or repro_user_func) and repro: - reraise_with_filtered_traceback = repro.boundary( - reraise_with_filtered_traceback, api_name=repro_api_name, - is_user=repro_user_func) - return cast(C, reraise_with_filtered_traceback) - -try: - # TODO: import from the final location - from jax._src import repro # type: ignore - repro_is_enabled = repro.is_enabled - -except ImportError: - repro = None # type: ignore - def repro_is_enabled(): return False # type: ignore + if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: + logging.info("reraise_with_filtered_traceback id %s", + id(reraise_with_filtered_traceback)) + casted = cast(C, reraise_with_filtered_traceback) + + if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: + logging.info("casted reraise id %s", + id(casted)) + + return reraise_with_filtered_traceback diff --git a/tests/aot_test.py b/tests/aot_test.py index 0457b410683e..0e9135186f8a 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -303,6 +303,7 @@ def get_jaxpr_key(self, fun: Callable[..., Any]) -> Callable[..., Any] | None: def validate_cache_states( self, fun: Callable[..., Any], + component_key: aot_util.ComponentKey, num_jaxpr_entries: int, num_jaxpr_hits: int | Sequence[int], num_trace_hits: int, @@ -311,31 +312,16 @@ def validate_cache_states( num_disk_hits: int, ): cache = aot.get_cache() - component_key = fun.component_key # type: ignore - # Verify component key exists in disk cache. - self.assertIn(component_key, cache.cache_keys()) - # Verify the number of wrapper cache hits. - self.assertEqual( - aot_util._wrapper_cache.cache_info()[component_key]["hits"], - num_wrapper_hits, - ) - - # Verify the number of disk hits. - self.assertEqual(cache.cache_info(component_key)["hits"], num_disk_hits) - - jaxpr_key = self.get_jaxpr_key(fun.fun) + jaxpr_key = self.get_jaxpr_key(fun) jaxpr_cache = pjit._create_pjit_jaxpr.cache_get(jaxpr_key) jaxpr_hit = pjit._create_pjit_jaxpr.hit_get(jaxpr_key) if isinstance(num_jaxpr_hits, int): num_jaxpr_hits = (num_jaxpr_hits,) - # Verify fun exists in the jaxpr cache. self.assertIsNotNone(jaxpr_key) - # Verify number of entries in jaxpr cache for fun. self.assertEqual(len(jaxpr_cache), num_jaxpr_entries) - # Verify number of hits for each entry. self.assertEqual(tuple(jaxpr_hit.values()), num_jaxpr_hits) @@ -347,6 +333,19 @@ def validate_cache_states( pjit._infer_params_cached.cache_info().misses, num_trace_misses ) + # Verify component key exists in disk cache. + self.assertIn(component_key, cache.cache_keys()) + + if num_wrapper_hits > 0: + # Verify the number of wrapper cache hits. + self.assertEqual( + aot_util._wrapper_cache.cache_info()[component_key]["hits"], + num_wrapper_hits, + ) + + # Verify the number of disk hits. + self.assertEqual(cache.cache_info(component_key)["hits"], num_disk_hits) + # NOTE(dsuo): Disable checks because otherwise we check jaxprs in (at least) # four places and makes reasoning about cache hits and misses harder. # 1. After the initial abstract eval. @@ -364,7 +363,8 @@ def f(x): self.assertEqual(f(1.0), 2.0) self.validate_cache_states( - f, + f.fun, + f.component_key, # Make sure there is only one entry for f.fun. If there are more, then # it means the lowering rule missed. num_jaxpr_entries=1, @@ -391,7 +391,7 @@ def g(x): self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) # Cache state should remain unchanged except we grabbed the wrapped fun. - self.validate_cache_states(g, 1, 0, 1, 4, 1, 1) + self.validate_cache_states(g.fun, g.component_key, 1, 0, 1, 4, 1, 1) @config.enable_checks(False) def test_component_in_function(self): @@ -409,16 +409,14 @@ def g(x): self.assertEqual(f(1.0), 2.0) # We should have the same cache states as in test_component_basic. - self.validate_cache_states(f, 1, 0, 1, 4, 0, 1) - - logging.info("\n\n\n") + self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 4, 0, 1) # 1 hit when lowering g. g is not a component, so doesn't look up # CacheEntry during abstract_eval. self.assertEqual(g(1.0), 3.0) # We have one more trace cache hit on f from tracing g, two more misses, # one for g and one for add, and one more disk cache hit from lowering g. - self.validate_cache_states(f, 1, 0, 2, 6, 0, 2) + self.validate_cache_states(f.fun, f.component_key, 1, 0, 2, 6, 0, 2) @config.enable_checks(False) def test_jit_of_component(self): @@ -432,13 +430,9 @@ def f(x): # Create cache entry when abstract_eval f. 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) - # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_jaxpr_key(f.fun) - self.assertIsNotNone(pjit_key) - # Make sure there is only one entry for f.fun. If there are more, then it - # means the lowering rule missed. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) + # We should have the same cache states as in test_component_basic except + # one additional infer params cache miss for the outermost jitted f. + self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 5, 0, 1) @aot.component(key="f") def g(x): @@ -448,12 +442,9 @@ def g(x): # key "f". self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) - # Confirm we still have just one entry in the jaxpr cache. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) - # We get two additional hits on the disk cache during abstract eval and - # lowering for g. - self.assertEqual(cache.info(f.component_key)["hits"], 3) + # We have one more hit in infer params cache for the inner f and one more + # hit in the disk cache for the lowering of f. + self.validate_cache_states(g.fun, g.component_key, 1, 0, 2, 5, 1, 2) @config.enable_checks(False) def test_component_of_jit(self): @@ -465,15 +456,10 @@ def test_component_of_jit(self): def f(x): return x + 1.0 - # Create cache entry when abstract_eval f. 1 hit when lowering f. self.assertEqual(f(1.0), 2.0) - # Make sure the underlying function f.fun exists in the jaxpr cache. - pjit_key = self.get_jaxpr_key(f.fun) - self.assertIsNotNone(pjit_key) - # Make sure there is only one entry for f.fun. If there are more, then it - # means the lowering rule missed. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) + # We should have the same cache states as in test_component_basic except + # one additional infer params cache miss for the outermost jitted f. + self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 5, 0, 1) @aot.component(key="f") def g(x): @@ -483,12 +469,9 @@ def g(x): # key "f". self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) - # Confirm we still have just one entry in the jaxpr cache. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) - # We get two additional hits on the disk cache during abstract eval and - # lowering for g. - self.assertEqual(cache.info(f.component_key)["hits"], 3) + logging.info(g(1.0)) + # We have one hit in the wrapper cache. + self.validate_cache_states(g.fun, g.component_key, 1, 0, 1, 5, 1, 1) @config.enable_checks(False) def test_explicit_lowering(self): @@ -500,41 +483,44 @@ def f(x): return x + 1.0 lowered = f.lower(jax.ShapeDtypeStruct((), "float32")) - self.assertEqual(cache.keys(), [f.component_key]) + # One less infer params cache miss because we just have add, f, and + # call_wrapped; no equal. + self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 3, 0, 1) - pjit_key = self.get_jaxpr_key(f.fun) - self.assertIsNotNone(pjit_key) - # Make sure there is only one entry for f.fun. If there are more, then it - # means the lowering rule missed. - num_entries = len(list(self.get_pjit_jaxpr_entry(pjit_key))) - self.assertEqual(num_entries, 1) + logging.info("\n\n\n") @aot.component(key="f") def g(x): raise NotImplementedError lowered = g.lower(jax.ShapeDtypeStruct((), "float32")) - self.assertEqual(cache.info(f.component_key)["hits"], 3) + # We hit the wrapper cache because we have the same component key, but + # because we explicitly call lower, we hit infer params cache again. + # TODO(dsuo): Do we want this behavior? + self.validate_cache_states(f.fun, f.component_key, 1, 0, 2, 3, 1, 1) @config.enable_checks(False) def test_vmap_of_component(self): with self.make_in_memory_cache(): cache = aot.get_cache() - @aot.component(key="f") def f(x): logging.info("running!") return x + 1.0 - vmapped_f = jax.vmap(f) + logging.info("\n\nuser fun id %s", id(f)) + component_f = aot.component(key="f")(f) + logging.info("\n\ncomponent_f id %s", id(component_f)) + vmapped_f = jax.vmap(component_f) + logging.info("\n\nvmapped_f id %s", id(vmapped_f)) + logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. vmapped_key = aot_util.ComponentKey.vmap(aot_util.ComponentKey("f")) - self.assertArraysEqual(vmapped_f(jnp.ones((8,))), jnp.ones((8,)) + 1.0) - self.assertEqual(cache.keys(), [f.component_key, vmapped_key]) - # self.assertEqual(aot_util._traced_cache[f.component_key].hits, 0) - # self.assertEqual(aot_util._traced_cache[vmapped_key].hits, 1) + self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) + self.validate_cache_states(component_f.fun, component_f.component_key, 1, 0, 1, 7, 0, 0) + self.validate_cache_states(component_f.fun, vmapped_key, 1, 0, 1, 7, 0, 1) if __name__ == "__main__": From 2c923cc88479a21d77e88126789b6c4d0eedddde Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:21:36 -0800 Subject: [PATCH 15/24] Update --- jax/_src/aot.py | 19 +++++++++++-- tests/aot_test.py | 71 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index be4afc95a2b1..5206d7ee4fe6 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -14,6 +14,7 @@ """JAX AOT API""" from collections.abc import Hashable +import functools import traceback from typing import Any, Callable, Sequence @@ -76,7 +77,7 @@ def wrapper(*args, **kwargs): ) return tree_util.tree_unflatten(out_tree(), out_flat) - wrapper.component_key = component_key + wrapper.key = component_key wrapper.fun = fun logging.info("jit(wrapper(fun)) wrapper id %s", id(wrapper)) logging.info("wrapper(fun) wrapper._fun id %s", id(wrapper._fun)) @@ -103,9 +104,14 @@ def component_abstract_eval( entry = aot_util.get_entry(component_key) logging.info("component_abstract_eval got entry %s", component_key) if entry is None: - logging.info("missed abstract_eval %s", component_key) + logging.info("missed abstract_eval %s %s", component_key, type(fun)) # TODO(dsuo): By the time we get to lowering, our trace context has picked # up an empty AbstractMesh. Don't know why. + if isinstance(fun, functools.partial): + logging.info("abstract_eval partial %s", fun.func.__name__) + if isinstance(fun, lu.WrappedFun): + logging.info("abstract_eval lu.WrappedFun") + fun = aot_util.maybe_reset_stores(fun).call_wrapped with mesh_lib.use_abstract_mesh(mesh_lib.AbstractMesh((), (), ())): avals_out = tree_util.tree_map( lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) @@ -195,14 +201,21 @@ def component_batcher( # TODO(dsuo): Ignore updating annotations. # TODO(dsuo): Dummy debug info. + if isinstance(fun, functools.partial): + name = fun.func.__name__ + else: + name = fun.__name__ + if isinstance(fun, lu.WrappedFun): + fun = aot_util.maybe_reset_stores(fun) wrapped_fun = lu.wrap_init( - fun, debug_info=lu.DebugInfo("vmap(component)", fun.__name__, None, None) + fun, debug_info=lu.DebugInfo("vmap(component)", name, None, None) ) # ????(dsuo): I don't understand trace tags. batched_fun, dims_out = batching.batch_subtrace( wrapped_fun, core.TraceTag(), axis_data, tuple(dims_in) ) + batched_fun = aot_util.maybe_reset_stores(batched_fun) vals_out = component_p.bind( *vals_in, diff --git a/tests/aot_test.py b/tests/aot_test.py index 0e9135186f8a..e3bd51393f25 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -308,7 +308,7 @@ def validate_cache_states( num_jaxpr_hits: int | Sequence[int], num_trace_hits: int, num_trace_misses: int, - num_wrapper_hits: int, + num_wrapper_hits: int | None, num_disk_hits: int, ): cache = aot.get_cache() @@ -336,7 +336,8 @@ def validate_cache_states( # Verify component key exists in disk cache. self.assertIn(component_key, cache.cache_keys()) - if num_wrapper_hits > 0: + # If we don't have wrapper hits (0) or we don't expect component_key (-1). + if num_wrapper_hits is not None: # Verify the number of wrapper cache hits. self.assertEqual( aot_util._wrapper_cache.cache_info()[component_key]["hits"], @@ -364,7 +365,7 @@ def f(x): self.assertEqual(f(1.0), 2.0) self.validate_cache_states( f.fun, - f.component_key, + f.key, # Make sure there is only one entry for f.fun. If there are more, then # it means the lowering rule missed. num_jaxpr_entries=1, @@ -376,7 +377,7 @@ def f(x): # We should have 4 misses: add, equal, f, and call_wrapped. num_trace_misses=4, # We shouldn't have hit the wrapper cache yet. - num_wrapper_hits=0, + num_wrapper_hits=None, # We get 1 hit on the disk cache during the lowering rule. However, this # hit is for an incomplete CacheEntry; only avals_out were populated # and not the lowered module. The lowering rule updates the CacheEntry @@ -391,7 +392,7 @@ def g(x): self.assertEqual(f.fun, g.fun) self.assertEqual(g(1.0), 2.0) # Cache state should remain unchanged except we grabbed the wrapped fun. - self.validate_cache_states(g.fun, g.component_key, 1, 0, 1, 4, 1, 1) + self.validate_cache_states(g.fun, g.key, 1, 0, 1, 4, 1, 1) @config.enable_checks(False) def test_component_in_function(self): @@ -409,14 +410,14 @@ def g(x): self.assertEqual(f(1.0), 2.0) # We should have the same cache states as in test_component_basic. - self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 4, 0, 1) + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 4, None, 1) # 1 hit when lowering g. g is not a component, so doesn't look up # CacheEntry during abstract_eval. self.assertEqual(g(1.0), 3.0) # We have one more trace cache hit on f from tracing g, two more misses, # one for g and one for add, and one more disk cache hit from lowering g. - self.validate_cache_states(f.fun, f.component_key, 1, 0, 2, 6, 0, 2) + self.validate_cache_states(f.fun, f.key, 1, 0, 2, 6, None, 2) @config.enable_checks(False) def test_jit_of_component(self): @@ -432,7 +433,7 @@ def f(x): self.assertEqual(f(1.0), 2.0) # We should have the same cache states as in test_component_basic except # one additional infer params cache miss for the outermost jitted f. - self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 5, 0, 1) + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 5, None, 1) @aot.component(key="f") def g(x): @@ -444,7 +445,7 @@ def g(x): self.assertEqual(g(1.0), 2.0) # We have one more hit in infer params cache for the inner f and one more # hit in the disk cache for the lowering of f. - self.validate_cache_states(g.fun, g.component_key, 1, 0, 2, 5, 1, 2) + self.validate_cache_states(g.fun, g.key, 1, 0, 2, 5, 1, 2) @config.enable_checks(False) def test_component_of_jit(self): @@ -459,7 +460,7 @@ def f(x): self.assertEqual(f(1.0), 2.0) # We should have the same cache states as in test_component_basic except # one additional infer params cache miss for the outermost jitted f. - self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 5, 0, 1) + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 5, None, 1) @aot.component(key="f") def g(x): @@ -471,7 +472,7 @@ def g(x): self.assertEqual(g(1.0), 2.0) logging.info(g(1.0)) # We have one hit in the wrapper cache. - self.validate_cache_states(g.fun, g.component_key, 1, 0, 1, 5, 1, 1) + self.validate_cache_states(g.fun, g.key, 1, 0, 1, 5, 1, 1) @config.enable_checks(False) def test_explicit_lowering(self): @@ -485,7 +486,7 @@ def f(x): lowered = f.lower(jax.ShapeDtypeStruct((), "float32")) # One less infer params cache miss because we just have add, f, and # call_wrapped; no equal. - self.validate_cache_states(f.fun, f.component_key, 1, 0, 1, 3, 0, 1) + self.validate_cache_states(f.fun, f.key, 1, 0, 1, 3, None, 1) logging.info("\n\n\n") @@ -497,7 +498,7 @@ def g(x): # We hit the wrapper cache because we have the same component key, but # because we explicitly call lower, we hit infer params cache again. # TODO(dsuo): Do we want this behavior? - self.validate_cache_states(f.fun, f.component_key, 1, 0, 2, 3, 1, 1) + self.validate_cache_states(f.fun, f.key, 1, 0, 2, 3, 1, 1) @config.enable_checks(False) def test_vmap_of_component(self): @@ -516,11 +517,49 @@ def f(x): logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. - vmapped_key = aot_util.ComponentKey.vmap(aot_util.ComponentKey("f")) + vmapped_key = aot_util.ComponentKey.vmap(component_f.key) self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) - self.validate_cache_states(component_f.fun, component_f.component_key, 1, 0, 1, 7, 0, 0) - self.validate_cache_states(component_f.fun, vmapped_key, 1, 0, 1, 7, 0, 1) + self.validate_cache_states( + component_f.fun, component_f.key, 1, 0, 1, 7, 0, 0 + ) + self.validate_cache_states(component_f.fun, vmapped_key, 1, 0, 1, 7, None, 1) + + @config.enable_checks(False) + def test_vmap_of_vmap_of_component(self): + with self.make_in_memory_cache(): + cache = aot.get_cache() + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + c_f = aot.component(key="f")(f) + logging.info("\n\nc_f id %s", id(c_f)) + self.assertEqual(c_f(1.0), 2.0) + self.validate_cache_states(c_f.fun, c_f.key, 1, 0, 1, 4, 0, 1) + v_f = jax.vmap(c_f) + logging.info("\n\nv_f id %s", id(v_f)) + self.assertArraysEqual(v_f(jnp.ones((4,))), jnp.ones((4,)) + 1.0) + # TODO(dsuo): How to put component key on vmapped? + # We now have 2 entries, one for f and one for vmap(f). + self.validate_cache_states( + c_f.fun, aot_util.ComponentKey.vmap(c_f.key), 2, (0, 0), 1, 12, None, 1 + ) + vv_f = jax.vmap(v_f) + logging.info("\n\nvv_f id %s", id(vv_f)) + self.assertArraysEqual(vv_f(jnp.ones((4, 4,))), jnp.ones((4, 4)) + 1.0) + self.validate_cache_states( + c_f.fun, + aot_util.ComponentKey.vmap(aot_util.ComponentKey.vmap(c_f.key)), + 2, + (0, 0), + 3, + 16, + None, + 1, + ) if __name__ == "__main__": From db5764bf9c546c8b61a27b0e43f08afcec8db867 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:32:09 -0800 Subject: [PATCH 16/24] Update --- tests/aot_test.py | 72 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/tests/aot_test.py b/tests/aot_test.py index e3bd51393f25..cde8c79f60ad 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -523,7 +523,9 @@ def f(x): self.validate_cache_states( component_f.fun, component_f.key, 1, 0, 1, 7, 0, 0 ) - self.validate_cache_states(component_f.fun, vmapped_key, 1, 0, 1, 7, None, 1) + self.validate_cache_states( + component_f.fun, vmapped_key, 1, 0, 1, 7, None, 1 + ) @config.enable_checks(False) def test_vmap_of_vmap_of_component(self): @@ -549,7 +551,15 @@ def f(x): ) vv_f = jax.vmap(v_f) logging.info("\n\nvv_f id %s", id(vv_f)) - self.assertArraysEqual(vv_f(jnp.ones((4, 4,))), jnp.ones((4, 4)) + 1.0) + self.assertArraysEqual( + vv_f( + jnp.ones(( + 4, + 4, + )) + ), + jnp.ones((4, 4)) + 1.0, + ) self.validate_cache_states( c_f.fun, aot_util.ComponentKey.vmap(aot_util.ComponentKey.vmap(c_f.key)), @@ -561,6 +571,64 @@ def f(x): 1, ) + @config.enable_checks(False) + def test_vmap_of_jit_of_component(self): + # NOTE: This should be the same as test_vmap_of_component except for one + # more infer params cache miss because of the extra jit. + with self.make_in_memory_cache(): + cache = aot.get_cache() + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + component_f = jax.jit(aot.component(key="f")(f)) + logging.info("\n\ncomponent_f id %s", id(component_f)) + vmapped_f = jax.vmap(component_f) + logging.info("\n\nvmapped_f id %s", id(vmapped_f)) + logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) + + # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. + vmapped_key = aot_util.ComponentKey.vmap(component_f.key) + + self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) + self.validate_cache_states( + component_f.fun, component_f.key, 1, 0, 1, 8, 0, 0 + ) + self.validate_cache_states( + component_f.fun, vmapped_key, 1, 0, 1, 8, None, 1 + ) + + @config.enable_checks(False) + def test_jit_of_vmap_of_component(self): + # NOTE: This should be the same as test_vmap_of_component except for one + # more infer params cache miss because of the extra jit. + with self.make_in_memory_cache(): + cache = aot.get_cache() + + def f(x): + logging.info("running!") + return x + 1.0 + + logging.info("\n\nuser fun id %s", id(f)) + component_f = aot.component(key="f")(f) + logging.info("\n\ncomponent_f id %s", id(component_f)) + vmapped_f = jax.jit(jax.vmap(component_f)) + logging.info("\n\nvmapped_f id %s", id(vmapped_f)) + logging.info("vmapped_f.__wrapped__ id %s", id(vmapped_f.__wrapped__)) + + # TODO(dsuo): How to put component_key on vmapped_f? This is just a hack. + vmapped_key = aot_util.ComponentKey.vmap(component_f.key) + + self.assertArraysEqual(vmapped_f(jnp.ones((4,))), [2.0] * 4) + self.validate_cache_states( + component_f.fun, component_f.key, 1, 0, 1, 8, 0, 0 + ) + self.validate_cache_states( + component_f.fun, vmapped_key, 1, 0, 1, 8, None, 1 + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From b97b4f3e1d0a61fd406f2c9ff438490238a66bc5 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:35:36 -0800 Subject: [PATCH 17/24] Update jax/_src/aot_util.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- jax/_src/aot_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 22e24db00388..03b7278f6520 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -43,7 +43,7 @@ def __hash__(self): return hash(self.user_key) def __eq__(self, other): - return hash(self) == hash(other) + return isinstance(other, ComponentKey) and self.user_key == other.user_key def __str__(self): return self.user_key From 1cabe638b280a8d304e19924f613392ad1c3f533 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:55:36 -0800 Subject: [PATCH 18/24] Update --- jax/_src/traceback_util.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index e99db5ca546c..b7641c209589 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -160,8 +160,10 @@ def _filtering_mode() -> str: mode = "quiet_remove_frames" return mode -import logging -def api_boundary(fun: C) -> C: +def api_boundary( + fun: C, *, + repro_api_name: str | None = None, + repro_user_func: bool = False) -> C: '''Wraps ``fun`` to form a boundary for filtering exception tracebacks. When an exception occurs below ``fun``, this appends to it a custom @@ -182,10 +184,10 @@ def api_boundary(fun: C) -> C: ``g``. Because the function returned by :func:`~jax.jit` is annotated as an :func:`~api_boundary`, such an exception is accompanied by an additional traceback that excludes the frames specific to JAX's implementation. + + For the "repro" kwargs, see the comments for `repro.boundary`. ''' - if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: - logging.info("api_boundary fun id %s", id(fun)) @functools.wraps(fun) def reraise_with_filtered_traceback(*args, **kwargs): __tracebackhide__ = True @@ -221,13 +223,17 @@ def reraise_with_filtered_traceback(*args, **kwargs): raise finally: del mode, tb - if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: - logging.info("reraise_with_filtered_traceback id %s", - id(reraise_with_filtered_traceback)) - casted = cast(C, reraise_with_filtered_traceback) - - if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: - logging.info("casted reraise id %s", - id(casted)) - - return reraise_with_filtered_traceback + if (repro_api_name or repro_user_func) and repro: + reraise_with_filtered_traceback = repro.boundary( + reraise_with_filtered_traceback, api_name=repro_api_name, + is_user=repro_user_func) + return cast(C, reraise_with_filtered_traceback) + +try: + # TODO: import from the final location + from jax._src import repro # type: ignore + repro_is_enabled = repro.is_enabled + +except ImportError: + repro = None # type: ignore + def repro_is_enabled(): return False # type: ignore From f41db3941a101983b5ec7abf9ac6b5b6f6a14a90 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:56:20 -0800 Subject: [PATCH 19/24] Update --- jaxlib/_jax/__init__.pyi | 178 +++++++++++++++++++++------------------ 1 file changed, 94 insertions(+), 84 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index f3399a83b705..055c1dc2d1de 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -134,8 +134,8 @@ class Shape: def array_shape( type: PrimitiveType, dims: Sequence[int], - layout: Sequence[int] | None = None, - dynamic_dimensions: Sequence[bool] | None = None, + layout: Sequence[int] | None = ..., + dynamic_dimensions: Sequence[bool] | None = ..., ) -> Shape: """Constructs an array shape.""" @@ -144,8 +144,8 @@ class Shape: def array_shape( type: numpy.dtype, dims: Sequence[int], - layout: Sequence[int] | None = None, - dynamic_dimensions: Sequence[bool] | None = None, + layout: Sequence[int] | None = ..., + dynamic_dimensions: Sequence[bool] | None = ..., ) -> Shape: ... @staticmethod def token_shape() -> Shape: ... @@ -191,7 +191,7 @@ class Literal: def __init__(self, arg: Shape, /) -> None: ... def __repr__(self) -> str: ... def __array__( - self, dtype: object | None = None, copy: bool | None = None + self, dtype: object | None = ..., copy: bool | None = ... ) -> NDArray: ... def shape(self) -> Shape: ... @@ -201,7 +201,7 @@ class XlaComputation: def program_shape(self) -> ProgramShape: ... def name(self) -> str: ... def as_serialized_hlo_module_proto(self) -> bytes: ... - def as_hlo_text(self, print_large_constants: bool = False) -> str: ... + def as_hlo_text(self, print_large_constants: bool = ...) -> str: ... def as_hlo_dot_graph(self) -> str: ... def hash(self) -> int: ... def as_hlo_module(self) -> HloModule: ... @@ -368,8 +368,8 @@ def register_custom_call_target( fn_name: object, fn: object, platform: str, - api_version: int = 0, - traits: int = 0, + api_version: int = ..., + traits: int = ..., ) -> None: ... def custom_call_targets(platform: str) -> dict: ... @@ -724,9 +724,9 @@ class HloSharding: @staticmethod def iota_tile( dims: Sequence[int], - reshape_dims: Sequence[int] = [], - transpose_perm: Sequence[int] = [], - subgroup_types: Sequence[OpSharding_Type] = [], + reshape_dims: Sequence[int] = ..., + transpose_perm: Sequence[int] = ..., + subgroup_types: Sequence[OpSharding_Type] = ..., ) -> HloSharding: ... @staticmethod def manual() -> HloSharding: ... @@ -739,7 +739,7 @@ class HloSharding: @staticmethod def subgroup_with_device_ordering( tile_assignment: Annotated[NDArray[numpy.int64], dict(order='C')], - subgroup_types: Sequence[OpSharding_Type] = [], + subgroup_types: Sequence[OpSharding_Type] = ..., ) -> HloSharding: ... def __eq__(self, other: object, /) -> bool: ... def __ne__(self, other: object, /) -> bool: ... @@ -872,8 +872,8 @@ class Client: def buffer_from_pyval( self, argument: object, - device: Device | None = None, - force_copy: bool = False, + device: Device | None = ..., + force_copy: bool = ..., host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, ) -> object: ... def compile( @@ -924,15 +924,23 @@ class Client: self, serialized: bytes, executable_devices: DeviceList, - compile_options: CompileOptions | None = None, - host_callbacks: Sequence[typing_extensions.CapsuleType] = [], + compile_options: CompileOptions | None = ..., + host_callbacks: Sequence[typing_extensions.CapsuleType] = ..., + ) -> LoadedExecutable: ... + @overload + def deserialize_executable( + self, + serialized: bytes, + executable_devices: DeviceList, + compile_options: CompileOptions | None = ..., + host_callbacks: Sequence[Callable] = ..., ) -> LoadedExecutable: ... @overload def deserialize_executable( self, serialized: bytes, executable_devices: Sequence, - compile_options: CompileOptions | None = None, + compile_options: CompileOptions | None = ..., ) -> LoadedExecutable: ... def heap_profile(self) -> bytes: ... def defragment(self) -> None: ... @@ -943,7 +951,7 @@ class Client: result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], recv_channel_ids: Sequence[int], - serializer: Callable | None = None, + serializer: Callable | None = ..., ) -> object: ... def get_default_layout( self, dtype: numpy.dtype, shard_shape: Sequence, device: Device @@ -971,34 +979,34 @@ class CpuCollectives: def make_gloo_tcp_collectives( distributed_client: DistributedRuntimeClient, - hostname: str | None = None, - interface: str | None = None, + hostname: str | None = ..., + interface: str | None = ..., ) -> CpuCollectives: ... def make_mpi_collectives() -> CpuCollectives: ... def get_tfrt_cpu_client( - asynchronous: bool = True, - distributed_client: DistributedRuntimeClient | None = None, - node_id: int = 0, - num_nodes: int = 1, - collectives: CpuCollectives | None = None, - num_devices: int | None = None, - get_local_topology_timeout_minutes: int | None = None, - get_global_topology_timeout_minutes: int | None = None, - transfer_server_factory: TransferServerInterfaceFactory | None = None, + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: CpuCollectives | None = ..., + num_devices: int | None = ..., + get_local_topology_timeout_minutes: int | None = ..., + get_global_topology_timeout_minutes: int | None = ..., + transfer_server_factory: TransferServerInterfaceFactory | None = ..., ) -> Client: ... def pjrt_plugin_loaded(arg: str, /) -> bool: ... def load_pjrt_plugin( platform_name: str, - library_path: str | None = None, - c_api: typing_extensions.CapsuleType | None = None, + library_path: str | None = ..., + c_api: typing_extensions.CapsuleType | None = ..., ) -> typing_extensions.CapsuleType: ... def pjrt_plugin_initialized(arg: str, /) -> bool: ... def initialize_pjrt_plugin(arg: str, /) -> None: ... def get_c_api_client( platform_name: str, - options: Mapping[str, str | bool | int | Sequence[int] | float] = {}, - distributed_client: DistributedRuntimeClient | None = None, - transfer_server_factory: TransferServerInterfaceFactory | None = None, + options: Mapping[str, str | bool | int | Sequence[int] | float] = ..., + distributed_client: DistributedRuntimeClient | None = ..., + transfer_server_factory: TransferServerInterfaceFactory | None = ..., ) -> Client: ... def get_default_c_api_topology( arg0: str, @@ -1026,7 +1034,7 @@ def batched_copy_array_to_devices_with_sharding( /, ) -> list[Array]: ... def array_result_handler( - aval: object, sharding: object, committed: bool, _skip_checks: bool = False + aval: object, sharding: object, committed: bool, _skip_checks: bool = ... ) -> ResultHandler: ... class ResultHandler: @@ -1068,8 +1076,8 @@ class NamedSharding(Sharding): self, mesh: object, spec: PartitionSpec, - memory_kind: object | None = None, - _logical_device_ids: object | None = None, + memory_kind: object | None = ..., + _logical_device_ids: object | None = ..., ) -> None: ... @property def mesh(self) -> object: ... @@ -1086,7 +1094,7 @@ class NamedSharding(Sharding): class SingleDeviceSharding(Sharding): def __init__( - self, device: object, memory_kind: object | None = None + self, device: object, memory_kind: object | None = ... ) -> None: ... @property def _device(self) -> object: ... @@ -1112,28 +1120,28 @@ class GSPMDSharding(Sharding): self, devices: DeviceList, op_sharding: OpSharding, - memory_kind: object | None = None, + memory_kind: object | None = ..., ) -> None: ... @overload def __init__( self, devices: DeviceList, op_sharding: HloSharding, - memory_kind: object | None = None, + memory_kind: object | None = ..., ) -> None: ... @overload def __init__( self, devices: Sequence[Device], op_sharding: OpSharding, - memory_kind: object | None = None, + memory_kind: object | None = ..., ) -> None: ... @overload def __init__( self, devices: Sequence[Device], op_sharding: HloSharding, - memory_kind: object | None = None, + memory_kind: object | None = ..., ) -> None: ... @property def _devices(self) -> DeviceList: ... @@ -1215,9 +1223,7 @@ class LoadedExecutable: def size_of_generated_code_in_bytes(self) -> int: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def execute_sharded( - self, - arguments: Sequence[Array | Sequence[Array]], - with_tokens: bool = False, + self, arguments: Sequence[Array], with_tokens: bool = ... ) -> ExecuteResults: ... def hlo_modules(self) -> list[HloModule]: ... def get_output_memory_kinds(self) -> list[list[str]]: ... @@ -1242,13 +1248,16 @@ class ShardedToken: def get_token(self, arg: int, /) -> Token: ... def buffer_to_dlpack_managed_tensor( - buffer: object, stream: int | None = None + buffer: object, stream: int | None = ... ) -> typing_extensions.CapsuleType: ... def dlpack_managed_tensor_to_buffer( - dlpack: typing_extensions.CapsuleType, device: Device, stream: int | None, copy: bool | None = None + dlpack: typing_extensions.CapsuleType, + device: Device, + stream: int | None, + copy: bool | None = ..., ) -> ArrayImpl: ... def cuda_array_interface_to_buffer( - cai: dict, gpu_backend: Client | None = None, device_id: int | None = None + cai: dict, gpu_backend: Client | None = ..., device_id: int | None = ... ) -> object: ... class RuntimeTracebackMode(enum.Enum): @@ -1269,7 +1278,7 @@ def set_send_traceback_to_runtime_thread_local( ) -> None: ... class PjitFunctionCache: - def __init__(self, capacity: int = 4096) -> None: ... + def __init__(self, capacity: int = ...) -> None: ... def size(self) -> int: ... def capacity(self) -> int: ... def clear(self) -> None: ... @@ -1285,7 +1294,7 @@ class PjitFunction: def __call__(self, /, *args, **kwargs): """Call self as a function.""" - def __get__(self, instance, owner=None, /): + def __get__(self, instance, owner=..., /): """Return an attribute of instance, which is of type owner.""" __vectorcalloffset__: types.MemberDescriptorType = ... @@ -1382,8 +1391,8 @@ def register_custom_call_partitioner( prop_user_sharding: object, partition: object, infer_sharding_from_operands: object, - can_side_effecting_have_replicated_sharding: bool = False, - c_api: typing_extensions.CapsuleType | None = None, + can_side_effecting_have_replicated_sharding: bool = ..., + c_api: typing_extensions.CapsuleType | None = ..., ) -> None: """Registers a partitioner for a custom-call operation. @@ -1404,7 +1413,7 @@ def register_custom_call_partitioner( def encode_inspect_sharding_callback(arg: object, /) -> bytes: ... def register_custom_call_as_batch_partitionable( - target_name: str, c_api: typing_extensions.CapsuleType | None = None + target_name: str, c_api: typing_extensions.CapsuleType | None = ... ) -> None: """Registers a custom call as batch partitionable. @@ -1421,6 +1430,7 @@ def register_custom_call_as_batch_partitionable( class TransferConnection: def _testonly_inject_failure(self) -> None: ... + def _poison_connection(self) -> None: ... def _pull_flat( self, arg0: int, arg1: Client, arg2: Sequence[object], / ) -> list[Array]: ... @@ -1437,19 +1447,19 @@ class TransferServer: def _make_error_array(arg0: Client, arg1: object, arg2: str, /) -> Array: ... def start_transfer_server( client: Client, - address: str = '[::]:0', - transport_addresses: Sequence[str] = [], - max_num_parallel_copies: int = 8, - transfer_size: int = 268435456, - supports_pinned_allocator: bool = False, - use_raw_buffers: bool = False, + address: str = ..., + transport_addresses: Sequence[str] = ..., + max_num_parallel_copies: int = ..., + transfer_size: int = ..., + supports_pinned_allocator: bool = ..., + use_raw_buffers: bool = ..., ) -> TransferServer: ... def make_transfer_server_interface_factory( - transfer_size: int = 268435456, - cross_host_transfer_timeout_seconds: int = 60, - distributed_client: DistributedRuntimeClient | None = None, - socket_address: str = '[::]:0', - transport_addresses: Sequence[str] = [], + transfer_size: int = ..., + cross_host_transfer_timeout_seconds: int = ..., + distributed_client: DistributedRuntimeClient | None = ..., + socket_address: str = ..., + transport_addresses: Sequence[str] = ..., ) -> TransferServerInterfaceFactory: ... class PreemptionSyncManager: @@ -1478,14 +1488,14 @@ class DistributedRuntimeClient: self, barrier_id: str, timeout_in_ms: int, - process_ids: Sequence[int] | None = None, + process_ids: Sequence[int] | None = ..., ) -> None: ... def get_live_nodes(self, process_ids: Sequence[int]) -> dict[int, int]: ... def key_value_set( - self, key: str, value: str, allow_overwrite: bool = False + self, key: str, value: str, allow_overwrite: bool = ... ) -> None: ... def key_value_set_bytes( - self, key: str, value: bytes, allow_overwrite: bool = False + self, key: str, value: bytes, allow_overwrite: bool = ... ) -> None: ... def key_value_dir_get(self, key: str) -> list[tuple[str, str]]: ... def key_value_dir_get_bytes(self, key: str) -> list[tuple[str, bytes]]: ... @@ -1494,21 +1504,21 @@ class DistributedRuntimeClient: def get_distributed_runtime_service( address: str, num_nodes: int, - heartbeat_timeout: int | None = None, - cluster_register_timeout: int | None = None, - shutdown_timeout: int | None = None, + heartbeat_timeout: int | None = ..., + cluster_register_timeout: int | None = ..., + shutdown_timeout: int | None = ..., ) -> DistributedRuntimeService: ... def get_distributed_runtime_client( address: str, node_id: int, - rpc_timeout: int | None = None, - init_timeout: int | None = None, - shutdown_timeout: int | None = None, - heartbeat_timeout: int | None = None, - missed_heartbeat_callback: Callable | None = None, - shutdown_on_destruction: bool | None = None, - use_compression: bool | None = None, - recoverable: bool | None = None, + rpc_timeout: int | None = ..., + init_timeout: int | None = ..., + shutdown_timeout: int | None = ..., + heartbeat_timeout: int | None = ..., + missed_heartbeat_callback: Callable | None = ..., + shutdown_on_destruction: bool | None = ..., + use_compression: bool | None = ..., + recoverable: bool | None = ..., ) -> DistributedRuntimeClient: ... def collect_garbage() -> None: ... def is_optimized_build() -> bool: ... @@ -1561,10 +1571,10 @@ def batched_device_put( sharding: object, xs: Sequence[object], devices: Sequence[Device], - committed: bool = True, - force_copy: bool = False, + committed: bool = ..., + force_copy: bool = ..., host_buffer_semantics: HostBufferSemantics = HostBufferSemantics.ZERO_COPY, - enable_x64: bool | None = None, + enable_x64: bool | None = ..., ) -> object: ... def reorder_shards( x: Array, dst_sharding: object, array_copy_semantics: ArrayCopySemantics @@ -1574,15 +1584,15 @@ def check_and_canonicalize_memory_kind( memory_kind: object | None, device_list: DeviceList ) -> object: ... -ifrt_version_number: int = 34 +ifrt_version_number: int = ... def approx_top_k_reduction_output_size( input_size: int, rank: int, top_k: int, recall_target: float, - aggregate_to_topk: bool = True, - input_size_override: int = -1, + aggregate_to_topk: bool = ..., + input_size_override: int = ..., ) -> tuple[int, int]: ... def get_internal_device_put_info() -> dict[str, int]: ... From 9ed761f7c30b5da2a10a8a3fd358f7c124fb51ac Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:56:44 -0800 Subject: [PATCH 20/24] Update --- jaxlib/dlpack.cc | 16 ++++++++++------ jaxlib/jax.cc | 13 +++++++------ jaxlib/xla_client.py | 2 +- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc index 48d4a7329876..ca90dd1a5c5e 100644 --- a/jaxlib/dlpack.cc +++ b/jaxlib/dlpack.cc @@ -194,13 +194,16 @@ absl::StatusOr> MakePjrtBuffer( // On CPU, creating a view may fail because of unaligned data buffer // in which case we'll fallback to copy. On non-CPU, array-api copy // semantics is handled in dlpack._place_array function. - bool fallback_to_copy = !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU; + bool fallback_to_copy = + !copy.has_value() && dlmt->dl_tensor.device.device_type == kDLCPU; // Create a view. if (!copy.value_or(false)) { auto result = device.client()->CreateViewOfDeviceBuffer( - data, shape, *device.default_memory_space(), on_delete_callback, stream); - if (!(result.status().code() == absl::StatusCode::kInvalidArgument && fallback_to_copy)) { + data, shape, *device.default_memory_space(), on_delete_callback, + stream); + if (!(result.status().code() == absl::StatusCode::kInvalidArgument && + fallback_to_copy)) { return result; } } @@ -216,8 +219,8 @@ absl::StatusOr> MakePjrtBuffer( // Create a copy. return device.client()->BufferFromHostBuffer( data, element_type, dimensions, byte_strides, - xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, - memory_space, /*device_layout=*/nullptr); + xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, + on_delete_callback, memory_space, /*device_layout=*/nullptr); } } // namespace @@ -317,7 +320,8 @@ absl::StatusOr BufferToDLPackManagedTensor( absl::StatusOr DLPackManagedTensorToBuffer( const nb::capsule& tensor, ifrt::Device* ifrt_device, - nb_class_ptr client, std::optional stream, std::optional copy) { + nb_class_ptr client, std::optional stream, + std::optional copy) { ifrt::PjRtDevice* device = llvm::dyn_cast_or_null(ifrt_device); if (device == nullptr) { diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc index cbda01880089..f914467ab322 100644 --- a/jaxlib/jax.cc +++ b/jaxlib/jax.cc @@ -590,17 +590,18 @@ NB_MODULE(_jax, m) { return xla::ValueOrThrow(DLPackManagedTensorToBuffer( tensor, device->device(), device->client(), stream, copy)); }, - nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none(), nb::arg("copy").none() = nb::none(), - nb::sig( - // clang-format off + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none(), + nb::arg("copy").none() = nb::none(), + nb::sig( + // clang-format off "def dlpack_managed_tensor_to_buffer(" "dlpack: typing_extensions.CapsuleType, " "device: Device, " "stream: int | None, " - "copy: bool | None" + "copy: bool | None = ..." ") -> ArrayImpl" - // clang-format on - )); + // clang-format on + )); m.def("cuda_array_interface_to_buffer", xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), nb::arg("gpu_backend").none() = nb::none(), diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 7d7983f11483..ce0379c4cb03 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 384 # Add a new copy argument to dlpack_managed_tensor_to_buffer +_version = 387 # Introduce lowering support for lax.ragged_dot + collectives. # An internal increasing version number for protecting jaxlib code against # ifrt changes. From d953bddc9da46f39f0d3baf608dee2d9dfadd1f1 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 10 Nov 2025 14:57:32 -0800 Subject: [PATCH 21/24] Update --- jax/_src/pjit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 689c63dd2b7e..a6c989579670 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -631,7 +631,6 @@ def _infer_params( return _infer_params_internal(fun, ji, args, kwargs) -cache = dict() def _infer_params_internal( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: From 053ba96e916322a08395a7fb25c6ea78753ffe5b Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 12 Nov 2025 10:40:39 -0800 Subject: [PATCH 22/24] Update --- jax/_src/aot.py | 38 +++++++++----------------------------- jax/_src/aot_util.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 5206d7ee4fe6..7f3cc4967d97 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -40,7 +40,6 @@ get_cache = aot_util.get_cache - def component( key: UserKey = None, ) -> Callable[..., Any]: @@ -51,7 +50,7 @@ def _component(fun: Callable[..., Any]): component_key = ComponentKey(key) if component_key in aot_util._wrapper_cache.cache_keys(): - logging.info('hit wrapper_cache: %s', component_key) + logging.info("hit wrapper_cache: %s", component_key) return aot_util._wrapper_cache.get(component_key) @api.jit @@ -63,10 +62,11 @@ def wrapper(*args, **kwargs): fun, debug_info=api_util.debug_info("component", fun, args, kwargs) ) flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) + # TODO(dsuo): do we need this cached? flat_fun = aot_util.cached_flat_fun(flat_fun) logging.info("miss component flat_fun %s:", id(flat_fun)) flat_fun = flat_fun.f_transformed - flat_fun.__name__ = 'wrapped(flat_fun)' + flat_fun.__name__ = "wrapped(flat_fun)" jitted_fun = api.jit(flat_fun) logging.info("miss component jitted_fun %s:", id(jitted_fun)) @@ -81,7 +81,9 @@ def wrapper(*args, **kwargs): wrapper.fun = fun logging.info("jit(wrapper(fun)) wrapper id %s", id(wrapper)) logging.info("wrapper(fun) wrapper._fun id %s", id(wrapper._fun)) - logging.info("fun wrapper._fun.__wrapped__ id %s", id(wrapper._fun.__wrapped__)) + logging.info( + "fun wrapper._fun.__wrapped__ id %s", id(wrapper._fun.__wrapped__) + ) logging.info("user fun id %s", id(fun)) aot_util._wrapper_cache.put(component_key, wrapper) return wrapper @@ -123,7 +125,7 @@ def component_abstract_eval( def component_lowering( - ctx, + ctx: mlir.LoweringRuleContext, *args, fun: Callable[..., Any], component_key: ComponentKey, @@ -139,31 +141,9 @@ def component_lowering( logging.info("missed lowering: %s", component_key) if isinstance(fun, lu.WrappedFun): fun = aot_util.maybe_reset_stores(fun).call_wrapped - traced = api.trace(fun, *ctx.avals_in) - lowering_result = mlir.lower_jaxpr_to_module( - module_name=module_name, - jaxpr=traced.jaxpr, - num_const_args=traced._num_consts, - in_avals=ctx.avals_in, - # TODO(dsuo): What are ordered effects vs effects? - ordered_effects=traced.jaxpr.effects, - # TODO(dsuo): Figure out why ctx.platforms=None. - platforms=["cpu"], - backend=ctx.module_context.backend, - axis_context=ctx.module_context.axis_context, - donated_args=tuple( - x.donated for x in tree_util.tree_leaves(traced.args_info) - ), - lowering_parameters=mlir.LoweringParameters(), - # TODO(dsuo): Presumably we need to forward the rest of the arguments to - # lower_jaxpr_to_module? + module = aot_util.lower_component_to_module( + ctx, fun, module_name, component_key ) - # TODO(dsuo): What should we do about the other attributes on - # LoweringResult? - # - keepalive: probably not supported. - # - host_callbacks: probably not supported. - # - shape_poly_state: talk to necula@ - module = lowering_result.module # TODO(dsuo): We have this to ensure the source and destination modules have # the same context, but is it necessary? Perhaps yes, since we need to get # rid of the submodule context before merging. Could we just create it with diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 03b7278f6520..5a54d7562561 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -24,6 +24,7 @@ from jax._src import core from jax._src import linear_util as lu from jax._src import stages +from jax._src import tree_util from jax._src import util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc @@ -203,3 +204,38 @@ def cache_keys(self): _wrapper_cache = WrapperCache() util.register_cache(_wrapper_cache, "aot_wrapper_cache") + + +def lower_component_to_module( + ctx: mlir.LoweringRuleContext, + fun: Callable[..., Any], + module_name: str, + component_key: ComponentKey, +) -> ir.Module: + traced = api.trace(fun, *ctx.avals_in) + lowering_result = mlir.lower_jaxpr_to_module( + module_name=module_name, + jaxpr=traced.jaxpr, + num_const_args=traced._num_consts, + in_avals=ctx.avals_in, + # TODO(dsuo): What are ordered effects vs effects? + ordered_effects=traced.jaxpr.effects, + # TODO(dsuo): Figure out why ctx.platforms=None. + platforms=["cpu"], + backend=ctx.module_context.backend, + axis_context=ctx.module_context.axis_context, + donated_args=tuple( + x.donated for x in tree_util.tree_leaves(traced.args_info) + ), + lowering_parameters=mlir.LoweringParameters(), + # TODO(dsuo): Presumably we need to forward the rest of the arguments to + # lower_jaxpr_to_module? + ) + # TODO(dsuo): What should we do about the other attributes on + # LoweringResult? + # - keepalive: probably not supported. + # - host_callbacks: probably not supported. + # - shape_poly_state: talk to necula@ + module = lowering_result.module + + return module From 56a3194ed5cab53307b9824d828296e1d2a9bfc7 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 12 Nov 2025 14:46:01 -0800 Subject: [PATCH 23/24] Update --- jax/_src/aot.py | 86 ++++++++------------------------------------ jax/_src/aot_util.py | 37 +++++++++++++++++++ tests/aot_test.py | 27 +++++++++----- 3 files changed, 69 insertions(+), 81 deletions(-) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index 7f3cc4967d97..d6a792c7f590 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -45,12 +45,9 @@ def component( ) -> Callable[..., Any]: def _component(fun: Callable[..., Any]): # TODO(dsuo): Need to consider static args, etc if fun is jitted. - # TODO(dsuo): Do we have all the information we need at this point to make - # the component key? component_key = ComponentKey(key) if component_key in aot_util._wrapper_cache.cache_keys(): - logging.info("hit wrapper_cache: %s", component_key) return aot_util._wrapper_cache.get(component_key) @api.jit @@ -58,33 +55,20 @@ def _component(fun: Callable[..., Any]): @traceback_util.api_boundary def wrapper(*args, **kwargs): args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) - wrapped_fun = lu.wrap_init( - fun, debug_info=api_util.debug_info("component", fun, args, kwargs) - ) + wrapped_fun = aot_util.wrap_init(fun, "component") flat_fun, out_tree = api_util.flatten_fun(wrapped_fun, in_tree) - # TODO(dsuo): do we need this cached? + # TODO(dsuo): We need this in vmap(vmap(f)) because otherwise we create a + # new flat_fun and may not have called it yet; store will be empty. flat_fun = aot_util.cached_flat_fun(flat_fun) - logging.info("miss component flat_fun %s:", id(flat_fun)) - flat_fun = flat_fun.f_transformed - flat_fun.__name__ = "wrapped(flat_fun)" - jitted_fun = api.jit(flat_fun) - logging.info("miss component jitted_fun %s:", id(jitted_fun)) - out_flat = component_p.bind( - *args, - fun=jitted_fun, + *args_flat, + fun=api.jit(flat_fun.f_transformed), component_key=component_key, ) return tree_util.tree_unflatten(out_tree(), out_flat) wrapper.key = component_key wrapper.fun = fun - logging.info("jit(wrapper(fun)) wrapper id %s", id(wrapper)) - logging.info("wrapper(fun) wrapper._fun id %s", id(wrapper._fun)) - logging.info( - "fun wrapper._fun.__wrapped__ id %s", id(wrapper._fun.__wrapped__) - ) - logging.info("user fun id %s", id(fun)) aot_util._wrapper_cache.put(component_key, wrapper) return wrapper @@ -92,8 +76,6 @@ def wrapper(*args, **kwargs): def component_impl(*args, fun: Callable[..., Any], **_): - if isinstance(fun, lu.WrappedFun): - return fun.call_wrapped(*args) return fun(*args) @@ -102,25 +84,16 @@ def component_abstract_eval( fun: Callable[..., Any], component_key: ComponentKey, ) -> Sequence[core.AbstractValue] | None: - # ????(dsuo): Is this an effectful rule? + # TODO(dsuo): Is this an effectful rule since we read/write to disk? entry = aot_util.get_entry(component_key) - logging.info("component_abstract_eval got entry %s", component_key) if entry is None: - logging.info("missed abstract_eval %s %s", component_key, type(fun)) # TODO(dsuo): By the time we get to lowering, our trace context has picked # up an empty AbstractMesh. Don't know why. - if isinstance(fun, functools.partial): - logging.info("abstract_eval partial %s", fun.func.__name__) - if isinstance(fun, lu.WrappedFun): - logging.info("abstract_eval lu.WrappedFun") - fun = aot_util.maybe_reset_stores(fun).call_wrapped with mesh_lib.use_abstract_mesh(mesh_lib.AbstractMesh((), (), ())): avals_out = tree_util.tree_map( lambda x: core.ShapedArray(x.shape, x.dtype), api.eval_shape(fun, *args) ) aot_util.put_entry(component_key, entry := aot_util.CacheEntry(avals_out)) - else: - logging.info("hit abstract_eval %s", component_key) return entry.avals_out @@ -130,43 +103,18 @@ def component_lowering( fun: Callable[..., Any], component_key: ComponentKey, ) -> Sequence[ir.Value]: - with ctx.module_context.context as ir_ctx: - entry = aot_util.get_entry(component_key, ir_ctx) - logging.info("component_lowering got entry %s", component_key) + module_name = aot_util.get_module_name(ctx.module_context.module) + entry = aot_util.get_entry(component_key, ctx.module_context.context) if entry is None: raise ValueError("Should hit abstract_eval already, which would populate.") - module_name = f"{component_key}.module" if (module := entry.module) is None: - logging.info("missed lowering: %s", component_key) - if isinstance(fun, lu.WrappedFun): - fun = aot_util.maybe_reset_stores(fun).call_wrapped - module = aot_util.lower_component_to_module( + entry.module = module = aot_util.lower_component_to_module( ctx, fun, module_name, component_key ) - # TODO(dsuo): We have this to ensure the source and destination modules have - # the same context, but is it necessary? Perhaps yes, since we need to get - # rid of the submodule context before merging. Could we just create it with - # the right context? - entry.module = module = ir.Module.parse(mlir.module_to_bytecode(module)) aot_util.put_entry(component_key, entry, update=True) - else: - logging.info("hit lowering: %s", component_key) - - symtab = ir.SymbolTable(module.operation) - module = mlir.merge_mlir_modules( - ctx.module_context.module, - f"component_{module_name}", - module, - dst_symtab=ctx.module_context.symbol_table, - ) - # TODO(dsuo): There's quite a bit of logic from jax.export, but we just strip - # away most of that for this demo. e.g., ordered effects, platforms. - # submodule_args = [mlir.aval_to_ir_type(x) for x in ctx.avals_in] - results = symtab["main"].type.results - call = func_dialect.CallOp(results, ir.FlatSymbolRefAttr.get(module), args) - return call.results + return aot_util.get_module_results(ctx, module, module_name, *args) def component_batcher( @@ -180,21 +128,15 @@ def component_batcher( # TODO(dsuo): Ignore ragged. # TODO(dsuo): Ignore updating annotations. - # TODO(dsuo): Dummy debug info. - if isinstance(fun, functools.partial): - name = fun.func.__name__ - else: - name = fun.__name__ - if isinstance(fun, lu.WrappedFun): - fun = aot_util.maybe_reset_stores(fun) - wrapped_fun = lu.wrap_init( - fun, debug_info=lu.DebugInfo("vmap(component)", name, None, None) - ) + # TODO(dsuo): This doesn't handle nesting. + wrapped_fun = aot_util.wrap_init(fun, "vmap(component)") # ????(dsuo): I don't understand trace tags. batched_fun, dims_out = batching.batch_subtrace( wrapped_fun, core.TraceTag(), axis_data, tuple(dims_in) ) + # TODO(dsuo): We might need to reset stores because we may be calling a cached + # wrapped fun. batched_fun = aot_util.maybe_reset_stores(batched_fun) vals_out = component_p.bind( diff --git a/jax/_src/aot_util.py b/jax/_src/aot_util.py index 5a54d7562561..d0b14da44409 100644 --- a/jax/_src/aot_util.py +++ b/jax/_src/aot_util.py @@ -14,6 +14,7 @@ """JAX AOT API utilities.""" from collections.abc import Hashable +import functools import pickle import traceback from typing import Any, Callable, NamedTuple, Self, Sequence @@ -237,5 +238,41 @@ def lower_component_to_module( # - host_callbacks: probably not supported. # - shape_poly_state: talk to necula@ module = lowering_result.module + # TODO(dsuo): Do we need to do this step to strip context? + module = ir.Module.parse(mlir.module_to_bytecode(module)) return module + + +def get_module_name(module: ir.Module) -> str: + # TODO(dsuo): Is this reasonable? + module_name = str(module.operation.attributes["sym_name"].value) + logging.info("module_name: %s", module_name) + return module_name + + +def get_module_results( + ctx: mlir.LoweringRuleContext, module: ir.Module, module_name: str, *args +) -> Sequence[ir.Value]: + symtab = ir.SymbolTable(module.operation) + module = mlir.merge_mlir_modules( + ctx.module_context.module, + f"component_{module_name}", + module, + dst_symtab=ctx.module_context.symbol_table, + ) + results = symtab["main"].type.results + call = func_dialect.CallOp(results, ir.FlatSymbolRefAttr.get(module), args) + + return call.results + + +def wrap_init(fun: Callable[..., Any], traced_for: str) -> lu.WrappedFun: + # TODO(dsuo): Dummy debug info. + if isinstance(fun, functools.partial): + name = fun.func.__name__ + else: + name = fun.__name__ + return lu.wrap_init( + fun, debug_info=lu.DebugInfo(traced_for, name, None, None) + ) diff --git a/tests/aot_test.py b/tests/aot_test.py index cde8c79f60ad..c9371d585828 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -356,7 +356,6 @@ def validate_cache_states( @config.enable_checks(False) def test_component_basic(self): with self.make_in_memory_cache(): - cache = aot.get_cache() @aot.component(key="f") def f(x): @@ -397,7 +396,6 @@ def g(x): @config.enable_checks(False) def test_component_in_function(self): with self.make_in_memory_cache(): - cache = aot.get_cache() @aot.component(key="f") def f(x): @@ -422,7 +420,6 @@ def g(x): @config.enable_checks(False) def test_jit_of_component(self): with self.make_in_memory_cache(): - cache = aot.get_cache() @jax.jit @aot.component(key="f") @@ -450,7 +447,6 @@ def g(x): @config.enable_checks(False) def test_component_of_jit(self): with self.make_in_memory_cache(): - cache = aot.get_cache() @aot.component(key="f") @jax.jit @@ -477,7 +473,6 @@ def g(x): @config.enable_checks(False) def test_explicit_lowering(self): with self.make_in_memory_cache(): - cache = aot.get_cache() @aot.component(key="f") def f(x): @@ -503,7 +498,6 @@ def g(x): @config.enable_checks(False) def test_vmap_of_component(self): with self.make_in_memory_cache(): - cache = aot.get_cache() def f(x): logging.info("running!") @@ -530,7 +524,6 @@ def f(x): @config.enable_checks(False) def test_vmap_of_vmap_of_component(self): with self.make_in_memory_cache(): - cache = aot.get_cache() def f(x): logging.info("running!") @@ -576,7 +569,6 @@ def test_vmap_of_jit_of_component(self): # NOTE: This should be the same as test_vmap_of_component except for one # more infer params cache miss because of the extra jit. with self.make_in_memory_cache(): - cache = aot.get_cache() def f(x): logging.info("running!") @@ -605,7 +597,6 @@ def test_jit_of_vmap_of_component(self): # NOTE: This should be the same as test_vmap_of_component except for one # more infer params cache miss because of the extra jit. with self.make_in_memory_cache(): - cache = aot.get_cache() def f(x): logging.info("running!") @@ -629,6 +620,24 @@ def f(x): component_f.fun, vmapped_key, 1, 0, 1, 8, None, 1 ) + @config.enable_checks(False) + def test_scan_of_component(self): + with self.make_in_memory_cache(): + + @aot.component(key="f") + def f(x): + logging.info("running!") + return x + 1.0 + + def body(carry, x): + return f(carry), f(x) + + carry, ys = jax.lax.scan(body, 0, jnp.arange(10, dtype="float32")) + + self.assertEqual(carry, 10) + self.assertArraysEqual(ys, jnp.arange(10, dtype="float32") + 1) + self.validate_cache_states(f.fun, f.key, 2, (0, 0), 2, 13, None, 2) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 9b0fdc7f7de08d5e973afdb05fbd5988f7de710c Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 13 Nov 2025 07:08:55 -0800 Subject: [PATCH 24/24] Update --- jax/_src/aot.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax/_src/aot.py b/jax/_src/aot.py index d6a792c7f590..63ccdc537e11 100644 --- a/jax/_src/aot.py +++ b/jax/_src/aot.py @@ -147,6 +147,10 @@ def component_batcher( return vals_out, dims_out() +def component_jvp(arg_values, arg_tangents): + pass + + component_p = core.Primitive("component") component_p.multiple_results = True component_p.def_impl(component_impl)