Skip to content

Commit 1c8199e

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
parent CL, DO NOT SUBMIT
PiperOrigin-RevId: 840914090
1 parent ecb86c8 commit 1c8199e

File tree

3 files changed

+702
-0
lines changed

3 files changed

+702
-0
lines changed

checkpoint/orbax/checkpoint/_src/tree/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,19 @@ def select_by_tree_path(
472472
raise ValueError(f'Path {path} does not exist in {tree=}.')
473473
case ():
474474
return tree
475+
476+
477+
def get_leaf_by_keypath(tree: PyTree, keypath: PyTreePath) -> Any:
478+
"""Returns the leaf value at the given keypath."""
479+
node = tree
480+
for key in keypath:
481+
node = node[key]
482+
return node
483+
484+
485+
def set_leaf_by_keypath(tree: PyTree, keypath: PyTreePath, value: Any) -> Any:
486+
"""Sets the leaf value at the given keypath."""
487+
parent = tree
488+
for key in keypath[:-1]:
489+
parent = parent[key]
490+
parent[keypath[-1]] = value
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Partial merge binary."""
16+
17+
import collections
18+
import dataclasses
19+
import random
20+
from typing import Any, Iterator, List
21+
22+
from absl import app
23+
from absl import flags
24+
from etils import epath
25+
import jax
26+
import numpy as np
27+
from orbax.checkpoint._src.arrays import fragments as array_fragments
28+
from orbax.checkpoint._src.arrays import sharding as array_sharding
29+
from orbax.checkpoint._src.tree import parts_of
30+
from orbax.checkpoint._src.tree import structure_utils
31+
from orbax.checkpoint._src.tree import utils as tree_utils
32+
from orbax.checkpoint.experimental.model_surgery import source_checkpoint
33+
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
34+
from orbax.checkpoint.experimental.v1._src.partial import saving as partial_saving
35+
36+
# Note: Ensure you have access to numpy_utils if copying _from_fragments,
37+
# otherwise rely on public APIs if available.
38+
from .learning.deepmind.jax.roc import numpy_utils
39+
from .learning.deepmind.jax.roc.experimental import eval_fragments
40+
41+
FLAGS = flags.FLAGS
42+
43+
_IN_PATHS = flags.DEFINE_multi_string(
44+
'in_paths',
45+
None,
46+
'Paths of checkpoints to merge.',
47+
required=True,
48+
)
49+
_OUT_PATH = flags.DEFINE_string(
50+
'out_path',
51+
None,
52+
'Output checkpoint path.',
53+
required=True,
54+
)
55+
_PER_HOST_MEMORY_LIMIT_GB = flags.DEFINE_integer(
56+
'per_host_memory_limit_gb',
57+
16,
58+
'Memory limit in GB per CPU host for partial loading and saving.'
59+
' Non-uniform memory limits are not supported.',
60+
)
61+
62+
PyTree = Any
63+
Keypath = tuple[Any, ...]
64+
PartsOf = parts_of.PartsOf
65+
66+
67+
def fragments_to_arrays(
68+
fragments_or_arrays: PyTree,
69+
target: PyTree,
70+
) -> PyTree:
71+
"""Creates jax.Array from a tree of Fragments."""
72+
73+
def _to_jax_array(frags_or_arr, abstract_target):
74+
if not isinstance(frags_or_arr, eval_fragments.ConcreteFragments):
75+
return frags_or_arr
76+
77+
def extract_shard(idx) -> jax.Array:
78+
idx = numpy_utils.resolve_slice(idx, abstract_target.shape)
79+
shard_data = eval_fragments._extract_fragment( # pylint: disable=protected-access
80+
frags_or_arr.fragments,
81+
eval_fragments.AbstractFragment(index=idx),
82+
).value
83+
assert shard_data is not None
84+
return jax.numpy.asarray(shard_data)
85+
86+
sharding = abstract_target.sharding
87+
return jax.make_array_from_callback(
88+
abstract_target.shape, sharding, extract_shard
89+
)
90+
91+
return jax.tree.map(_to_jax_array, fragments_or_arrays, target)
92+
93+
94+
@dataclasses.dataclass(frozen=True)
95+
class FragmentInfo:
96+
"""Information about a fragment to be used for batching."""
97+
98+
ckpt_idx: int
99+
keypath: Keypath
100+
fragment: array_fragments.AbstractFragment
101+
dtype: np.dtype
102+
103+
@property
104+
def size_bytes(self) -> int:
105+
return self.fragment.nbytes_astype(self.dtype)
106+
107+
108+
def merge_transform_fn(*args: PyTree) -> PyTree:
109+
"""Merges trees, overwriting existing keys."""
110+
return structure_utils.merge_trees(*args, overwrite=True)
111+
112+
113+
def batch_fragments(
114+
fragment_infos: list[FragmentInfo], memory_limit_gb: int
115+
) -> Iterator[list[FragmentInfo]]:
116+
"""Groups leaves into batches based on memory availability."""
117+
memory_limit_bytes = memory_limit_gb * 1024**3
118+
current_batch_leaves = []
119+
current_batch_size = 0
120+
121+
for finfo in fragment_infos:
122+
if finfo.size_bytes > memory_limit_bytes:
123+
raise ValueError(
124+
f'Fragment size {finfo.size_bytes} is larger than memory limit.'
125+
)
126+
127+
if current_batch_size + finfo.size_bytes > memory_limit_bytes:
128+
# Yield the current batch and start a new one.
129+
yield current_batch_leaves
130+
current_batch_leaves = [finfo]
131+
current_batch_size = finfo.size_bytes
132+
else:
133+
# Add the leaf to the current batch.
134+
current_batch_leaves.append(finfo)
135+
current_batch_size += finfo.size_bytes
136+
137+
if current_batch_leaves:
138+
# Yield the final batch.
139+
yield current_batch_leaves
140+
141+
142+
def resolve_pytree_path(path: epath.Path) -> epath.Path:
143+
"""Resolves the path to the pytree checkpoint."""
144+
if not (path / checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY).exists():
145+
raise ValueError(f'Path {path} does not contain a pytree checkpoint.')
146+
147+
return path / checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
148+
149+
150+
def resolve_target_structure(
151+
abstract_sources: list[PyTree], host_cpus: list[jax.Device]
152+
) -> PyTree:
153+
"""Resolves output structure and output sharding."""
154+
abstract_target = jax.eval_shape(merge_transform_fn, *abstract_sources)
155+
156+
shardings = array_sharding.construct_maximal_shardings(
157+
abstract_target, devices=host_cpus
158+
)
159+
sharded_abstract_target = jax.tree.map(
160+
lambda x, s: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=s),
161+
abstract_target,
162+
shardings,
163+
)
164+
165+
return sharded_abstract_target
166+
167+
168+
def resolve_merge_topology(
169+
sharded_abstract_target: PyTree, abstract_sources: list[PyTree]
170+
) -> tuple[PyTree, Any]:
171+
"""Uses Model Surgery to resolve topology."""
172+
173+
# Determine Fragments
174+
abstract_fragments_to_load = jax.tree.map(
175+
array_fragments.abstract_fragments, sharded_abstract_target
176+
)
177+
178+
# The "Surgery": Map inputs to outputs
179+
return eval_fragments.eval_fragments(
180+
merge_transform_fn,
181+
abstract_sources,
182+
abstract_fragments_to_load,
183+
)
184+
185+
186+
def create_fragment_infos(required_input_fragments: Any) -> list[FragmentInfo]:
187+
"""Flattens fragments into FragmentInfos for batching."""
188+
fragment_infos = []
189+
for ckpt_idx, fragments_tree in enumerate(required_input_fragments):
190+
flat_fragments = tree_utils.to_flat_dict(fragments_tree)
191+
ckpt_fragment_infos = []
192+
193+
for keypath, fragments in flat_fragments.items():
194+
for fragment in fragments.fragments:
195+
ckpt_fragment_infos.append(
196+
FragmentInfo(
197+
ckpt_idx=ckpt_idx,
198+
keypath=keypath,
199+
fragment=fragment,
200+
dtype=fragments.dtype,
201+
)
202+
)
203+
204+
# Randomize the order of leaves *within* this checkpoint. This helps mix
205+
# large and small arrays in batches to avoid wasting batch space.
206+
random.shuffle(ckpt_fragment_infos)
207+
fragment_infos.extend(ckpt_fragment_infos)
208+
return fragment_infos
209+
210+
211+
def load_batch_fragments(
212+
abstract_sources: list[PyTree],
213+
batch_fragments_map: dict[
214+
int, dict[tuple[Any, ...], list[array_fragments.AbstractFragment]]
215+
],
216+
source_checkpoints: list[source_checkpoint.SourceCheckpoint],
217+
memory_limit_gb: int,
218+
) -> list[PyTree]:
219+
"""Loads fragments for a batch."""
220+
loaded_fragments = []
221+
# Reconstruct trees for loading
222+
for i, abstract_source in enumerate(abstract_sources):
223+
# We need to construct a request tree that matches the source structure
224+
# but only contains the fragments for this batch.
225+
226+
def _get_fragments_for_leaf(
227+
path, meta, keypath_fragments=batch_fragments_map[i]
228+
):
229+
# Convert JAX KeyPath to tuple for dict lookup
230+
path_tuple = tree_utils.tuple_path_from_keypath(path)
231+
232+
frags = keypath_fragments.get(path_tuple)
233+
234+
if frags:
235+
return array_fragments.AbstractFragments(
236+
shape=meta.shape,
237+
dtype=meta.dtype, # Use source dtype
238+
fragments=frags,
239+
)
240+
return array_fragments.AbstractFragments(
241+
shape=meta.shape, dtype=meta.dtype, fragments=[]
242+
)
243+
244+
source_request_tree = jax.tree_util.tree_map_with_path(
245+
_get_fragments_for_leaf, abstract_source
246+
)
247+
248+
loaded_fragments.append(
249+
source_checkpoints[i].load_fragments(
250+
source_request_tree, concurrent_gb=memory_limit_gb
251+
)
252+
)
253+
return loaded_fragments
254+
255+
256+
def main(argv: List[str] | None = None) -> None:
257+
if argv is not None and len(argv) > 1:
258+
raise app.UsageError(f'Too many command-line arguments: {argv[1:]}')
259+
260+
all_cpus = jax.devices('cpu')
261+
host_cpus = all_cpus[: jax.process_count()]
262+
263+
random.seed(0)
264+
265+
ckpts_to_merge = [epath.Path(path) for path in _IN_PATHS.value]
266+
merged_ckpt_path = epath.Path(_OUT_PATH.value)
267+
268+
# Load metadata for all input checkpoints to understand their structure and
269+
# contents.
270+
source_checkpoints = [
271+
source_checkpoint.checkpoint_at(resolve_pytree_path(path))
272+
for path in ckpts_to_merge
273+
]
274+
abstract_sources = [sc.metadata for sc in source_checkpoints]
275+
276+
# Determine the structure and sharding of the final merged checkpoint. This
277+
# acts as the blueprint for the output, derived by merging the metadata of the
278+
# input checkpoints.
279+
sharded_abstract_target = resolve_target_structure(
280+
abstract_sources, host_cpus
281+
)
282+
283+
# Plan the merge operation by identifying exactly which data fragments need to
284+
# be read from the inputs to construct the output. This also prepares a
285+
# transformation function to assemble the loaded data.
286+
required_input_fragments, fragment_transform_fn = resolve_merge_topology(
287+
sharded_abstract_target, abstract_sources
288+
)
289+
290+
# Prepare for execution by flattening the required data fragments into a list
291+
# of tasks. This allows us to process the merge in memory-constrained batches.
292+
fragment_infos = create_fragment_infos(required_input_fragments)
293+
294+
for batch in batch_fragments(fragment_infos, _PER_HOST_MEMORY_LIMIT_GB.value):
295+
# Group the fragments in the current batch by their source checkpoint and
296+
# original keypath.
297+
batch_fragments_map = collections.defaultdict(
298+
lambda: collections.defaultdict(list)
299+
)
300+
for finfo in batch:
301+
batch_fragments_map[finfo.ckpt_idx][finfo.keypath].append(finfo.fragment)
302+
303+
# Execute the load for the current batch: fetch the specific data fragments
304+
# from the source checkpoints into memory.
305+
loaded_fragments = load_batch_fragments(
306+
abstract_sources,
307+
batch_fragments_map,
308+
source_checkpoints,
309+
_PER_HOST_MEMORY_LIMIT_GB.value,
310+
)
311+
312+
# Apply the transformation function to assemble the loaded fragments into
313+
# the desired target structure.
314+
target_fragments = fragment_transform_fn(loaded_fragments)
315+
316+
# Convert the assembled fragments into concrete, sharded JAX arrays.
317+
target_tree = fragments_to_arrays(target_fragments, sharded_abstract_target)
318+
319+
# Save the current batch of merged arrays to the output checkpoint
320+
# directory.
321+
partial_saving.save_pytree(merged_ckpt_path, target_tree)
322+
323+
# Finalize the checkpoint, completing the merge process.
324+
partial_saving.finalize(merged_ckpt_path)
325+
326+
327+
if __name__ == '__main__':
328+
jax.config.config_with_absl()
329+
app.run(main)

0 commit comments

Comments
 (0)