From f7271d4e01480598df96a2ad648bb867e74a9034 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Thu, 22 Jan 2026 21:26:49 +0000 Subject: [PATCH 1/2] Chore: Fix Flat Ref Defaults --- src/flashpack/deserialization.py | 18 +++++------------- src/flashpack/integrations/diffusers/model.py | 2 +- .../integrations/transformers/model.py | 2 +- src/flashpack/mixin.py | 2 +- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/flashpack/deserialization.py b/src/flashpack/deserialization.py index e92cd3a..579b1de 100644 --- a/src/flashpack/deserialization.py +++ b/src/flashpack/deserialization.py @@ -209,12 +209,8 @@ def _copy_memmaps_into_storage( block_tensor = storage.block(idx) num_pipeline_buffers = max(1, min(num_streams, 8)) - - # For dtypes that require bit-reinterpretation (e.g. bfloat16 stored as uint16), - # allocate staging buffers in the packing dtype - packing_dtype = get_packing_dtype(spec.dtype) staging_bufs = [ - torch.empty(elems_per_chunk, dtype=packing_dtype, pin_memory=True) + torch.empty(elems_per_chunk, dtype=spec.dtype, pin_memory=True) for _ in range(num_pipeline_buffers) ] num_cuda_streams = max(1, min(num_streams, 8)) @@ -226,7 +222,7 @@ def _copy_memmaps_into_storage( sz = end - start buf_idx = chunk_idx % num_pipeline_buffers - buf_raw = staging_bufs[buf_idx].narrow(0, 0, sz) + buf = staging_bufs[buf_idx].narrow(0, 0, sz) stream = streams[chunk_idx % num_cuda_streams] if chunk_idx >= num_pipeline_buffers: @@ -236,13 +232,9 @@ def _copy_memmaps_into_storage( with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) src_t = torch.from_numpy(np_view) - buf_raw.copy_(src_t, non_blocking=False) - - # Reinterpret bits if needed (e.g. uint16 -> bfloat16) - if spec.dtype != packing_dtype: - buf = buf_raw.view(spec.dtype) - else: - buf = buf_raw + if src_t.dtype != spec.dtype: + src_t = src_t.to(dtype=spec.dtype) + buf.copy_(src_t, non_blocking=False) with torch.cuda.stream(stream): block_tensor.narrow(0, start, sz).copy_(buf, non_blocking=True) diff --git a/src/flashpack/integrations/diffusers/model.py b/src/flashpack/integrations/diffusers/model.py index 97290ae..ca75d9d 100644 --- a/src/flashpack/integrations/diffusers/model.py +++ b/src/flashpack/integrations/diffusers/model.py @@ -87,7 +87,7 @@ def from_pretrained_flashpack( strict: bool | None = None, strict_params: bool = True, strict_buffers: bool = False, - keep_flash_ref_on_model: bool = True, + keep_flash_ref_on_model: bool = False, num_streams: int = DEFAULT_NUM_STREAMS, chunk_bytes: int = DEFAULT_CHUNK_BYTES, ignore_names: list[str] | None = None, diff --git a/src/flashpack/integrations/transformers/model.py b/src/flashpack/integrations/transformers/model.py index ddfadc0..b1eec89 100644 --- a/src/flashpack/integrations/transformers/model.py +++ b/src/flashpack/integrations/transformers/model.py @@ -90,7 +90,7 @@ def from_pretrained_flashpack( strict: bool | None = None, strict_params: bool = True, strict_buffers: bool = False, - keep_flash_ref_on_model: bool = True, + keep_flash_ref_on_model: bool = False, num_streams: int = DEFAULT_NUM_STREAMS, chunk_bytes: int = DEFAULT_CHUNK_BYTES, ignore_names: list[str] | None = None, diff --git a/src/flashpack/mixin.py b/src/flashpack/mixin.py index a4ce397..8aafd5c 100644 --- a/src/flashpack/mixin.py +++ b/src/flashpack/mixin.py @@ -34,7 +34,7 @@ def from_flashpack( strict: bool | None = None, strict_params: bool = True, strict_buffers: bool = False, - keep_flash_ref_on_model: bool = True, + keep_flash_ref_on_model: bool = False, ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, ignore_suffixes: list[str] | None = None, From d1762aaaddf63f954f01755aa841e2f8deb007d7 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Thu, 22 Jan 2026 21:27:57 +0000 Subject: [PATCH 2/2] wrong branch --- src/flashpack/deserialization.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/flashpack/deserialization.py b/src/flashpack/deserialization.py index 579b1de..e92cd3a 100644 --- a/src/flashpack/deserialization.py +++ b/src/flashpack/deserialization.py @@ -209,8 +209,12 @@ def _copy_memmaps_into_storage( block_tensor = storage.block(idx) num_pipeline_buffers = max(1, min(num_streams, 8)) + + # For dtypes that require bit-reinterpretation (e.g. bfloat16 stored as uint16), + # allocate staging buffers in the packing dtype + packing_dtype = get_packing_dtype(spec.dtype) staging_bufs = [ - torch.empty(elems_per_chunk, dtype=spec.dtype, pin_memory=True) + torch.empty(elems_per_chunk, dtype=packing_dtype, pin_memory=True) for _ in range(num_pipeline_buffers) ] num_cuda_streams = max(1, min(num_streams, 8)) @@ -222,7 +226,7 @@ def _copy_memmaps_into_storage( sz = end - start buf_idx = chunk_idx % num_pipeline_buffers - buf = staging_bufs[buf_idx].narrow(0, 0, sz) + buf_raw = staging_bufs[buf_idx].narrow(0, 0, sz) stream = streams[chunk_idx % num_cuda_streams] if chunk_idx >= num_pipeline_buffers: @@ -232,9 +236,13 @@ def _copy_memmaps_into_storage( with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) src_t = torch.from_numpy(np_view) - if src_t.dtype != spec.dtype: - src_t = src_t.to(dtype=spec.dtype) - buf.copy_(src_t, non_blocking=False) + buf_raw.copy_(src_t, non_blocking=False) + + # Reinterpret bits if needed (e.g. uint16 -> bfloat16) + if spec.dtype != packing_dtype: + buf = buf_raw.view(spec.dtype) + else: + buf = buf_raw with torch.cuda.stream(stream): block_tensor.narrow(0, start, sz).copy_(buf, non_blocking=True)