@@ -532,6 +532,7 @@ def _normal_pin_memory(
532532 files : list [str ],
533533 named_tensors : dict [str , torch .Tensor ],
534534 rank : int | None = None ,
535+ shared_pin_memory : list [MemoryBuffer ] | None = None ,
535536) -> list [MemoryBuffer ]:
536537 parameters = _load_checkpoint (files )
537538 if named_tensors :
@@ -559,22 +560,22 @@ class MemoryBucket(BaseModel):
559560 for bucket in buckets
560561 ]
561562
562- def register_pin_memory (
563- idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
564- ) -> tuple [int , torch .Tensor ]:
565- if shared_pin_memory :
566- # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
567- # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
568- assert idx < len (shared_pin_memory ), (
569- f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
570- )
571- assert shared_pin_memory [idx ].size == size , (
572- f"shared_pin_memory[{ idx } ].size { shared_pin_memory [idx ].size } should be equal to { size } "
573- )
574- return idx , shared_pin_memory [idx ].buffer
575- else :
576- buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
577- return idx , buffer
563+ def register_pin_memory (
564+ idx : int , size : int , shared_pin_memory : list [MemoryBuffer ] | None = None
565+ ) -> tuple [int , torch .Tensor ]:
566+ if shared_pin_memory :
567+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569+ assert idx < len (shared_pin_memory ), (
570+ f"idx { idx } should be less than shared_pin_memory length { len (shared_pin_memory )} "
571+ )
572+ assert shared_pin_memory [idx ].size == size , (
573+ f"shared_pin_memory[{ idx } ].size { shared_pin_memory [idx ].size } should be equal to { size } "
574+ )
575+ return idx , shared_pin_memory [idx ].buffer
576+ else :
577+ buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
578+ return idx , buffer
578579
579580 def register_tensor (buffer : torch .Tensor , offset : int , tensor : torch .Tensor ):
580581 buffer [offset : offset + tensor .nbytes ] = tensor .view (- 1 ).view (dtype = torch .uint8 )
@@ -620,6 +621,7 @@ def _register_checkpoint(
620621 files : list [str ],
621622 named_tensors : dict [str , torch .Tensor ],
622623 rank : int | None = None ,
624+ shared_pin_memory : list [MemoryBuffer ] | None = None ,
623625) -> list [MemoryBuffer ]:
624626 logger .info (
625627 f"[rank{ rank } ] start to register checkpoint with { len (files )} files and { len (named_tensors )} named_tensors"
@@ -635,7 +637,12 @@ def _register_checkpoint(
635637 files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin ]
636638 if files_to_normal_pin or named_tensors :
637639 memory_buffers .extend (
638- _normal_pin_memory (files = files_to_normal_pin , named_tensors = named_tensors , rank = rank )
640+ _normal_pin_memory (
641+ files = files_to_normal_pin ,
642+ named_tensors = named_tensors ,
643+ rank = rank ,
644+ shared_pin_memory = shared_pin_memory ,
645+ )
639646 )
640647 if files_to_inplace_pin :
641648 memory_buffers .extend (_inplace_pin_memory (files_to_inplace_pin , rank = rank ))
0 commit comments