Skip to content
139 changes: 125 additions & 14 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import concurrent.futures
import ctypes
import json
import os
import pickle
import random
Expand All @@ -18,7 +19,7 @@
import zmq
from loguru import logger
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
from safetensors.torch import safe_open
from safetensors.torch import _getdtype, safe_open
from torch.multiprocessing.reductions import reduce_tensor

from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
Expand Down Expand Up @@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
name: str
dtype: _TorchDtype
shape: _TorchSize
aligned_size: int


class BucketRange(NamedTuple):
Expand Down Expand Up @@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
ret = []
for meta in metas:
size = _align_size(meta.dtype, meta.shape)
size = meta.aligned_size
ret.append(
{
"name": meta.name,
Expand Down Expand Up @@ -422,6 +424,7 @@ class TPMeta(BaseModel):
name=parameter_name,
shape=meta["shape"],
dtype=meta["dtype"],
aligned_size=_align_size(meta["dtype"], meta["shape"]),
)
tp_meta = tp_metas[parameter_name]
if tp_meta.concat_dim != -1:
Expand All @@ -431,7 +434,10 @@ class TPMeta(BaseModel):
shape = list(parameter_metas[name].shape)
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
parameter_metas[name] = ParameterMeta(
name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
name=name,
shape=torch.Size(shape),
dtype=parameter_metas[name].dtype,
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
)
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
# TODO: here concat is serial, which may be slow
Expand All @@ -449,18 +455,85 @@ class TPMeta(BaseModel):
return parameters


def _register_checkpoint(
*,
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
"""
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
The actual tensor data is in the remaining bytes and is naturally aligned.
We pin the remaining bytes as the buffer, making pinning faster.
"""

def _pin(t: torch.Tensor):
"""
Pin the memory of tensor in-place.
See: https://github.com/pytorch/pytorch/issues/32167
"""
cudart = torch.cuda.cudart()
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
assert r == 0, f"pin memory error, error code: {r}"

# TODO: should only support /dev/shm? but we found files in disk also work?
size = os.stat(file_path).st_size
flag_size = 8
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
assert t.nbytes > flag_size, (
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
)
start_pos = (
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
+ flag_size
)
header_tensor = t[flag_size:start_pos]
header = json.loads(header_tensor.numpy().tobytes())
if "__metadata__" in header:
header.pop("__metadata__")

metas: list[ParameterMeta] = []
offset = 0
try:
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
start, end = meta["data_offsets"]
# safetensors format ensures offsets are aligned
assert offset == start, f"offset {offset} should be equal to start {start}"
metas.append(
ParameterMeta(
name=name,
dtype=_getdtype(meta["dtype"]),
shape=torch.Size(meta["shape"]),
aligned_size=end - start,
)
)
offset = end
except Exception as e:
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
raise

buffer = t[start_pos:]
assert offset == buffer.nbytes, (
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
)
# Remove the file after successfully loading. This will avoid doubling the memory usage.
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
os.remove(file_path)
_pin(buffer)
logger.info(
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
)
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)

memory_buffers: list[MemoryBuffer] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
return memory_buffers


def _normal_pin_memory(
files: list[str],
named_tensors: dict[str, torch.Tensor],
rank: int | None = None,
shared_pin_memory: list[MemoryBuffer] | None = None,
) -> list[MemoryBuffer]:
logger.info(
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
)
if not files and not named_tensors:
return []
parameters = _load_checkpoint(files)
if named_tensors:
parameters.update(named_tensors)
Expand All @@ -470,13 +543,16 @@ class MemoryBucket(BaseModel):
size: int
metas: list[ParameterMeta]

buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
buckets: list[MemoryBucket] = []
buckets.append(MemoryBucket(size=0, metas=[]))
for name, tensor in sorted(parameters.items()):
size = _align_size(tensor.dtype, tensor.shape)
if buckets[-1].size + size > bucket_size:
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
buckets.append(MemoryBucket(size=0, metas=[]))
buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
buckets[-1].metas.append(
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
)
buckets[-1].size += size

memory_buffers = [
Expand Down Expand Up @@ -537,6 +613,39 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
offset += size
for future in concurrent.futures.as_completed(new_futures):
future.result()
return memory_buffers


def _register_checkpoint(
*,
files: list[str],
named_tensors: dict[str, torch.Tensor],
rank: int | None = None,
shared_pin_memory: list[MemoryBuffer] | None = None,
) -> list[MemoryBuffer]:
logger.info(
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
)
if not files and not named_tensors:
return []
memory_buffers: list[MemoryBuffer] = []
files_to_inplace_pin = [
file
for file in files
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
]
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
if files_to_normal_pin or named_tensors:
memory_buffers.extend(
_normal_pin_memory(
files=files_to_normal_pin,
named_tensors=named_tensors,
rank=rank,
shared_pin_memory=shared_pin_memory,
)
)
if files_to_inplace_pin:
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
return memory_buffers


Expand Down Expand Up @@ -585,7 +694,7 @@ def _gen_h2d_buckets(
for idx, metas in enumerate(items.memory_buffer_metas_list):
start_offset, offset = 0, 0
for meta in metas.metas:
s = _align_size(meta.dtype, meta.shape)
s = meta.aligned_size
if buckets[-1][1].size + s > bucket_size:
if offset - start_offset > 0:
buckets[-1][1].ranges.append(
Expand Down Expand Up @@ -867,6 +976,8 @@ def register_checkpoint(
) -> None:
"""
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
Please make sure to copy the files to disks if you need to keep them.

Args:
checkpoint_name: The name of the checkpoint.
Expand Down Expand Up @@ -1130,7 +1241,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
for items in self._current_global_parameter_metas.values():
for metas_list in items.memory_buffer_metas_list:
for meta in metas_list.metas:
max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
Expand Down
7 changes: 5 additions & 2 deletions examples/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def update_weights(
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
uds: str | None = None,
):
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
ps.init_process_group()
dist.barrier()
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
check_vllm_ready(endpoint, inference_parallel_size, uds)
dist.barrier()
with timer("Gather metas"):
Expand Down Expand Up @@ -173,7 +174,9 @@ def join(
args.uds,
)
else:
if os.path.exists(os.path.join(args.checkpoint_path, "model.safetensors.index.json")):
if os.path.exists(
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
) and not args.checkpoint_path.startswith("/dev/shm/"): # noqa: S108
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
checkpoint_files = []
else:
Expand Down
97 changes: 95 additions & 2 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
try:
trigger_error(socket_paths)
except RuntimeError as e:
assert str(e) == "Failed to update weights due to remote errors"
assert str(e) == "Some workers failed to update weights"


def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
Expand All @@ -96,7 +96,7 @@ def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor
for name, weight in weights:
if name not in named_tensors:
continue
assert (weight == named_tensors[name]).all()
assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
names_to_check[name] = True

def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
Expand Down Expand Up @@ -163,6 +163,66 @@ def run(
assert proc.exitcode == 0


def run_with_files(
checker_func: callable,
):
rank = int(os.getenv("RANK"))
ctx = get_context("spawn")
queue = ctx.Queue()
_device_uuid = _get_physical_gpu_id(device_manager, rank)
ps = ParameterServer(auto_pg=True)
_device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
named_tensors = dict(gen_test_tensors(rank))

# Save 1/3 tensors to /dev/shm/ as .safetensors files
# Save 1/3 tensors to ./tmp (disk) as .safetensors files
# Keep 1/3 tensors in memory
import safetensors.torch

files = []
dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108
os.makedirs(dev_shm_dir, exist_ok=True)
os.makedirs(disk_dir, exist_ok=True)
tensors_items = list(named_tensors.items())
tensors_in_dev_shm = named_tensors
tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
disk_files = [
os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors")
for _rank in range(get_world_size())
]
safetensors.torch.save_file(tensors_in_disk, disk_files[rank])
time.sleep(1)
files.append(disk_files[rank])
dev_shm_files = [
os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
for _ in range(get_world_size())
]
safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
time.sleep(1)
files.append(dev_shm_files[rank])

checkpoint_name = "test_with_files"
proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue))
proc.start()
ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files)
ps.gather_metas(checkpoint_name)
ps.update(checkpoint_name, queue.put, ranks=[])
# sleep 3s to wait process group is destroyed
time.sleep(3)
ps.unregister_checkpoint(checkpoint_name)
queue.put(None)
proc.join()
if rank == 0:
import shutil

os.removedirs(dev_shm_dir)
shutil.rmtree(disk_dir)
assert proc.exitcode == 0


@pytest.mark.gpu
@pytest.mark.parametrize(
"test_name,rank_list",
Expand Down Expand Up @@ -211,6 +271,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
assert result.returncode == 0


@pytest.mark.gpu
def test_update_with_files(test_name: str = "test_with_files"):
world_size = device_manager.device_module.device_count()
assert world_size >= 2, "This test requires at least 2 GPUs."
master_addr = "localhost"
master_port = 25400
cmd = [
"torchrun",
"--nproc_per_node",
str(world_size),
"--master_addr",
master_addr,
"--master_port",
str(master_port),
__file__,
test_name,
"[]",
]

result = subprocess.run( # noqa: S603
cmd,
capture_output=False,
text=True,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
shell=False,
check=False,
)

assert result.returncode == 0


if __name__ == "__main__":
run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ
if not run_with_pytest:
Expand All @@ -230,5 +321,7 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
expected_exception=RuntimeError,
exception_msg="Failed to update weights due to remote errors",
)
elif test_type == "test_with_files":
run_with_files(checker_proc)
else:
raise ValueError(f"Unknown TEST_TYPE: {test_type}")