Skip to content

Commit 15e8dba

Browse files
committed
misc
1 parent c380f0c commit 15e8dba

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

checkpoint_engine/ps.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def _normal_pin_memory(
532532
files: list[str],
533533
named_tensors: dict[str, torch.Tensor],
534534
rank: int | None = None,
535+
shared_pin_memory: list[MemoryBuffer] | None = None,
535536
) -> list[MemoryBuffer]:
536537
parameters = _load_checkpoint(files)
537538
if named_tensors:
@@ -559,22 +560,22 @@ class MemoryBucket(BaseModel):
559560
for bucket in buckets
560561
]
561562

562-
def register_pin_memory(
563-
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
564-
) -> tuple[int, torch.Tensor]:
565-
if shared_pin_memory:
566-
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
567-
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
568-
assert idx < len(shared_pin_memory), (
569-
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
570-
)
571-
assert shared_pin_memory[idx].size == size, (
572-
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
573-
)
574-
return idx, shared_pin_memory[idx].buffer
575-
else:
576-
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
577-
return idx, buffer
563+
def register_pin_memory(
564+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565+
) -> tuple[int, torch.Tensor]:
566+
if shared_pin_memory:
567+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569+
assert idx < len(shared_pin_memory), (
570+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571+
)
572+
assert shared_pin_memory[idx].size == size, (
573+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574+
)
575+
return idx, shared_pin_memory[idx].buffer
576+
else:
577+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578+
return idx, buffer
578579

579580
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
580581
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
@@ -620,6 +621,7 @@ def _register_checkpoint(
620621
files: list[str],
621622
named_tensors: dict[str, torch.Tensor],
622623
rank: int | None = None,
624+
shared_pin_memory: list[MemoryBuffer] | None = None,
623625
) -> list[MemoryBuffer]:
624626
logger.info(
625627
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
@@ -635,7 +637,12 @@ def _register_checkpoint(
635637
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
636638
if files_to_normal_pin or named_tensors:
637639
memory_buffers.extend(
638-
_normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors, rank=rank)
640+
_normal_pin_memory(
641+
files=files_to_normal_pin,
642+
named_tensors=named_tensors,
643+
rank=rank,
644+
shared_pin_memory=shared_pin_memory,
645+
)
639646
)
640647
if files_to_inplace_pin:
641648
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))

0 commit comments

Comments
 (0)