diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 61eafb7..8e37519 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -1,6 +1,7 @@ import argparse import concurrent.futures import ctypes +import json import os import pickle import random @@ -18,7 +19,7 @@ import zmq from loguru import logger from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema -from safetensors.torch import safe_open +from safetensors.torch import _getdtype, safe_open from torch.multiprocessing.reductions import reduce_tensor from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid @@ -92,6 +93,7 @@ class ParameterMeta(BaseModel): name: str dtype: _TorchDtype shape: _TorchSize + aligned_size: int class BucketRange(NamedTuple): @@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int: def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: ret = [] for meta in metas: - size = _align_size(meta.dtype, meta.shape) + size = meta.aligned_size ret.append( { "name": meta.name, @@ -422,6 +424,7 @@ class TPMeta(BaseModel): name=parameter_name, shape=meta["shape"], dtype=meta["dtype"], + aligned_size=_align_size(meta["dtype"], meta["shape"]), ) tp_meta = tp_metas[parameter_name] if tp_meta.concat_dim != -1: @@ -431,7 +434,10 @@ class TPMeta(BaseModel): shape = list(parameter_metas[name].shape) shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size parameter_metas[name] = ParameterMeta( - name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype + name=name, + shape=torch.Size(shape), + dtype=parameter_metas[name].dtype, + aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)), ) weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])] # TODO: here concat is serial, which may be slow @@ -449,18 +455,85 @@ class TPMeta(BaseModel): return parameters -def _register_checkpoint( - *, +def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]: + def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer: + """ + safetensors format see https://huggingface.co/docs/safetensors/en/index#format. + We load the safetensors file as bytes, then parse the header manually to get parameter metas. + The actual tensor data is in the remaining bytes and is naturally aligned. + We pin the remaining bytes as the buffer, making pinning faster. + """ + + def _pin(t: torch.Tensor): + """ + Pin the memory of tensor in-place. + See: https://github.com/pytorch/pytorch/issues/32167 + """ + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) + assert r == 0, f"pin memory error, error code: {r}" + + # TODO: should only support /dev/shm? but we found files in disk also work? + size = os.stat(file_path).st_size + flag_size = 8 + t = torch.from_file(file_path, True, size, dtype=torch.uint8) + assert t.nbytes > flag_size, ( + f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}" + ) + start_pos = ( + int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False) + + flag_size + ) + header_tensor = t[flag_size:start_pos] + header = json.loads(header_tensor.numpy().tobytes()) + if "__metadata__" in header: + header.pop("__metadata__") + + metas: list[ParameterMeta] = [] + offset = 0 + try: + for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]): + start, end = meta["data_offsets"] + # safetensors format ensures offsets are aligned + assert offset == start, f"offset {offset} should be equal to start {start}" + metas.append( + ParameterMeta( + name=name, + dtype=_getdtype(meta["dtype"]), + shape=torch.Size(meta["shape"]), + aligned_size=end - start, + ) + ) + offset = end + except Exception as e: + logger.error(f"fail to parse safetensors header from {file_path}: {e}") + raise + + buffer = t[start_pos:] + assert offset == buffer.nbytes, ( + f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}" + ) + # Remove the file after successfully loading. This will avoid doubling the memory usage. + # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading. + os.remove(file_path) + _pin(buffer) + logger.info( + f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB" + ) + return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas) + + memory_buffers: list[MemoryBuffer] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files)) + return memory_buffers + + +def _normal_pin_memory( files: list[str], named_tensors: dict[str, torch.Tensor], rank: int | None = None, shared_pin_memory: list[MemoryBuffer] | None = None, ) -> list[MemoryBuffer]: - logger.info( - f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" - ) - if not files and not named_tensors: - return [] parameters = _load_checkpoint(files) if named_tensors: parameters.update(named_tensors) @@ -470,13 +543,16 @@ class MemoryBucket(BaseModel): size: int metas: list[ParameterMeta] - buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])] + buckets: list[MemoryBucket] = [] + buckets.append(MemoryBucket(size=0, metas=[])) for name, tensor in sorted(parameters.items()): size = _align_size(tensor.dtype, tensor.shape) if buckets[-1].size + size > bucket_size: assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty" buckets.append(MemoryBucket(size=0, metas=[])) - buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype)) + buckets[-1].metas.append( + ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size) + ) buckets[-1].size += size memory_buffers = [ @@ -537,6 +613,39 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): offset += size for future in concurrent.futures.as_completed(new_futures): future.result() + return memory_buffers + + +def _register_checkpoint( + *, + files: list[str], + named_tensors: dict[str, torch.Tensor], + rank: int | None = None, + shared_pin_memory: list[MemoryBuffer] | None = None, +) -> list[MemoryBuffer]: + logger.info( + f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" + ) + if not files and not named_tensors: + return [] + memory_buffers: list[MemoryBuffer] = [] + files_to_inplace_pin = [ + file + for file in files + if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108 + ] + files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin] + if files_to_normal_pin or named_tensors: + memory_buffers.extend( + _normal_pin_memory( + files=files_to_normal_pin, + named_tensors=named_tensors, + rank=rank, + shared_pin_memory=shared_pin_memory, + ) + ) + if files_to_inplace_pin: + memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank)) return memory_buffers @@ -585,7 +694,7 @@ def _gen_h2d_buckets( for idx, metas in enumerate(items.memory_buffer_metas_list): start_offset, offset = 0, 0 for meta in metas.metas: - s = _align_size(meta.dtype, meta.shape) + s = meta.aligned_size if buckets[-1][1].size + s > bucket_size: if offset - start_offset > 0: buckets[-1][1].ranges.append( @@ -867,6 +976,8 @@ def register_checkpoint( ) -> None: """ Register a checkpoint to the parameter server. Both files and named_tensors will be registered together. + Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning. + Please make sure to copy the files to disks if you need to keep them. Args: checkpoint_name: The name of the checkpoint. @@ -1130,7 +1241,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, for items in self._current_global_parameter_metas.values(): for metas_list in items.memory_buffer_metas_list: for meta in metas_list.metas: - max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape)) + max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size) free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer: self._logger_rank0(f"[rank{self._rank}] use h2d buffer") diff --git a/examples/update.py b/examples/update.py index af8fe69..51cb189 100644 --- a/examples/update.py +++ b/examples/update.py @@ -100,8 +100,9 @@ def update_weights( update_method: Literal["broadcast", "p2p", "all"] = "broadcast", uds: str | None = None, ): - ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors) ps.init_process_group() + dist.barrier() + ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors) check_vllm_ready(endpoint, inference_parallel_size, uds) dist.barrier() with timer("Gather metas"): @@ -173,7 +174,9 @@ def join( args.uds, ) else: - if os.path.exists(os.path.join(args.checkpoint_path, "model.safetensors.index.json")): + if os.path.exists( + os.path.join(args.checkpoint_path, "model.safetensors.index.json") + ) and not args.checkpoint_path.startswith("/dev/shm/"): # noqa: S108 named_tensors = split_tensors(args.checkpoint_path, rank, world_size) checkpoint_files = [] else: diff --git a/tests/test_update.py b/tests/test_update.py index d13e0e1..97b7f60 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]): try: trigger_error(socket_paths) except RuntimeError as e: - assert str(e) == "Failed to update weights due to remote errors" + assert str(e) == "Some workers failed to update weights" def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue): @@ -96,7 +96,7 @@ def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor for name, weight in weights: if name not in named_tensors: continue - assert (weight == named_tensors[name]).all() + assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!" names_to_check[name] = True def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]): @@ -163,6 +163,66 @@ def run( assert proc.exitcode == 0 +def run_with_files( + checker_func: callable, +): + rank = int(os.getenv("RANK")) + ctx = get_context("spawn") + queue = ctx.Queue() + _device_uuid = _get_physical_gpu_id(device_manager, rank) + ps = ParameterServer(auto_pg=True) + _device_uuid = _get_physical_gpu_id(ps.device_manager, rank) + named_tensors = dict(gen_test_tensors(rank)) + + # Save 1/3 tensors to /dev/shm/ as .safetensors files + # Save 1/3 tensors to ./tmp (disk) as .safetensors files + # Keep 1/3 tensors in memory + import safetensors.torch + + files = [] + dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108 + disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108 + os.makedirs(dev_shm_dir, exist_ok=True) + os.makedirs(disk_dir, exist_ok=True) + tensors_items = list(named_tensors.items()) + tensors_in_dev_shm = named_tensors + tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2]) + tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3]) + tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :]) + disk_files = [ + os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors") + for _rank in range(get_world_size()) + ] + safetensors.torch.save_file(tensors_in_disk, disk_files[rank]) + time.sleep(1) + files.append(disk_files[rank]) + dev_shm_files = [ + os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors") + for _ in range(get_world_size()) + ] + safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank]) + time.sleep(1) + files.append(dev_shm_files[rank]) + + checkpoint_name = "test_with_files" + proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue)) + proc.start() + ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files) + ps.gather_metas(checkpoint_name) + ps.update(checkpoint_name, queue.put, ranks=[]) + # sleep 3s to wait process group is destroyed + time.sleep(3) + ps.unregister_checkpoint(checkpoint_name) + queue.put(None) + proc.join() + if rank == 0: + import shutil + + os.removedirs(dev_shm_dir) + shutil.rmtree(disk_dir) + assert proc.exitcode == 0 + + @pytest.mark.gpu @pytest.mark.parametrize( "test_name,rank_list", @@ -211,6 +271,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None): assert result.returncode == 0 +@pytest.mark.gpu +def test_update_with_files(test_name: str = "test_with_files"): + world_size = device_manager.device_module.device_count() + assert world_size >= 2, "This test requires at least 2 GPUs." + master_addr = "localhost" + master_port = 25400 + cmd = [ + "torchrun", + "--nproc_per_node", + str(world_size), + "--master_addr", + master_addr, + "--master_port", + str(master_port), + __file__, + test_name, + "[]", + ] + + result = subprocess.run( # noqa: S603 + cmd, + capture_output=False, + text=True, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + shell=False, + check=False, + ) + + assert result.returncode == 0 + + if __name__ == "__main__": run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ if not run_with_pytest: @@ -230,5 +321,7 @@ def test_update(test_name: str, rank_list: list[list[int]] | None): expected_exception=RuntimeError, exception_msg="Failed to update weights due to remote errors", ) + elif test_type == "test_with_files": + run_with_files(checker_proc) else: raise ValueError(f"Unknown TEST_TYPE: {test_type}")