From cfe02d82b80d3ad240c2776f882766d6cfbf25d1 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Fri, 5 Dec 2025 16:12:08 -0800 Subject: [PATCH] parent CL, DO NOT SUBMIT PiperOrigin-RevId: 840914090 --- .../v1/_src/partial/partial_merge_main.py | 329 ++++++++++++++++ .../_src/partial/partial_merge_main_test.py | 357 ++++++++++++++++++ 2 files changed, 686 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main_test.py diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main.py new file mode 100644 index 000000000..1c6070249 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main.py @@ -0,0 +1,329 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Partial merge binary.""" + +import collections +import dataclasses +import random +from typing import Any, Iterator, List + +from absl import app +from absl import flags +from etils import epath +import jax +import numpy as np +from orbax.checkpoint._src.arrays import fragments as array_fragments +from orbax.checkpoint._src.arrays import sharding as array_sharding +from orbax.checkpoint._src.tree import parts_of +from orbax.checkpoint._src.tree import structure_utils +from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint.experimental.model_surgery import source_checkpoint +from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.partial import saving as partial_saving + +# Note: Ensure you have access to numpy_utils if copying _from_fragments, +# otherwise rely on public APIs if available. +from .learning.deepmind.jax.roc import numpy_utils +from .learning.deepmind.jax.roc.experimental import eval_fragments + +FLAGS = flags.FLAGS + +_IN_PATHS = flags.DEFINE_multi_string( + 'in_paths', + None, + 'Paths of checkpoints to merge.', + required=True, +) +_OUT_PATH = flags.DEFINE_string( + 'out_path', + None, + 'Output checkpoint path.', + required=True, +) +_PER_HOST_MEMORY_LIMIT_GB = flags.DEFINE_integer( + 'per_host_memory_limit_gb', + 16, + 'Memory limit in GB per CPU host for partial loading and saving.' + ' Non-uniform memory limits are not supported.', +) + +PyTree = Any +Keypath = tuple[Any, ...] +PartsOf = parts_of.PartsOf + + +def fragments_to_arrays( + fragments_or_arrays: PyTree, + target: PyTree, +) -> PyTree: + """Creates jax.Array from a tree of Fragments.""" + + def _to_jax_array(frags_or_arr, abstract_target): + if not isinstance(frags_or_arr, eval_fragments.ConcreteFragments): + return frags_or_arr + + def extract_shard(idx) -> jax.Array: + idx = numpy_utils.resolve_slice(idx, abstract_target.shape) + shard_data = eval_fragments._extract_fragment( # pylint: disable=protected-access + frags_or_arr.fragments, + eval_fragments.AbstractFragment(index=idx), + ).value + assert shard_data is not None + return jax.numpy.asarray(shard_data) + + sharding = abstract_target.sharding + return jax.make_array_from_callback( + abstract_target.shape, sharding, extract_shard + ) + + return jax.tree.map(_to_jax_array, fragments_or_arrays, target) + + +@dataclasses.dataclass(frozen=True) +class FragmentInfo: + """Information about a fragment to be used for batching.""" + + ckpt_idx: int + keypath: Keypath + fragment: array_fragments.AbstractFragment + dtype: np.dtype + + @property + def size_bytes(self) -> int: + return self.fragment.nbytes_astype(self.dtype) + + +def merge_transform_fn(*args: PyTree) -> PyTree: + """Merges trees, overwriting existing keys.""" + return structure_utils.merge_trees(*args, overwrite=True) + + +def batch_fragments( + fragment_infos: list[FragmentInfo], memory_limit_gb: int +) -> Iterator[list[FragmentInfo]]: + """Groups leaves into batches based on memory availability.""" + memory_limit_bytes = memory_limit_gb * 1024**3 + current_batch_leaves = [] + current_batch_size = 0 + + for finfo in fragment_infos: + if finfo.size_bytes > memory_limit_bytes: + raise ValueError( + f'Fragment size {finfo.size_bytes} is larger than memory limit.' + ) + + if current_batch_size + finfo.size_bytes > memory_limit_bytes: + # Yield the current batch and start a new one. + yield current_batch_leaves + current_batch_leaves = [finfo] + current_batch_size = finfo.size_bytes + else: + # Add the leaf to the current batch. + current_batch_leaves.append(finfo) + current_batch_size += finfo.size_bytes + + if current_batch_leaves: + # Yield the final batch. + yield current_batch_leaves + + +def resolve_pytree_path(path: epath.Path) -> epath.Path: + """Resolves the path to the pytree checkpoint.""" + if not (path / checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY).exists(): + raise ValueError(f'Path {path} does not contain a pytree checkpoint.') + + return path / checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY + + +def resolve_target_structure( + abstract_sources: list[PyTree], host_cpus: list[jax.Device] +) -> PyTree: + """Resolves output structure and output sharding.""" + abstract_target = jax.eval_shape(merge_transform_fn, *abstract_sources) + + shardings = array_sharding.construct_maximal_shardings( + abstract_target, devices=host_cpus + ) + sharded_abstract_target = jax.tree.map( + lambda x, s: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=s), + abstract_target, + shardings, + ) + + return sharded_abstract_target + + +def resolve_merge_topology( + sharded_abstract_target: PyTree, abstract_sources: list[PyTree] +) -> tuple[PyTree, Any]: + """Uses Model Surgery to resolve topology.""" + + # Determine Fragments + abstract_fragments_to_load = jax.tree.map( + array_fragments.abstract_fragments, sharded_abstract_target + ) + + # The "Surgery": Map inputs to outputs + return eval_fragments.eval_fragments( + merge_transform_fn, + abstract_sources, + abstract_fragments_to_load, + ) + + +def create_fragment_infos(required_input_fragments: Any) -> list[FragmentInfo]: + """Flattens fragments into FragmentInfos for batching.""" + fragment_infos = [] + for ckpt_idx, fragments_tree in enumerate(required_input_fragments): + flat_fragments = tree_utils.to_flat_dict(fragments_tree) + ckpt_fragment_infos = [] + + for keypath, fragments in flat_fragments.items(): + for fragment in fragments.fragments: + ckpt_fragment_infos.append( + FragmentInfo( + ckpt_idx=ckpt_idx, + keypath=keypath, + fragment=fragment, + dtype=fragments.dtype, + ) + ) + + # Randomize the order of leaves *within* this checkpoint. This helps mix + # large and small arrays in batches to avoid wasting batch space. + random.shuffle(ckpt_fragment_infos) + fragment_infos.extend(ckpt_fragment_infos) + return fragment_infos + + +def load_batch_fragments( + abstract_sources: list[PyTree], + batch_fragments_map: dict[ + int, dict[tuple[Any, ...], list[array_fragments.AbstractFragment]] + ], + source_checkpoints: list[source_checkpoint.SourceCheckpoint], + memory_limit_gb: int, +) -> list[PyTree]: + """Loads fragments for a batch.""" + loaded_fragments = [] + # Reconstruct trees for loading + for i, abstract_source in enumerate(abstract_sources): + # We need to construct a request tree that matches the source structure + # but only contains the fragments for this batch. + + def _get_fragments_for_leaf( + path, meta, keypath_fragments=batch_fragments_map[i] + ): + # Convert JAX KeyPath to tuple for dict lookup + path_tuple = tree_utils.tuple_path_from_keypath(path) + + frags = keypath_fragments.get(path_tuple) + + if frags: + return array_fragments.AbstractFragments( + shape=meta.shape, + dtype=meta.dtype, # Use source dtype + fragments=frags, + ) + return array_fragments.AbstractFragments( + shape=meta.shape, dtype=meta.dtype, fragments=[] + ) + + source_request_tree = jax.tree_util.tree_map_with_path( + _get_fragments_for_leaf, abstract_source + ) + + loaded_fragments.append( + source_checkpoints[i].load_fragments( + source_request_tree, concurrent_gb=memory_limit_gb + ) + ) + return loaded_fragments + + +def main(argv: List[str] | None = None) -> None: + if argv is not None and len(argv) > 1: + raise app.UsageError(f'Too many command-line arguments: {argv[1:]}') + + all_cpus = jax.devices('cpu') + host_cpus = all_cpus[: jax.process_count()] + + random.seed(0) + + ckpts_to_merge = [epath.Path(path) for path in _IN_PATHS.value] + merged_ckpt_path = epath.Path(_OUT_PATH.value) + + # Load metadata for all input checkpoints to understand their structure and + # contents. + source_checkpoints = [ + source_checkpoint.checkpoint_at(resolve_pytree_path(path)) + for path in ckpts_to_merge + ] + abstract_sources = [sc.metadata for sc in source_checkpoints] + + # Determine the structure and sharding of the final merged checkpoint. This + # acts as the blueprint for the output, derived by merging the metadata of the + # input checkpoints. + sharded_abstract_target = resolve_target_structure( + abstract_sources, host_cpus + ) + + # Plan the merge operation by identifying exactly which data fragments need to + # be read from the inputs to construct the output. This also prepares a + # transformation function to assemble the loaded data. + required_input_fragments, fragment_transform_fn = resolve_merge_topology( + sharded_abstract_target, abstract_sources + ) + + # Prepare for execution by flattening the required data fragments into a list + # of tasks. This allows us to process the merge in memory-constrained batches. + fragment_infos = create_fragment_infos(required_input_fragments) + + for batch in batch_fragments(fragment_infos, _PER_HOST_MEMORY_LIMIT_GB.value): + # Group the fragments in the current batch by their source checkpoint and + # original keypath. + batch_fragments_map = collections.defaultdict( + lambda: collections.defaultdict(list) + ) + for finfo in batch: + batch_fragments_map[finfo.ckpt_idx][finfo.keypath].append(finfo.fragment) + + # Execute the load for the current batch: fetch the specific data fragments + # from the source checkpoints into memory. + loaded_fragments = load_batch_fragments( + abstract_sources, + batch_fragments_map, + source_checkpoints, + _PER_HOST_MEMORY_LIMIT_GB.value, + ) + + # Apply the transformation function to assemble the loaded fragments into + # the desired target structure. + target_fragments = fragment_transform_fn(loaded_fragments) + + # Convert the assembled fragments into concrete, sharded JAX arrays. + target_tree = fragments_to_arrays(target_fragments, sharded_abstract_target) + + # Save the current batch of merged arrays to the output checkpoint + # directory. + partial_saving.save_pytree(merged_ckpt_path, target_tree) + + # Finalize the checkpoint, completing the merge process. + partial_saving.finalize(merged_ckpt_path) + + +if __name__ == '__main__': + jax.config.config_with_absl() + app.run(main) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main_test.py new file mode 100644 index 000000000..498af8886 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/partial_merge_main_test.py @@ -0,0 +1,357 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import random +from unittest import mock + +from absl import flags +from absl.testing import flagsaver +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.metadata import value as value_metadata +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint.experimental.v1._src.loading import loading as loading_lib +from orbax.checkpoint.experimental.v1._src.partial import partial_merge_main +from orbax.checkpoint.experimental.v1._src.saving import saving as saving_lib +from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils + +from .learning.deepmind.jax.roc.experimental import eval_fragments + +FLAGS = flags.FLAGS + +MEMORY_LIMIT_BYTES = 2**30 * 16 +MB = 2**10 + + +def _create_mock_metadata(shape, dtype): + return value_metadata.ArrayMetadata( + name='test_array', + directory=None, + shape=shape, + dtype=dtype, + sharding=None, + ) + + +def _setup_large_pytree(sharding: jax.sharding.Sharding): + """Creates a large pytree with arrays of random sizes.""" + # 100 arrays + array_sizes_mb = [50] * 2 + [10] * 5 + [2] * 13 + [1] * 80 + rng = random.Random(42) + rng.shuffle(array_sizes_mb) + + pytree = {} + for i in range(10): + for j in range(10): + array_size = array_sizes_mb[i * 10 + j] * MB // 4 + pytree.setdefault(f'param{i}', {})[f'param{j}'] = np.ones( + array_size, dtype=np.float32 + ) + pytree = jax.device_put(pytree, sharding) + return pytree + + +def _permute_pytree(pytree, idx): + def _permute(x): + if isinstance(x, np.ndarray): + rng = np.random.default_rng(seed=idx) + x *= rng.random(x.shape).astype(x.dtype) + return x + + return jax.tree.map(_permute, pytree) + + +def _get_abstract_pytree(pytree): + return jax.tree.map(array_test_utils.as_abstract_type, pytree) + + +class PartialMergeTest( + parameterized.TestCase, multiprocess_test.MultiProcessTest +): + + def setUp(self): + super().setUp() + + self.directory = epath.Path( + self.create_tempdir(name='partial_merging_test').full_path + ) + self.pytree, self.abstract_pytree = array_test_utils.create_sharded_pytree() + + test_utils.set_tensorstore_driver_for_test() + test_utils.sync_global_processes('PartialMergingTest:setUp:complete') + + def tearDown(self): + super().tearDown() + test_utils.sync_global_processes('PartialMergingTest:tearDown:complete') + + def test_batch_fragments_logic(self): + """Tests the greedy batching logic based on memory limits.""" + + # helper to create a mock FragmentInfo with a specific size + def _mock_info(size_bytes, name): + m = mock.Mock(spec=partial_merge_main.FragmentInfo) + type(m).size_bytes = mock.PropertyMock(return_value=size_bytes) + m.name = name # Just for debugging/identification in test + return m + + # 1 GB limit + limit_gb = 1 + limit_bytes = limit_gb * 1024**3 + + # Case 1: Items fit perfectly into one batch + infos = [ + _mock_info(int(limit_bytes * 0.4), 'A'), + _mock_info(int(limit_bytes * 0.4), 'B'), + ] + batches = list(partial_merge_main.batch_fragments(infos, limit_gb)) + self.assertLen(batches, 1) + self.assertEqual(batches[0], infos) + + # Case 2: Items spill over to second batch + infos = [ + _mock_info(int(limit_bytes * 0.6), 'A'), + _mock_info(int(limit_bytes * 0.5), 'B'), # 0.6 + 0.5 > 1.0 + _mock_info(int(limit_bytes * 0.2), 'C'), + ] + batches = list(partial_merge_main.batch_fragments(infos, limit_gb)) + self.assertLen(batches, 2) + self.assertEqual(batches[0], [infos[0]]) # A (0.6) + self.assertEqual(batches[1], [infos[1], infos[2]]) # B (0.5) + C (0.2) + + # Case 3: Single item equals limit + infos = [_mock_info(limit_bytes, 'A')] + batches = list(partial_merge_main.batch_fragments(infos, limit_gb)) + self.assertLen(batches, 1) + + # Case 4: Single item exceeds limit (Should raise ValueError) + infos = [_mock_info(limit_bytes + 1, 'TooBig')] + with self.assertRaisesRegex(ValueError, 'larger than memory limit'): + list(partial_merge_main.batch_fragments(infos, limit_gb)) + + @mock.patch('random.shuffle') + def test_create_fragment_infos(self, mock_shuffle): + """Tests flattening of fragment trees into FragmentInfo objects.""" + # Ensure shuffle does nothing so we can assert order deterministically + mock_shuffle.side_effect = lambda x: x + + # Mock the input structure: List[PyTree[Fragments]] + # We simulate 2 checkpoints. + + # Mock Fragment object (from eval_fragments usually) + MockFragment = collections.namedtuple('MockFragment', ['index']) + + # Mock array_fragments.Fragments + class MockFragmentsContainer: + + def __init__(self, frags, dtype): + self.fragments = frags + self.dtype = dtype + + # Checkpoint 0 structure + ckpt0_tree = { + 'layer1': MockFragmentsContainer( + [MockFragment(index=1), MockFragment(index=2)], np.float32 + ) + } + + # Checkpoint 1 structure + ckpt1_tree = { + 'layer1': MockFragmentsContainer([MockFragment(index=3)], np.float32), + 'layer2': MockFragmentsContainer([MockFragment(index=4)], np.int32), + } + + required_input_fragments = [ckpt0_tree, ckpt1_tree] + + infos = partial_merge_main.create_fragment_infos(required_input_fragments) + + # We expect: + # Ckpt0: 2 fragments from layer1 + # Ckpt1: 1 fragment from layer1, 1 fragment from layer2 + # Total = 4 FragmentInfos + self.assertLen(infos, 4) + + # Verify content of the first info (Ckpt 0, layer 1, first frag) + self.assertEqual(infos[0].ckpt_idx, 0) + self.assertEqual(infos[0].keypath, ('layer1',)) + self.assertEqual(infos[0].fragment.index, 1) + self.assertEqual(infos[0].dtype, np.float32) + + # Verify content of the last info (Ckpt 1, layer 2) + # Note: dict iteration order is insertion order in modern python, + # but create_fragment_infos iterates the list of ckpts. + + # Find the int32 fragment + int_frag = next(x for x in infos if x.dtype == np.int32) + self.assertEqual(int_frag.ckpt_idx, 1) + self.assertEqual(int_frag.keypath, ('layer2',)) + self.assertEqual(int_frag.fragment.index, 4) + + def test_fragments_to_arrays_passthrough(self): + """Tests that non-fragment leaves are passed through unchanged.""" + target = {'a': jax.ShapeDtypeStruct((1,), np.float32)} + # If the input isn't an eval_fragments.Fragments instance, returns it as is. + # Happens if tree structure doesn't match perfectly or for non-array leaves. + inputs = {'a': 123} + + result = partial_merge_main.fragments_to_arrays(inputs, target) + self.assertEqual(result['a'], 123) + + def test_fragments_to_arrays_conversion(self): + """Tests conversion of Fragments to jax.Array via callback.""" + shape = (2, 2) + dtype = np.float32 + + local_devices = [ + d for d in jax.devices('cpu') if d.process_index == jax.process_index() + ] + if not local_devices: + self.skipTest('No local CPU devices found') + + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(local_devices[:1], ('x',)), + jax.sharding.PartitionSpec('x'), + ) + + target_leaf = jax.ShapeDtypeStruct(shape, dtype, sharding=sharding) + target = {'param': target_leaf} + + mock_fragment_data = np.ones(shape, dtype) + + with mock.patch.object( + partial_merge_main.eval_fragments, '_extract_fragment' + ) as mock_extract: + mock_extract.return_value.value = mock_fragment_data + + fragments_input = eval_fragments.ConcreteFragments( + shape=shape, + dtype=np.dtype(dtype), + fragments=[ + eval_fragments.ConcreteFragment( + index=(slice(0, 2, 1), slice(0, 2, 1)), + value=np.zeros(shape, dtype), + ) + ], + ) + + inputs = {'param': fragments_input} + + result_tree = partial_merge_main.fragments_to_arrays(inputs, target) + + self.assertIsInstance(result_tree['param'], jax.Array) + self.assertEqual(result_tree['param'].shape, shape) + self.assertEqual(result_tree['param'].dtype, dtype) + + # Force materialization to trigger the callback + result_data = np.array(result_tree['param']) + np.testing.assert_array_equal(result_data, mock_fragment_data) + + def test_resolve_pytree_path(self): + """Tests path resolution logic.""" + with self.subTest('valid_path'): + temp_dir = self.create_tempdir() + path = epath.Path(temp_dir.full_path) + ( + path / partial_merge_main.checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY + ).mkdir() + + result = partial_merge_main.resolve_pytree_path(path) + self.assertEqual( + result, + path / partial_merge_main.checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY, + ) + + with self.subTest('invalid_path'): + temp_dir = self.create_tempdir('invalid_path') + path = epath.Path(temp_dir.full_path) + + with self.assertRaisesRegex(ValueError, 'does not contain a pytree'): + partial_merge_main.resolve_pytree_path(path) + + @flagsaver.flagsaver + def test_main(self): + ckpt_paths = [ + self.directory / 'ckpt0', + self.directory / 'ckpt1', + self.directory / 'ckpt2', + ] + out_path = self.directory / 'out' + + trees = [ + { + 'a': self.pytree['a'], + 'b': self.pytree['b'], + 'c': { + 'a': self.pytree['c']['a'], + 'e': self.pytree['c']['e'], + }, + # skip 'x' and 'y' + }, + { + 'a': jax.tree.map(lambda x: x * 2, self.pytree['a']), + # skip 'b' + 'c': { + 'a': jax.tree.map(lambda x: x * 2, self.pytree['c']['a']), + # skip 'c.e' + }, + }, + { + # skip 'a' and 'b' + 'c': { + 'a': jax.tree.map(lambda x: x * 3, self.pytree['c']['a']), + # skip 'c.e' + }, + 'x': jax.tree.map(lambda x: x * 3, self.pytree['x']), + 'y': jax.tree.map(lambda x: x * 3, self.pytree['y']), + }, + ] + + for path, pytree in zip(ckpt_paths, trees): + saving_lib.save_pytree(path, pytree) + + expected_tree = { + 'a': trees[1]['a'], + 'b': trees[0]['b'], + 'c': { + 'a': trees[2]['c']['a'], + 'e': trees[0]['c']['e'], + }, + 'x': trees[2]['x'], + 'y': trees[2]['y'], + } + abstract_expected_tree = jax.tree.map( + tree_utils.to_shape_dtype_struct, expected_tree + ) + + FLAGS.in_paths = [str(path) for path in ckpt_paths] + FLAGS.out_path = str(out_path) + FLAGS.per_host_memory_limit_gb = 1 + + partial_merge_main.main() + + # Check that the merged checkpoint exists and has the correct contents. + merged_ckpt = loading_lib.load_pytree(out_path, abstract_expected_tree) + test_utils.assert_tree_equal(self, merged_ckpt, expected_tree) + + +if __name__ == '__main__': + # Initialize required flags with dummy values before googletest.main() + # if they are marked as required in the binary. + FLAGS.in_paths = ['default_in'] + FLAGS.out_path = 'default_out' + multiprocess_test.main()