From 8b0bb102ff702f597ceb133edc1234ce62bcbf37 Mon Sep 17 00:00:00 2001 From: specture724 Date: Tue, 2 Dec 2025 12:23:05 +0000 Subject: [PATCH 01/10] feat: inplace pin memory for safetensors in /dev/shm/ --- checkpoint_engine/ps.py | 115 +++++++++++++++++++++++++++++++--------- 1 file changed, 91 insertions(+), 24 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 61eafb7..aebff97 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -1,6 +1,7 @@ import argparse import concurrent.futures import ctypes +import json import os import pickle import random @@ -18,7 +19,7 @@ import zmq from loguru import logger from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema -from safetensors.torch import safe_open +from safetensors.torch import _getdtype, safe_open from torch.multiprocessing.reductions import reduce_tensor from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid @@ -461,28 +462,94 @@ def _register_checkpoint( ) if not files and not named_tensors: return [] - parameters = _load_checkpoint(files) - if named_tensors: - parameters.update(named_tensors) - bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) + memory_buffers: list[MemoryBuffer] = [] + inplace_pin = all( + file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 + for file in files or [] + ) + if inplace_pin: + + def _pin(t: torch.Tensor): + """ + Pin the memory of tensor in-place. + See: https://github.com/pytorch/pytorch/issues/32167 + """ + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) + assert r == 0, f"pin memory error, error code: {r.value}" + + def _inplace_pin_memory(file_path: str) -> MemoryBuffer: + # TODO: should only support /dev/shm? but we found files in disk also work? + size = os.stat(file_path).st_size + t = torch.from_file(file_path, True, size, dtype=torch.uint8) + + # safetensors format see https://huggingface.co/docs/safetensors/en/index#format. + # We load the safetensors file as bytes, then parse the header manually to get parameter metas. + # and the actual tensor data is in the remaining bytes. + # We pin the remaining bytes as the buffer, making pinning faster. + flag_size = 8 + with open(file_path, "rb") as f: + n = bytearray(flag_size) + data = f.readinto(n) + assert data == flag_size, f"data {data} should be equal to flag_size {flag_size}" + n = int.from_bytes(n, byteorder="little", signed=False) + start_pos = n + flag_size + + time.sleep(3) + header_tensor = t[flag_size:start_pos] + header = json.loads(header_tensor.numpy().tobytes()) + + metas: list[ParameterMeta] = [] + offset = 0 + for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): + start, end = meta["data_offsets"] + # safetensors format ensures offsets are aligned + assert offset == start, f"offset {offset} should be equal to start {start}" + metas.append( + ParameterMeta( + name=name, dtype=_getdtype(meta["dtype"]), shape=torch.Size(meta["shape"]) + ) + ) + offset = end - class MemoryBucket(BaseModel): - size: int - metas: list[ParameterMeta] - - buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])] - for name, tensor in sorted(parameters.items()): - size = _align_size(tensor.dtype, tensor.shape) - if buckets[-1].size + size > bucket_size: - assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" - buckets.append(MemoryBucket(size=0, metas=[])) - buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype)) - buckets[-1].size += size - - memory_buffers = [ - MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) - for bucket in buckets - ] + buffer = t[start_pos:] + assert offset == buffer.nbytes, ( + f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" + ) + _pin(buffer) + return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) + + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + futures = [executor.submit(_inplace_pin_memory, file) for file in files] + for future in concurrent.futures.as_completed(futures): + memory_buffer = future.result() + memory_buffers.append(memory_buffer) + + else: + parameters = _load_checkpoint(files) + if named_tensors: + parameters.update(named_tensors) + bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) + + class MemoryBucket(BaseModel): + size: int + metas: list[ParameterMeta] + + buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])] + for name, tensor in sorted(parameters.items()): + size = _align_size(tensor.dtype, tensor.shape) + if buckets[-1].size + size > bucket_size: + assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" + buckets.append(MemoryBucket(size=0, metas=[])) + buckets[-1].metas.append( + ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype) + ) + buckets[-1].size += size + + memory_buffers = [ + MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) + for bucket in buckets + ] def register_pin_memory( idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None @@ -501,8 +568,8 @@ def register_pin_memory( buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) return idx, buffer - def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): - buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) + def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): + buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: futures = [ From 3d2ad5d9e0a7aa08f606676fa72fd7199427e778 Mon Sep 17 00:00:00 2001 From: specture724 Date: Fri, 5 Dec 2025 06:00:17 +0000 Subject: [PATCH 02/10] feat: inplace pin and normal pin compatible --- checkpoint_engine/ps.py | 100 ++++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 30 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index aebff97..487d4c0 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -93,6 +93,7 @@ class ParameterMeta(BaseModel): name: str dtype: _TorchDtype shape: _TorchSize + manually_aligned: bool = True class BucketRange(NamedTuple): @@ -141,7 +142,11 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int: def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: ret = [] for meta in metas: - size = _align_size(meta.dtype, meta.shape) + size = ( + _align_size(meta.dtype, meta.shape) + if meta.manually_aligned + else meta.dtype.itemsize * meta.shape.numel() + ) ret.append( { "name": meta.name, @@ -463,12 +468,8 @@ def _register_checkpoint( if not files and not named_tensors: return [] memory_buffers: list[MemoryBuffer] = [] - inplace_pin = all( - file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 - for file in files or [] - ) - if inplace_pin: + def inplace_pin_memory(files: list[str]) -> list[MemoryBuffer]: def _pin(t: torch.Tensor): """ Pin the memory of tensor in-place. @@ -495,6 +496,7 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer: n = int.from_bytes(n, byteorder="little", signed=False) start_pos = n + flag_size + os.remove(file_path) time.sleep(3) header_tensor = t[flag_size:start_pos] header = json.loads(header_tensor.numpy().tobytes()) @@ -507,7 +509,10 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer: assert offset == start, f"offset {offset} should be equal to start {start}" metas.append( ParameterMeta( - name=name, dtype=_getdtype(meta["dtype"]), shape=torch.Size(meta["shape"]) + name=name, + dtype=_getdtype(meta["dtype"]), + shape=torch.Size(meta["shape"]), + manually_aligned=False, ) ) offset = end @@ -519,13 +524,24 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer: _pin(buffer) return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) + local_memory_buffers: list[MemoryBuffer] = [] + lock = threading.Lock() + idx = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: futures = [executor.submit(_inplace_pin_memory, file) for file in files] for future in concurrent.futures.as_completed(futures): memory_buffer = future.result() - memory_buffers.append(memory_buffer) + with lock: + local_memory_buffers.append(memory_buffer) + logger.info( + f"[rank{rank}] register pin_memory for file in /dev/shm {idx + 1}/{len(files)} finished" + ) + idx += 1 + return local_memory_buffers - else: + def normal_pin_memory( + files: list[str], named_tensors: dict[str, torch.Tensor] + ) -> list[MemoryBuffer]: parameters = _load_checkpoint(files) if named_tensors: parameters.update(named_tensors) @@ -535,7 +551,8 @@ class MemoryBucket(BaseModel): size: int metas: list[ParameterMeta] - buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])] + buckets: list[MemoryBucket] = [] + buckets.append(MemoryBucket(size=0, metas=[])) for name, tensor in sorted(parameters.items()): size = _align_size(tensor.dtype, tensor.shape) if buckets[-1].size + size > bucket_size: @@ -546,27 +563,27 @@ class MemoryBucket(BaseModel): ) buckets[-1].size += size - memory_buffers = [ + local_memory_buffers = [ MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) for bucket in buckets ] - def register_pin_memory( - idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None - ) -> tuple[int, torch.Tensor]: - if shared_pin_memory: - # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one - # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time - assert idx < len(shared_pin_memory), ( - f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" - ) - assert shared_pin_memory[idx].size == size, ( - f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" - ) - return idx, shared_pin_memory[idx].buffer - else: - buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) - return idx, buffer + def register_pin_memory( + idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None + ) -> tuple[int, torch.Tensor]: + if shared_pin_memory: + # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one + # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time + assert idx < len(shared_pin_memory), ( + f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" + ) + assert shared_pin_memory[idx].size == size, ( + f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" + ) + return idx, shared_pin_memory[idx].buffer + else: + buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) + return idx, buffer def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) @@ -587,7 +604,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): assert buffer.numel() == buckets[idx].size, ( f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}" ) - memory_buffers[idx].buffer = buffer + local_memory_buffers[idx].buffer = buffer logger.info( f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, " f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" @@ -604,6 +621,20 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): offset += size for future in concurrent.futures.as_completed(new_futures): future.result() + return local_memory_buffers + + files_to_inplace_pin = [ + file + for file in files + if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 + ] + files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin] + if files_to_normal_pin or named_tensors: + memory_buffers.extend( + normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors) + ) + if files_to_inplace_pin: + memory_buffers.extend(inplace_pin_memory(files_to_inplace_pin)) return memory_buffers @@ -652,7 +683,11 @@ def _gen_h2d_buckets( for idx, metas in enumerate(items.memory_buffer_metas_list): start_offset, offset = 0, 0 for meta in metas.metas: - s = _align_size(meta.dtype, meta.shape) + s = ( + _align_size(meta.dtype, meta.shape) + if meta.manually_aligned + else meta.dtype.itemsize * meta.shape.numel() + ) if buckets[-1][1].size + s > bucket_size: if offset - start_offset > 0: buckets[-1][1].ranges.append( @@ -1197,7 +1232,12 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, for items in self._current_global_parameter_metas.values(): for metas_list in items.memory_buffer_metas_list: for meta in metas_list.metas: - max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape)) + max_tensor_bytes = max( + max_tensor_bytes, + _align_size(meta.dtype, meta.shape) + if meta.manually_aligned + else meta.dtype.itemsize * meta.shape.numel(), + ) free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer: self._logger_rank0(f"[rank{self._rank}] use h2d buffer") From 20a8bf5f04898bfa32b18d65d6b75574f15c10ab Mon Sep 17 00:00:00 2001 From: specture724 Date: Fri, 5 Dec 2025 06:01:23 +0000 Subject: [PATCH 03/10] feat: inplace-pin-memory need synchronization barrier --- examples/update.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/update.py b/examples/update.py index af8fe69..51cb189 100644 --- a/examples/update.py +++ b/examples/update.py @@ -100,8 +100,9 @@ def update_weights( update_method: Literal["broadcast", "p2p", "all"] = "broadcast", uds: str | None = None, ): - ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors) ps.init_process_group() + dist.barrier() + ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors) check_vllm_ready(endpoint, inference_parallel_size, uds) dist.barrier() with timer("Gather metas"): @@ -173,7 +174,9 @@ def join( args.uds, ) else: - if os.path.exists(os.path.join(args.checkpoint_path, "model.safetensors.index.json")): + if os.path.exists( + os.path.join(args.checkpoint_path, "model.safetensors.index.json") + ) and not args.checkpoint_path.startswith("/dev/shm/"): # noqa: S108 named_tensors = split_tensors(args.checkpoint_path, rank, world_size) checkpoint_files = [] else: From df3bd0e524384dc31dc2ebce9c163f7d51cd43af Mon Sep 17 00:00:00 2001 From: specture724 Date: Fri, 5 Dec 2025 06:21:16 +0000 Subject: [PATCH 04/10] feat: test for inplace pin memory added --- tests/test_update.py | 92 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/tests/test_update.py b/tests/test_update.py index d13e0e1..9650cbe 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]): try: trigger_error(socket_paths) except RuntimeError as e: - assert str(e) == "Failed to update weights due to remote errors" + assert str(e) == "Some workers failed to update weights" def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue): @@ -96,7 +96,7 @@ def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor for name, weight in weights: if name not in named_tensors: continue - assert (weight == named_tensors[name]).all() + assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!" names_to_check[name] = True def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]): @@ -163,6 +163,61 @@ def run( assert proc.exitcode == 0 +def run_with_files( + checker_func: callable, +): + rank = int(os.getenv("RANK")) + ctx = get_context("spawn") + queue = ctx.Queue() + _device_uuid = _get_physical_gpu_id(device_manager, rank) + ps = ParameterServer(auto_pg=True) + _device_uuid = _get_physical_gpu_id(ps.device_manager, rank) + named_tensors = dict(gen_test_tensors(rank)) + + # Save 1/3 tensors to /dev/shm/ as .safetensors files + # Save 1/3 tensors to ./tmp (disk) as .safetensors files + # Keep 1/3 tensors in memory + import safetensors.torch + + files = [] + dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108 + disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108 + os.makedirs(dev_shm_dir, exist_ok=True) + os.makedirs(disk_dir, exist_ok=True) + tensors_items = list(named_tensors.items()) + tensors_in_dev_shm = named_tensors + tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2]) + tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3]) + tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :]) + disk_files = [ + os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors") + for _rank in range(get_world_size()) + ] + safetensors.torch.save_file(tensors_in_disk, disk_files[rank]) + time.sleep(1) + files.append(disk_files[rank]) + dev_shm_files = [ + os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors") + for _ in range(get_world_size()) + ] + safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank]) + time.sleep(1) + files.append(dev_shm_files[rank]) + + checkpoint_name = "test_with_files" + proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue)) + proc.start() + ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files) + ps.gather_metas(checkpoint_name) + ps.update(checkpoint_name, queue.put, ranks=[]) + # sleep 3s to wait process group is destroyed + time.sleep(3) + ps.unregister_checkpoint(checkpoint_name) + queue.put(None) + proc.join() + assert proc.exitcode == 0 + + @pytest.mark.gpu @pytest.mark.parametrize( "test_name,rank_list", @@ -211,6 +266,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None): assert result.returncode == 0 +@pytest.mark.gpu +def test_update_with_files(test_name: str = "test_with_files"): + world_size = device_manager.device_module.device_count() + assert world_size >= 2, "This test requires at least 2 GPUs." + master_addr = "localhost" + master_port = 25400 + cmd = [ + "torchrun", + "--nproc_per_node", + str(world_size), + "--master_addr", + master_addr, + "--master_port", + str(master_port), + __file__, + test_name, + "[]", + ] + + result = subprocess.run( # noqa: S603 + cmd, + capture_output=False, + text=True, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + shell=False, + check=False, + ) + + assert result.returncode == 0 + + if __name__ == "__main__": run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ if not run_with_pytest: @@ -230,5 +316,7 @@ def test_update(test_name: str, rank_list: list[list[int]] | None): expected_exception=RuntimeError, exception_msg="Failed to update weights due to remote errors", ) + elif test_type == "test_with_files": + run_with_files(checker_proc) else: raise ValueError(f"Unknown TEST_TYPE: {test_type}") From 9794cf318b3d7f610f2a41df455e0ae552d67644 Mon Sep 17 00:00:00 2001 From: specture724 Date: Fri, 5 Dec 2025 07:26:20 +0000 Subject: [PATCH 05/10] feat: add header format check and key `__metadata__` ignored --- checkpoint_engine/ps.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 487d4c0..23b1417 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -19,7 +19,7 @@ import zmq from loguru import logger from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema -from safetensors.torch import _getdtype, safe_open +from safetensors.torch import _TYPES, _getdtype, safe_open from torch.multiprocessing.reductions import reduce_tensor from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid @@ -500,22 +500,28 @@ def _inplace_pin_memory(file_path: str) -> MemoryBuffer: time.sleep(3) header_tensor = t[flag_size:start_pos] header = json.loads(header_tensor.numpy().tobytes()) + if "__metadata__" in header: + header.pop("__metadata__") metas: list[ParameterMeta] = [] offset = 0 - for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): - start, end = meta["data_offsets"] - # safetensors format ensures offsets are aligned - assert offset == start, f"offset {offset} should be equal to start {start}" - metas.append( - ParameterMeta( - name=name, - dtype=_getdtype(meta["dtype"]), - shape=torch.Size(meta["shape"]), - manually_aligned=False, + try: + for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): + start, end = meta["data_offsets"] + # safetensors format ensures offsets are aligned + assert offset == start, f"offset {offset} should be equal to start {start}" + metas.append( + ParameterMeta( + name=name, + dtype=_getdtype(meta["dtype"]), + shape=torch.Size(meta["shape"]), + manually_aligned=False, + ) ) - ) - offset = end + offset = end + except Exception as e: + logger.error(f"fail to parse safetensors header from {file_path}: {e}") + raise buffer = t[start_pos:] assert offset == buffer.nbytes, ( From ded28655eef5a94c4a47572c9cc9600cac89ac22 Mon Sep 17 00:00:00 2001 From: specture724 Date: Fri, 5 Dec 2025 08:02:00 +0000 Subject: [PATCH 06/10] fix: fix PR issues --- checkpoint_engine/ps.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 23b1417..6f63ea9 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -19,7 +19,7 @@ import zmq from loguru import logger from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema -from safetensors.torch import _TYPES, _getdtype, safe_open +from safetensors.torch import _getdtype, safe_open from torch.multiprocessing.reductions import reduce_tensor from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid @@ -477,27 +477,27 @@ def _pin(t: torch.Tensor): """ cudart = torch.cuda.cudart() r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) - assert r == 0, f"pin memory error, error code: {r.value}" + assert r == 0, f"pin memory error, error code: {r}" def _inplace_pin_memory(file_path: str) -> MemoryBuffer: + """ + safetensors format see https://huggingface.co/docs/safetensors/en/index#format. + We load the safetensors file as bytes, then parse the header manually to get parameter metas. + The actual tensor data is in the remaining bytes and is naturally aligned. + We pin the remaining bytes as the buffer, making pinning faster. + """ # TODO: should only support /dev/shm? but we found files in disk also work? size = os.stat(file_path).st_size - t = torch.from_file(file_path, True, size, dtype=torch.uint8) - - # safetensors format see https://huggingface.co/docs/safetensors/en/index#format. - # We load the safetensors file as bytes, then parse the header manually to get parameter metas. - # and the actual tensor data is in the remaining bytes. - # We pin the remaining bytes as the buffer, making pinning faster. flag_size = 8 - with open(file_path, "rb") as f: - n = bytearray(flag_size) - data = f.readinto(n) - assert data == flag_size, f"data {data} should be equal to flag_size {flag_size}" - n = int.from_bytes(n, byteorder="little", signed=False) - start_pos = n + flag_size - + t = torch.from_file(file_path, True, size, dtype=torch.uint8) + assert t.nbytes > flag_size, ( + f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}" + ) os.remove(file_path) - time.sleep(3) + start_pos = ( + int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False) + + flag_size + ) header_tensor = t[flag_size:start_pos] header = json.loads(header_tensor.numpy().tobytes()) if "__metadata__" in header: From 892aea5baea3f546bee98878af0e41b209595224 Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 8 Dec 2025 10:28:15 +0000 Subject: [PATCH 07/10] feat: remove temp files in test --- tests/test_update.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_update.py b/tests/test_update.py index 9650cbe..97b7f60 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -215,6 +215,11 @@ def run_with_files( ps.unregister_checkpoint(checkpoint_name) queue.put(None) proc.join() + if rank == 0: + import shutil + + os.removedirs(dev_shm_dir) + shutil.rmtree(disk_dir) assert proc.exitcode == 0 From b7906db564fc00bb9e3c41b295323fe9fde7425f Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 8 Dec 2025 10:29:26 +0000 Subject: [PATCH 08/10] fix: fix PR issues --- checkpoint_engine/ps.py | 241 +++++++++++++++++++--------------------- 1 file changed, 115 insertions(+), 126 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 6f63ea9..83faa67 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -93,7 +93,7 @@ class ParameterMeta(BaseModel): name: str dtype: _TorchDtype shape: _TorchSize - manually_aligned: bool = True + aligned_size: int class BucketRange(NamedTuple): @@ -142,11 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int: def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: ret = [] for meta in metas: - size = ( - _align_size(meta.dtype, meta.shape) - if meta.manually_aligned - else meta.dtype.itemsize * meta.shape.numel() - ) + size = meta.aligned_size ret.append( { "name": meta.name, @@ -428,6 +424,7 @@ class TPMeta(BaseModel): name=parameter_name, shape=meta["shape"], dtype=meta["dtype"], + aligned_size=_align_size(meta["dtype"], meta["shape"]), ) tp_meta = tp_metas[parameter_name] if tp_meta.concat_dim != -1: @@ -437,7 +434,10 @@ class TPMeta(BaseModel): shape = list(parameter_metas[name].shape) shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size parameter_metas[name] = ParameterMeta( - name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype + name=name, + shape=torch.Size(shape), + dtype=parameter_metas[name].dtype, + aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)), ) weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])] # TODO: here concat is serial, which may be slow @@ -455,21 +455,15 @@ class TPMeta(BaseModel): return parameters -def _register_checkpoint( - *, - files: list[str], - named_tensors: dict[str, torch.Tensor], - rank: int | None = None, - shared_pin_memory: list[MemoryBuffer] | None = None, -) -> list[MemoryBuffer]: - logger.info( - f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" - ) - if not files and not named_tensors: - return [] - memory_buffers: list[MemoryBuffer] = [] +def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]: + def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer: + """ + safetensors format see https://huggingface.co/docs/safetensors/en/index#format. + We load the safetensors file as bytes, then parse the header manually to get parameter metas. + The actual tensor data is in the remaining bytes and is naturally aligned. + We pin the remaining bytes as the buffer, making pinning faster. + """ - def inplace_pin_memory(files: list[str]) -> list[MemoryBuffer]: def _pin(t: torch.Tensor): """ Pin the memory of tensor in-place. @@ -479,100 +473,91 @@ def _pin(t: torch.Tensor): r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) assert r == 0, f"pin memory error, error code: {r}" - def _inplace_pin_memory(file_path: str) -> MemoryBuffer: - """ - safetensors format see https://huggingface.co/docs/safetensors/en/index#format. - We load the safetensors file as bytes, then parse the header manually to get parameter metas. - The actual tensor data is in the remaining bytes and is naturally aligned. - We pin the remaining bytes as the buffer, making pinning faster. - """ - # TODO: should only support /dev/shm? but we found files in disk also work? - size = os.stat(file_path).st_size - flag_size = 8 - t = torch.from_file(file_path, True, size, dtype=torch.uint8) - assert t.nbytes > flag_size, ( - f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}" - ) - os.remove(file_path) - start_pos = ( - int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False) - + flag_size - ) - header_tensor = t[flag_size:start_pos] - header = json.loads(header_tensor.numpy().tobytes()) - if "__metadata__" in header: - header.pop("__metadata__") + # TODO: should only support /dev/shm? but we found files in disk also work? + size = os.stat(file_path).st_size + flag_size = 8 + t = torch.from_file(file_path, True, size, dtype=torch.uint8) + assert t.nbytes > flag_size, ( + f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}" + ) + start_pos = ( + int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False) + + flag_size + ) + header_tensor = t[flag_size:start_pos] + header = json.loads(header_tensor.numpy().tobytes()) + if "__metadata__" in header: + header.pop("__metadata__") - metas: list[ParameterMeta] = [] - offset = 0 - try: - for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): - start, end = meta["data_offsets"] - # safetensors format ensures offsets are aligned - assert offset == start, f"offset {offset} should be equal to start {start}" - metas.append( - ParameterMeta( - name=name, - dtype=_getdtype(meta["dtype"]), - shape=torch.Size(meta["shape"]), - manually_aligned=False, - ) - ) - offset = end - except Exception as e: - logger.error(f"fail to parse safetensors header from {file_path}: {e}") - raise - - buffer = t[start_pos:] - assert offset == buffer.nbytes, ( - f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" - ) - _pin(buffer) - return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) - - local_memory_buffers: list[MemoryBuffer] = [] - lock = threading.Lock() - idx = 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - futures = [executor.submit(_inplace_pin_memory, file) for file in files] - for future in concurrent.futures.as_completed(futures): - memory_buffer = future.result() - with lock: - local_memory_buffers.append(memory_buffer) - logger.info( - f"[rank{rank}] register pin_memory for file in /dev/shm {idx + 1}/{len(files)} finished" + metas: list[ParameterMeta] = [] + offset = 0 + try: + for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): + start, end = meta["data_offsets"] + # safetensors format ensures offsets are aligned + assert offset == start, f"offset {offset} should be equal to start {start}" + metas.append( + ParameterMeta( + name=name, + dtype=_getdtype(meta["dtype"]), + shape=torch.Size(meta["shape"]), + aligned_size=end - start, ) - idx += 1 - return local_memory_buffers + ) + offset = end + except Exception as e: + logger.error(f"fail to parse safetensors header from {file_path}: {e}") + raise - def normal_pin_memory( - files: list[str], named_tensors: dict[str, torch.Tensor] - ) -> list[MemoryBuffer]: - parameters = _load_checkpoint(files) - if named_tensors: - parameters.update(named_tensors) - bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) - - class MemoryBucket(BaseModel): - size: int - metas: list[ParameterMeta] - - buckets: list[MemoryBucket] = [] - buckets.append(MemoryBucket(size=0, metas=[])) - for name, tensor in sorted(parameters.items()): - size = _align_size(tensor.dtype, tensor.shape) - if buckets[-1].size + size > bucket_size: - assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" - buckets.append(MemoryBucket(size=0, metas=[])) - buckets[-1].metas.append( - ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype) - ) - buckets[-1].size += size + buffer = t[start_pos:] + assert offset == buffer.nbytes, ( + f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" + ) + # Remove the file after successfully loading. This will avoid doubling the memory usage. + # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading. + os.remove(file_path) + _pin(buffer) + logger.info( + f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB" + ) + return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) - local_memory_buffers = [ - MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) - for bucket in buckets - ] + local_memory_buffers: list[MemoryBuffer] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + local_memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files)) + return local_memory_buffers + + +def _normal_pin_memory( + files: list[str], + named_tensors: dict[str, torch.Tensor], + rank: int | None = None, +) -> list[MemoryBuffer]: + parameters = _load_checkpoint(files) + if named_tensors: + parameters.update(named_tensors) + bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values())) + + class MemoryBucket(BaseModel): + size: int + metas: list[ParameterMeta] + + buckets: list[MemoryBucket] = [] + buckets.append(MemoryBucket(size=0, metas=[])) + for name, tensor in sorted(parameters.items()): + size = _align_size(tensor.dtype, tensor.shape) + if buckets[-1].size + size > bucket_size: + assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" + buckets.append(MemoryBucket(size=0, metas=[])) + buckets[-1].metas.append( + ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size) + ) + buckets[-1].size += size + + local_memory_buffers = [ + MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) + for bucket in buckets + ] def register_pin_memory( idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None @@ -591,8 +576,8 @@ def register_pin_memory( buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) return idx, buffer - def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): - buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) + def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): + buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: futures = [ @@ -629,6 +614,19 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): future.result() return local_memory_buffers + +def _register_checkpoint( + *, + files: list[str], + named_tensors: dict[str, torch.Tensor], + rank: int | None = None, +) -> list[MemoryBuffer]: + logger.info( + f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" + ) + if not files and not named_tensors: + return [] + memory_buffers: list[MemoryBuffer] = [] files_to_inplace_pin = [ file for file in files @@ -637,10 +635,10 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin] if files_to_normal_pin or named_tensors: memory_buffers.extend( - normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors) + _normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors, rank=rank) ) if files_to_inplace_pin: - memory_buffers.extend(inplace_pin_memory(files_to_inplace_pin)) + memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank)) return memory_buffers @@ -689,11 +687,7 @@ def _gen_h2d_buckets( for idx, metas in enumerate(items.memory_buffer_metas_list): start_offset, offset = 0, 0 for meta in metas.metas: - s = ( - _align_size(meta.dtype, meta.shape) - if meta.manually_aligned - else meta.dtype.itemsize * meta.shape.numel() - ) + s = meta.aligned_size if buckets[-1][1].size + s > bucket_size: if offset - start_offset > 0: buckets[-1][1].ranges.append( @@ -1238,12 +1232,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, for items in self._current_global_parameter_metas.values(): for metas_list in items.memory_buffer_metas_list: for meta in metas_list.metas: - max_tensor_bytes = max( - max_tensor_bytes, - _align_size(meta.dtype, meta.shape) - if meta.manually_aligned - else meta.dtype.itemsize * meta.shape.numel(), - ) + max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size) free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer: self._logger_rank0(f"[rank{self._rank}] use h2d buffer") From c380f0c85bdccf97143f7d43e672932ca84c79f2 Mon Sep 17 00:00:00 2001 From: specture724 Date: Wed, 10 Dec 2025 06:29:17 +0000 Subject: [PATCH 09/10] misc: add "files deleted" warning in doc string --- checkpoint_engine/ps.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 83faa67..35410c7 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -969,6 +969,8 @@ def register_checkpoint( ) -> None: """ Register a checkpoint to the parameter server. Both files and named_tensors will be registered together. + Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning. + Please make sure to copy the files to disks if you need to keep them. Args: checkpoint_name: The name of the checkpoint. From 4d68fb30cd0af59e49ddad81cf279d11b20c4788 Mon Sep 17 00:00:00 2001 From: specture724 Date: Thu, 11 Dec 2025 05:30:16 +0000 Subject: [PATCH 10/10] misc --- checkpoint_engine/ps.py | 53 +++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 35410c7..8e37519 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -522,16 +522,17 @@ def _pin(t: torch.Tensor): ) return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) - local_memory_buffers: list[MemoryBuffer] = [] + memory_buffers: list[MemoryBuffer] = [] with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - local_memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files)) - return local_memory_buffers + memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files)) + return memory_buffers def _normal_pin_memory( files: list[str], named_tensors: dict[str, torch.Tensor], rank: int | None = None, + shared_pin_memory: list[MemoryBuffer] | None = None, ) -> list[MemoryBuffer]: parameters = _load_checkpoint(files) if named_tensors: @@ -554,27 +555,27 @@ class MemoryBucket(BaseModel): ) buckets[-1].size += size - local_memory_buffers = [ + memory_buffers = [ MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) for bucket in buckets ] - def register_pin_memory( - idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None - ) -> tuple[int, torch.Tensor]: - if shared_pin_memory: - # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one - # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time - assert idx < len(shared_pin_memory), ( - f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" - ) - assert shared_pin_memory[idx].size == size, ( - f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" - ) - return idx, shared_pin_memory[idx].buffer - else: - buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) - return idx, buffer + def register_pin_memory( + idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None + ) -> tuple[int, torch.Tensor]: + if shared_pin_memory: + # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one + # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time + assert idx < len(shared_pin_memory), ( + f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" + ) + assert shared_pin_memory[idx].size == size, ( + f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" + ) + return idx, shared_pin_memory[idx].buffer + else: + buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) + return idx, buffer def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) @@ -595,7 +596,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): assert buffer.numel() == buckets[idx].size, ( f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}" ) - local_memory_buffers[idx].buffer = buffer + memory_buffers[idx].buffer = buffer logger.info( f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, " f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer" @@ -612,7 +613,7 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): offset += size for future in concurrent.futures.as_completed(new_futures): future.result() - return local_memory_buffers + return memory_buffers def _register_checkpoint( @@ -620,6 +621,7 @@ def _register_checkpoint( files: list[str], named_tensors: dict[str, torch.Tensor], rank: int | None = None, + shared_pin_memory: list[MemoryBuffer] | None = None, ) -> list[MemoryBuffer]: logger.info( f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" @@ -635,7 +637,12 @@ def _register_checkpoint( files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin] if files_to_normal_pin or named_tensors: memory_buffers.extend( - _normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors, rank=rank) + _normal_pin_memory( + files=files_to_normal_pin, + named_tensors=named_tensors, + rank=rank, + shared_pin_memory=shared_pin_memory, + ) ) if files_to_inplace_pin: memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))