From 326dc9eac55021802ec6dca27f90dfb96cdef9ac Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Sat, 15 Nov 2025 16:49:38 -0500 Subject: [PATCH 01/11] Feat: support multiple dtypes --- src/flashpack/__main__.py | 55 ++- src/flashpack/commands.py | 24 +- src/flashpack/constants.py | 1 + src/flashpack/deserialization.py | 344 ++++++++++++------ src/flashpack/integrations/diffusers/model.py | 2 +- .../integrations/diffusers/pipeline.py | 17 +- .../integrations/transformers/model.py | 2 +- src/flashpack/mixin.py | 2 +- src/flashpack/serialization.py | 212 ++++++++--- tests/test_baseline.py | 44 +++ 10 files changed, 526 insertions(+), 177 deletions(-) diff --git a/src/flashpack/__main__.py b/src/flashpack/__main__.py index 1e32886..cf3e1a6 100644 --- a/src/flashpack/__main__.py +++ b/src/flashpack/__main__.py @@ -1,10 +1,12 @@ import json as jsonlib import os +import traceback import click from . import __version__ from .commands import convert_to_flashpack +from .constants import FILE_FORMAT_V3, FILE_FORMAT_V4 from .deserialization import get_flashpack_file_metadata from .integrations import patch_integrations @@ -31,6 +33,10 @@ def magenta(text: str) -> str: return click.style(text, fg="magenta") +def cyan(text: str) -> str: + return click.style(text, fg="cyan") + + @click.group(name="flashpack") @click.version_option(__version__) def main() -> None: @@ -54,16 +60,54 @@ def metadata(path: str, show_index: bool, json: bool) -> None: if json: print(jsonlib.dumps(metadata, indent=2)) else: + macroblocks = None + format = metadata.get("format") + + if format == FILE_FORMAT_V4: + num_digits = len(str(metadata.get("total_payload_bytes", 0))) + elif format == FILE_FORMAT_V3: + num_digits = len(str(metadata.get("total_elems", 0))) + else: + num_digits = 10 # default to 10 digits for unknown format + for k, v in metadata.items(): if k == "index": continue - print(f"{green(k)}: {v}") + elif k == "macroblocks": + macroblocks = v + macroblocks.sort(key=lambda x: x["offset_bytes"]) + print(f"{green('macroblocks')}:") + for i, r in enumerate(macroblocks): + start_bytes = r["offset_bytes"] + end_bytes = start_bytes + r["length_bytes"] + print( + f" {magenta(i)}: {r['dtype'].ljust(10)} {start_bytes:0{num_digits}d}:{end_bytes:0{num_digits}d}" + ) + else: + print(f"{green(k)}: {v}") + if "index" in metadata and show_index: print(f"{green('index')}:") num_index_digits = len(str(len(metadata["index"]))) - for i, r in enumerate(metadata["index"]): + index_items = metadata["index"] + index_items.sort(key=lambda x: (x.get("macroblock", 0), x["offset"])) + last_seen_macroblock = None + + for i, r in enumerate(index_items): + macroblock = r.get("macroblock", 0) + if last_seen_macroblock != macroblock: + if macroblocks: + macroblock_dtype = macroblocks[macroblock]["dtype"] + else: + macroblock_dtype = metadata["target_dtype"] + if format == FILE_FORMAT_V4: + print( + f" {cyan('macroblock' + str(macroblock))}: {macroblock_dtype.ljust(10)} {r['offset']:0{num_digits}d}:{r['offset'] + r['length']:0{num_digits}d}" + ) + + last_seen_macroblock = macroblock + offset_end = r["offset"] + r["length"] - num_digits = len(str(metadata["total_elems"])) element_range = ( f"{r['offset']:0{num_digits}d}:{offset_end:0{num_digits}d}" ) @@ -111,6 +155,7 @@ def metadata(path: str, show_index: bool, json: bool) -> None: @click.option( "--use-diffusers", is_flag=True, help="Use diffusers to convert the model." ) +@click.option("--verbose", "-v", is_flag=True, help="Verbose output.") def convert( path_or_repo_id: str, destination_path: str, @@ -122,6 +167,7 @@ def convert( ignore_suffixes: list[str], use_transformers: bool, use_diffusers: bool, + verbose: bool, ) -> None: """ Convert a model to a flashpack file. @@ -138,10 +184,13 @@ def convert( use_diffusers=use_diffusers, subfolder=subfolder, variant=variant, + silent=not verbose, ) print(green(f"Success: Saved to {os.path.abspath(result_path)}")) except Exception as e: print(red(f"Error: {e}")) + if verbose: + traceback.print_exc() exit(1) diff --git a/src/flashpack/commands.py b/src/flashpack/commands.py index 4325e3e..f97ca06 100644 --- a/src/flashpack/commands.py +++ b/src/flashpack/commands.py @@ -38,6 +38,7 @@ def convert_to_flashpack_from_state_dict( ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, ignore_suffixes: list[str] | None = None, + silent: bool = True, ) -> str: """ Converts a state dictionary to a flashpack file. @@ -54,13 +55,12 @@ def convert_to_flashpack_from_state_dict( if isinstance(dtype, str): dtype = string_to_dtype(dtype) - elif dtype is None: - dtype = next(iter(state_dict.values())).dtype pack_to_file( state_dict, destination_path, dtype, + silent=silent, ) return destination_path @@ -73,16 +73,11 @@ def convert_to_flashpack_from_model( ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, ignore_suffixes: list[str] | None = None, + silent: bool = True, ) -> str: """ Converts a model to a flashpack file. """ - if dtype is None: - try: - dtype = model.dtype - except AttributeError: - dtype = next(model.parameters()).dtype - return convert_to_flashpack_from_state_dict( model.state_dict(), destination_path, @@ -90,6 +85,7 @@ def convert_to_flashpack_from_model( ignore_names, ignore_prefixes, ignore_suffixes, + silent, ) @@ -100,6 +96,7 @@ def convert_to_flashpack_from_state_dict_file( ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, ignore_suffixes: list[str] | None = None, + silent: bool = True, ) -> str: """ Converts a state dictionary file to a flashpack file. @@ -121,6 +118,7 @@ def convert_to_flashpack_from_state_dict_file( ignore_names, ignore_prefixes, ignore_suffixes, + silent, ) @@ -131,12 +129,14 @@ def convert_to_flashpack_from_diffusers_repo_id_or_dir( ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, ignore_suffixes: list[str] | None = None, + silent: bool = True, **kwargs: Any, ) -> str: """ Converts a diffusers model to a flashpack model. """ from diffusers import AutoModel + from .utils import string_to_dtype if isinstance(dtype, str): @@ -151,6 +151,7 @@ def convert_to_flashpack_from_diffusers_repo_id_or_dir( ignore_names=ignore_names, ignore_prefixes=ignore_prefixes, ignore_suffixes=ignore_suffixes, + silent=silent, ) return destination_path @@ -162,12 +163,14 @@ def convert_to_flashpack_from_transformers_repo_id_or_dir( ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, ignore_suffixes: list[str] | None = None, + silent: bool = True, **kwargs: Any, ) -> str: """ Converts a transformers model to a flashpack model. """ from transformers import AutoModel + from .utils import string_to_dtype if isinstance(dtype, str): @@ -182,6 +185,7 @@ def convert_to_flashpack_from_transformers_repo_id_or_dir( ignore_names=ignore_names, ignore_prefixes=ignore_prefixes, ignore_suffixes=ignore_suffixes, + silent=silent, ) return destination_path @@ -197,6 +201,7 @@ def convert_to_flashpack( ignore_suffixes: list[str] | None = None, use_transformers: bool = False, use_diffusers: bool = False, + silent: bool = True, **kwargs: Any, ) -> str: """ @@ -215,6 +220,7 @@ def convert_to_flashpack( ignore_names, ignore_prefixes, ignore_suffixes, + silent, ) model_dir = model_or_state_dict_or_path_or_repo_id @@ -266,6 +272,7 @@ def convert_to_flashpack( ignore_names, ignore_prefixes, ignore_suffixes, + silent, **kwargs, ) return convert_to_flashpack_from_diffusers_repo_id_or_dir( @@ -275,5 +282,6 @@ def convert_to_flashpack( ignore_names, ignore_prefixes, ignore_suffixes, + silent, **kwargs, ) diff --git a/src/flashpack/constants.py b/src/flashpack/constants.py index 74c4c0b..a6a5111 100644 --- a/src/flashpack/constants.py +++ b/src/flashpack/constants.py @@ -4,6 +4,7 @@ U64LE = struct.Struct(" torch.Tensor: + return self.blocks[idx] + + def __len__(self) -> int: + return len(self.blocks) + + @property + def device(self) -> torch.device: + if not self.blocks: + return torch.device("cpu") + return self.blocks[0].device + + def get_flashpack_file_metadata(path: str) -> dict[str, Any]: """ Get the metadata from a flashpack file. @@ -49,8 +77,9 @@ def get_flashpack_file_metadata(path: str) -> dict[str, Any]: f.seek(start) meta = json.loads(f.read(json_len).decode("utf-8")) - if meta.get("format") != FILE_FORMAT_V3: - raise ValueError(f"Unexpected format: {meta.get('format')}") + fmt = meta.get("format") + if fmt not in (FILE_FORMAT_V3, FILE_FORMAT_V4): + raise ValueError(f"Unexpected format: {fmt}") return meta @@ -66,130 +95,213 @@ def is_flashpack_file(path: str) -> bool: return False -def read_flashpack_file( - path: str, - device: str | torch.device = "cpu", - chunk_bytes: int = DEFAULT_CHUNK_BYTES, - num_streams: int = DEFAULT_NUM_STREAMS, - silent: bool = True, -) -> tuple[torch.Tensor, dict[str, Any]]: - """ - Read the flashpack file and return the tensor and metadata. - """ - with timer("read_metadata", silent): - meta = get_flashpack_file_metadata(path) +def _ensure_index_macroblocks(meta: dict[str, Any], num_blocks: int) -> None: + index = meta.get("index", []) + for rec in index: + block_id = rec.get("macroblock") + if block_id is None: + block_id = 0 + rec["macroblock"] = block_id + block_id = int(block_id) + if block_id < 0 or block_id >= num_blocks: + raise ValueError( + f"Index entry references macroblock {block_id}, but only {num_blocks} blocks exist." + ) - device = torch.device(device) if isinstance(device, str) else device - target_dtype = string_to_dtype(meta["target_dtype"]) - total_elems = int(meta["total_elems"]) - elem_sz = torch.tensor([], dtype=target_dtype).element_size() - with timer("mmap_payload", silent): - np_dtype = torch_dtype_to_numpy_dtype(target_dtype) - mm = np.memmap(path, dtype=np_dtype, mode="r", shape=(total_elems,)) - - # Fast CPU path - if device.type == "cpu": - with timer("cpu_from_memmap", silent): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - flash_cpu = ( - torch.from_numpy(mm) - if target_dtype != torch.bfloat16 - else torch.from_numpy(mm.view(np.uint16)).view(torch.bfloat16) +def _build_macroblock_specs(meta: dict[str, Any]) -> list[MacroblockSpec]: + fmt = meta.get("format") + specs: list[MacroblockSpec] = [] + if fmt == FILE_FORMAT_V3: + dtype = string_to_dtype(meta["target_dtype"]) + total_elems = int(meta["total_elems"]) + elem_sz = torch.tensor([], dtype=dtype).element_size() + specs.append( + MacroblockSpec( + dtype=dtype, + offset_bytes=0, + length_bytes=total_elems * elem_sz, + length_elems=total_elems, + ) + ) + elif fmt == FILE_FORMAT_V4: + macroblocks = meta.get("macroblocks") + if not macroblocks: + raise ValueError("Missing macroblock metadata for flashpack v4 file.") + for block in macroblocks: + dtype = string_to_dtype(block["dtype"]) + specs.append( + MacroblockSpec( + dtype=dtype, + offset_bytes=int(block["offset_bytes"]), + length_bytes=int(block["length_bytes"]), + length_elems=int(block["length_elems"]), ) - return flash_cpu, meta + ) + else: + raise ValueError(f"Unsupported flashpack format: {fmt}") - if device.type != "cuda": - raise ValueError(f"Unsupported device: {device}") + _ensure_index_macroblocks(meta, len(specs)) + return specs - with timer("alloc_device", silent): - if target_dtype == torch.bfloat16: - flash_dev = torch.empty(total_elems, dtype=torch.bfloat16, device=device) - flash_dev_u16 = flash_dev.view(torch.uint16) - else: - flash_dev = torch.empty(total_elems, dtype=target_dtype, device=device) - flash_dev_u16 = None - # Advise kernel to read ahead (Linux only) +def _madvise_memmap(mm: np.memmap) -> None: try: import mmap as mmap_module - # MADV_WILLNEED: tell kernel we'll need this data - # MADV_SEQUENTIAL: we'll read sequentially mm._mmap.madvise(mmap_module.MADV_WILLNEED) mm._mmap.madvise(mmap_module.MADV_SEQUENTIAL) - except: + except Exception: pass - # Tune chunk size for the specific file - total_bytes = total_elems * elem_sz - - # aim for 100-200 chunks total for good pipelining - target_num_chunks = 150 - optimal_chunk_bytes = max(chunk_bytes, total_bytes // target_num_chunks) - # But cap at 64MB to avoid too much staging memory - optimal_chunk_bytes = min(optimal_chunk_bytes, 64 * 1024 * 1024) - - elems_per_chunk = max(1, (optimal_chunk_bytes // elem_sz)) - n_chunks = (total_elems + elems_per_chunk - 1) // elems_per_chunk - with timer("read_and_copy", silent): - # Pre-allocate a small number of pinned staging buffers for pipelining - num_pipeline_buffers = min(num_streams, 8) # Don't over-allocate - dt = torch.uint16 if target_dtype == torch.bfloat16 else target_dtype +def _open_memmaps(path: str, specs: list[MacroblockSpec]) -> list[np.memmap]: + memmaps = [] + for spec in specs: + np_dtype = torch_dtype_to_numpy_dtype(spec.dtype) + mm = np.memmap( + path, + dtype=np_dtype, + mode="r", + offset=spec.offset_bytes, + shape=(spec.length_elems,), + ) + _madvise_memmap(mm) + memmaps.append(mm) + return memmaps + + +def _cpu_storage_from_memmaps( + memmaps: list[np.memmap], specs: list[MacroblockSpec] +) -> FlashTensorStorage: + blocks: list[torch.Tensor] = [] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + for mm, spec in zip(memmaps, specs): + tensor = torch.from_numpy(mm) + if spec.dtype == torch.bfloat16: + tensor = tensor.view(torch.bfloat16) + blocks.append(tensor) + return FlashTensorStorage(blocks=blocks, backing_arrays=memmaps) + + +def _copy_memmaps_into_storage( + memmaps: list[np.memmap], + specs: list[MacroblockSpec], + storage: FlashTensorStorage, + device: torch.device, + chunk_bytes: int, + num_streams: int, +) -> None: + for idx, (mm, spec) in enumerate(zip(memmaps, specs)): + total_elems = spec.length_elems + elem_sz = torch.tensor([], dtype=spec.dtype).element_size() + total_bytes = total_elems * elem_sz + + target_num_chunks = 150 + optimal_chunk_bytes = max(chunk_bytes, total_bytes // max(target_num_chunks, 1)) + optimal_chunk_bytes = min(optimal_chunk_bytes, 64 * 1024 * 1024) + elems_per_chunk = max(1, (optimal_chunk_bytes // max(elem_sz, 1))) + n_chunks = (total_elems + elems_per_chunk - 1) // elems_per_chunk + + block_tensor = storage.block(idx) + num_pipeline_buffers = max(1, min(num_streams, 8)) staging_bufs = [ - torch.empty(elems_per_chunk, dtype=dt, pin_memory=True) + torch.empty(elems_per_chunk, dtype=spec.dtype, pin_memory=True) for _ in range(num_pipeline_buffers) ] + num_cuda_streams = max(1, min(num_streams, 8)) + streams = [torch.cuda.Stream(device=device) for _ in range(num_cuda_streams)] - # Pre-allocate streams - streams = [torch.cuda.Stream(device=device) for _ in range(min(num_streams, 8))] - - # Pipeline: fill first buffer, then alternate fill/copy for chunk_idx in range(n_chunks): start = chunk_idx * elems_per_chunk end = min(total_elems, start + elems_per_chunk) sz = end - start - # Select staging buffer (round-robin) buf_idx = chunk_idx % num_pipeline_buffers buf = staging_bufs[buf_idx].narrow(0, 0, sz) + stream = streams[chunk_idx % num_cuda_streams] - # Select stream - stream = streams[chunk_idx % len(streams)] - - # Wait for this buffer's previous use to complete if chunk_idx >= num_pipeline_buffers: stream.synchronize() - # Copy from mmap to staging (CPU) np_view = mm[start:end] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - src_t = ( - torch.from_numpy(np_view) - if target_dtype != torch.bfloat16 - else torch.from_numpy(np_view.view(np.uint16)) - ) + src_t = torch.from_numpy(np_view) + if src_t.dtype != spec.dtype: + src_t = src_t.to(dtype=spec.dtype) buf.copy_(src_t, non_blocking=False) - # Copy to device (GPU) on the selected stream with torch.cuda.stream(stream): - if target_dtype == torch.bfloat16: - flash_dev_u16.narrow(0, start, sz).copy_(buf, non_blocking=True) - else: - flash_dev.narrow(0, start, sz).copy_(buf, non_blocking=True) + block_tensor.narrow(0, start, sz).copy_(buf, non_blocking=True) - # Final sync torch.cuda.synchronize(device) + return None + + +def _allocate_empty_storage( + specs: list[MacroblockSpec], device: torch.device +) -> FlashTensorStorage: + blocks = [ + torch.empty(spec.length_elems, dtype=spec.dtype, device=device) + for spec in specs + ] + return FlashTensorStorage(blocks=blocks) + + +def _broadcast_storage(storage: FlashTensorStorage, src: int) -> None: + for block in storage.blocks: + dist.broadcast(block, src=src) + + +def read_flashpack_file( + path: str, + device: str | torch.device = "cpu", + chunk_bytes: int = DEFAULT_CHUNK_BYTES, + num_streams: int = DEFAULT_NUM_STREAMS, + silent: bool = True, + metadata: dict[str, Any] | None = None, +) -> tuple[FlashTensorStorage, dict[str, Any]]: + """ + Read the flashpack file and return the macroblock storage and metadata. + """ + with timer("read_metadata", silent): + meta = metadata or get_flashpack_file_metadata(path) + + specs = _build_macroblock_specs(meta) + device = torch.device(device) if isinstance(device, str) else device + + with timer("mmap_payload", silent): + memmaps = _open_memmaps(path, specs) - del mm - return flash_dev, meta + if device.type == "cpu": + with timer("cpu_from_memmap", silent): + storage = _cpu_storage_from_memmaps(memmaps, specs) + return storage, meta + + if device.type != "cuda": + raise ValueError(f"Unsupported device: {device}") + + with timer("alloc_device", silent): + storage = _allocate_empty_storage(specs, device) + + with timer("read_and_copy", silent): + _copy_memmaps_into_storage( + memmaps, + specs, + storage, + device=device, + chunk_bytes=chunk_bytes, + num_streams=num_streams, + ) + + del memmaps + return storage, meta def iterate_from_flash_tensor( - flash_tensor: torch.Tensor, + flash_tensor: FlashTensorStorage | torch.Tensor, metadata: dict[str, Any], ignore_names: list[str] | None = None, ignore_prefixes: list[str] | None = None, @@ -198,14 +310,36 @@ def iterate_from_flash_tensor( """ Iterate over the tensors stored in the flash tensor. """ + storage = ( + flash_tensor + if isinstance(flash_tensor, FlashTensorStorage) + else FlashTensorStorage(blocks=[flash_tensor]) + ) index = metadata["index"] align_bytes = int(metadata.get("align_bytes", 0)) + align_cache: dict[int, int] = {} + + def _get_align(block_idx: int) -> int: + if not align_bytes: + return 0 + if block_idx not in align_cache: + esz = storage.block(block_idx).element_size() + g = math.gcd(align_bytes, esz) + align_cache[block_idx] = align_bytes // g if g else 0 + return align_cache[block_idx] + if align_bytes: - esz = flash_tensor.element_size() - g = math.gcd(align_bytes, esz) - align_elems = align_bytes // g - bad = [rec for rec in index if (int(rec["offset"]) % align_elems) != 0] + bad: list[dict[str, Any]] = [] + for rec in index: + block_idx = int(rec.get("macroblock", 0)) + if block_idx < 0 or block_idx >= len(storage): + raise ValueError( + f"Index entry references invalid macroblock {block_idx}." + ) + align_elems = _get_align(block_idx) + if align_elems and (int(rec["offset"]) % align_elems) != 0: + bad.append(rec) if bad: names = ", ".join(r["name"] for r in bad[:3]) raise ValueError( @@ -220,9 +354,13 @@ def iterate_from_flash_tensor( shape = tuple(rec["shape"]) or (1,) off = int(rec["offset"]) n = int(rec["length"]) + block_idx = int(rec.get("macroblock", 0)) + if block_idx < 0 or block_idx >= len(storage): + raise ValueError(f"Index entry references invalid macroblock {block_idx}.") + block_tensor = storage.block(block_idx) try: - view = flash_tensor.narrow(0, off, n).view( + view = block_tensor.narrow(0, off, n).view( *shape ) # contiguous 1D slice -> reshaped yield name, view @@ -271,24 +409,22 @@ def assign_from_file( world_size=world_size, ) rank = dist.get_rank() + meta = get_flashpack_file_metadata(path) + specs = _build_macroblock_specs(meta) if rank == 0: - flash_tensor, meta = read_flashpack_file( + flash_storage, meta = read_flashpack_file( path=path, device=device, silent=silent, num_streams=num_streams, chunk_bytes=chunk_bytes, + metadata=meta, ) else: - meta = get_flashpack_file_metadata(path) - flash_tensor = torch.empty( - meta["total_elems"], - dtype=string_to_dtype(meta["target_dtype"]), - device=device, - ) - dist.broadcast(flash_tensor, src=0) + flash_storage = _allocate_empty_storage(specs, device) + _broadcast_storage(flash_storage, src=0) else: - flash_tensor, meta = read_flashpack_file( + flash_storage, meta = read_flashpack_file( path=path, device=device, silent=silent, @@ -297,11 +433,9 @@ def assign_from_file( ) if keep_flash_ref_on_model: - setattr(model, "_flash_shared_storage", flash_tensor) + setattr(model, "_flash_shared_storage", flash_storage) setattr(model, "_flash_shared_storage_meta", meta) - target_dtype = string_to_dtype(meta["target_dtype"]) - with timer("build_lookups", silent): params = dict(model.named_parameters()) buffers = dict(model.named_buffers()) @@ -314,7 +448,7 @@ def assign_from_file( with timer("assign", silent): try: for name, view in iterate_from_flash_tensor( - flash_tensor, meta, ignore_names, ignore_prefixes, ignore_suffixes + flash_storage, meta, ignore_names, ignore_prefixes, ignore_suffixes ): total_elements += view.numel() @@ -337,12 +471,12 @@ def assign_from_file( raise TypeError( f"Expected Tensor buffer at '{name}', got {type(old_buf)}" ) - if old_buf.dtype != target_dtype: + if old_buf.dtype != view.dtype: if coerce_dtype: view = view.to(old_buf.dtype) else: raise TypeError( - f"dtype mismatch for buffer '{name}': model={old_buf.dtype} vs flash={target_dtype}." + f"dtype mismatch for buffer '{name}': model={old_buf.dtype} vs flash={view.dtype}." ) module._buffers[attr] = view assigned_buffer_names.append(name) diff --git a/src/flashpack/integrations/diffusers/model.py b/src/flashpack/integrations/diffusers/model.py index 6dabe1b..4ba3849 100644 --- a/src/flashpack/integrations/diffusers/model.py +++ b/src/flashpack/integrations/diffusers/model.py @@ -53,7 +53,7 @@ def save_pretrained_flashpack( self.save_flashpack( model_path, - target_dtype=target_dtype or self.dtype, + target_dtype=target_dtype, align_bytes=align_bytes, silent=silent, num_workers=num_workers, diff --git a/src/flashpack/integrations/diffusers/pipeline.py b/src/flashpack/integrations/diffusers/pipeline.py index b7eaa26..9d2beb8 100644 --- a/src/flashpack/integrations/diffusers/pipeline.py +++ b/src/flashpack/integrations/diffusers/pipeline.py @@ -4,12 +4,12 @@ import sys import warnings from typing import Any -from typing_extensions import Self import torch import torch.distributed as dist from huggingface_hub import DDUFEntry, create_repo, read_dduf_file, snapshot_download from packaging import version +from typing_extensions import Self from diffusers import OnnxRuntimeModel from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin @@ -105,6 +105,7 @@ def save_pretrained_flashpack( align_bytes: int = DEFAULT_ALIGN_BYTES, silent: bool = True, num_workers: int = DEFAULT_NUM_WRITE_WORKERS, + target_dtype: torch.dtype | dict[str, torch.dtype] | None = None, **kwargs: Any, ) -> None: """ @@ -139,21 +140,21 @@ def is_saveable_module(name, value): sub_model_dir = os.path.join(save_directory, pipeline_component_name) model_cls = sub_model.__class__ + sub_model_target_dtype = ( + target_dtype.get(pipeline_component_name, None) + if isinstance(target_dtype, dict) + else target_dtype + ) + if isinstance( sub_model, (FlashPackDiffusersModelMixin, FlashPackTransformersModelMixin), ): os.makedirs(sub_model_dir, exist_ok=True) - target_dtype = getattr(sub_model, "dtype", None) - if target_dtype is None: - try: - target_dtype = next(iter(sub_model.parameters())).dtype - except StopIteration: - pass sub_model.save_pretrained_flashpack( sub_model_dir, is_main_process=is_main_process, - target_dtype=target_dtype, + target_dtype=sub_model_target_dtype, align_bytes=align_bytes, silent=silent, num_workers=num_workers, diff --git a/src/flashpack/integrations/transformers/model.py b/src/flashpack/integrations/transformers/model.py index db7550a..7075193 100644 --- a/src/flashpack/integrations/transformers/model.py +++ b/src/flashpack/integrations/transformers/model.py @@ -55,7 +55,7 @@ def save_pretrained_flashpack( self.save_flashpack( model_path, - target_dtype=target_dtype or self.dtype, + target_dtype=target_dtype, align_bytes=align_bytes, silent=silent, num_workers=num_workers, diff --git a/src/flashpack/mixin.py b/src/flashpack/mixin.py index 776ff34..1a9cc38 100644 --- a/src/flashpack/mixin.py +++ b/src/flashpack/mixin.py @@ -99,7 +99,7 @@ def from_flashpack( def save_flashpack( self, destination_path: str, - target_dtype: torch.dtype, + target_dtype: torch.dtype | None = None, name_order: list[str] | None = None, align_bytes: int = DEFAULT_ALIGN_BYTES, silent: bool = False, diff --git a/src/flashpack/serialization.py b/src/flashpack/serialization.py index 1f7662a..7919893 100644 --- a/src/flashpack/serialization.py +++ b/src/flashpack/serialization.py @@ -12,6 +12,7 @@ DEFAULT_ALIGN_BYTES, DEFAULT_NUM_WRITE_WORKERS, FILE_FORMAT_V3, + FILE_FORMAT_V4, MAGIC, U64LE, ) @@ -24,12 +25,23 @@ class TensorIndexRecord: shape: list[int] offset: int # element offset (not bytes) length: int # number of elements + macroblock: int = 0 + + +@dataclass +class MacroblockPlan: + dtype: torch.dtype + offset_bytes: int + length_bytes: int + total_elems: int + align_elems: int + tensors: list[TensorIndexRecord] def pack_to_file( state_dict_or_model: dict[str, torch.Tensor] | torch.nn.Module, destination_path: str, - target_dtype: torch.dtype, + target_dtype: torch.dtype | None, name_order: list[str] | None = None, align_bytes: int = DEFAULT_ALIGN_BYTES, silent: bool = True, @@ -57,33 +69,89 @@ def pack_to_file( if align_bytes < 0: raise ValueError("align_bytes must be >= 0") - element_size = torch.tensor([], dtype=target_dtype).element_size() - g = math.gcd(align_bytes, element_size) if align_bytes else 1 - align_elems = (align_bytes // g) if align_bytes else 0 + def _validate_dtype(dtype: torch.dtype) -> torch.dtype: + if not isinstance(dtype, torch.dtype): + raise ValueError(f"Unsupported dtype in state dict: {dtype}") + torch_dtype_to_numpy_dtype(dtype) + return dtype + + def _lcm(a: int, b: int) -> int: + if a == 0 and b == 0: + return 0 + if a == 0: + return abs(b) + if b == 0: + return abs(a) + return abs(a * b) // math.gcd(a, b) + + resolved_target_dtype = ( + _validate_dtype(target_dtype) if target_dtype is not None else None + ) + + dtype_to_names: dict[torch.dtype, list[str]] = {} + dtype_order: list[torch.dtype] = [] + for name in names: + tensor = state_dict[name] + write_dtype = resolved_target_dtype or _validate_dtype(tensor.dtype) + if write_dtype not in dtype_to_names: + dtype_to_names[write_dtype] = [] + dtype_order.append(write_dtype) + dtype_to_names[write_dtype].append(name) with timer("build_index", silent): + macroblocks: list[MacroblockPlan] = [] index: list[TensorIndexRecord] = [] - elem_cursor = 0 - for name in names: - t = state_dict[name] - n = t.numel() - - if align_elems: - pad_elems = (-elem_cursor) % align_elems - elem_cursor += pad_elems - - index.append( - TensorIndexRecord( - name=name, shape=list(t.shape), offset=elem_cursor, length=n + file_cursor = 0 # bytes + + for block_id, dtype in enumerate(dtype_order): + names_for_dtype = dtype_to_names[dtype] + elem_size = torch.tensor([], dtype=dtype).element_size() + block_alignment = _lcm(align_bytes, elem_size) if align_bytes else elem_size + if block_alignment: + pad_bytes = (-file_cursor) % block_alignment + file_cursor += pad_bytes + block_offset = file_cursor + + g = math.gcd(align_bytes, elem_size) if align_bytes else 1 + align_elems = (align_bytes // g) if align_bytes else 0 + + block_cursor = 0 + block_records: list[TensorIndexRecord] = [] + for name in names_for_dtype: + tensor = state_dict[name] + n = tensor.numel() + if align_elems: + pad_elems = (-block_cursor) % align_elems + block_cursor += pad_elems + + rec = TensorIndexRecord( + name=name, + shape=list(tensor.shape), + offset=block_cursor, + length=n, + macroblock=block_id, + ) + block_records.append(rec) + index.append(rec) + block_cursor += n + + block_size_bytes = block_cursor * elem_size + macroblocks.append( + MacroblockPlan( + dtype=dtype, + offset_bytes=block_offset, + length_bytes=block_size_bytes, + total_elems=block_cursor, + align_elems=align_elems, + tensors=block_records, ) ) - elem_cursor += n + file_cursor = block_offset + block_size_bytes - total_elems = elem_cursor - if total_elems == 0: + total_payload_bytes = file_cursor + if total_payload_bytes == 0: raise ValueError("Nothing to pack after alignment.") - np_dtype = torch_dtype_to_numpy_dtype(target_dtype) dest_dir = os.path.dirname(os.path.abspath(destination_path)) or "." os.makedirs(dest_dir, exist_ok=True) fd_tmp = None @@ -95,13 +163,24 @@ def pack_to_file( os.close(fd_tmp) with timer("create_memmap", silent): - mm = np.memmap(tmp_path, dtype=np_dtype, mode="w+", shape=(total_elems,)) - flash_view = ( - torch.from_numpy(mm) - if target_dtype != torch.bfloat16 - else torch.from_numpy(mm.view(np.uint16)) + mm = np.memmap( + tmp_path, dtype=np.uint8, mode="w+", shape=(total_payload_bytes,) ) + block_numpy_views: list[np.ndarray] = [] + block_views: list[torch.Tensor] = [] + for block in macroblocks: + block_slice = mm[ + block.offset_bytes : block.offset_bytes + block.length_bytes + ] + np_dtype = torch_dtype_to_numpy_dtype(block.dtype) + if block.dtype == torch.bfloat16: + typed_view = block_slice.view(np.uint16) + else: + typed_view = block_slice.view(np_dtype) + block_numpy_views.append(typed_view) + block_views.append(torch.from_numpy(typed_view)) + # Optimized copy: sequential with batched progress updates with timer("copy_to_memmap", silent): # Only show progress if not silent @@ -120,17 +199,20 @@ def pack_to_file( actual_workers = min(4, num_workers) def copy_one(rec: TensorIndexRecord) -> None: + block = macroblocks[rec.macroblock] + dst_block = block_views[rec.macroblock] src = state_dict[rec.name] - if target_dtype == torch.bfloat16: - src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") + target = block.dtype + if target == torch.bfloat16: + src_cpu = src.view(-1).to(dtype=target, device="cpu") src_bits = src_cpu.view(torch.uint16) - dst = flash_view.narrow(0, rec.offset, rec.length).view( + dst = dst_block.narrow(0, rec.offset, rec.length).view( torch.uint16 ) dst.copy_(src_bits, non_blocking=False) else: - src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") - dst = flash_view.narrow(0, rec.offset, rec.length) + src_cpu = src.view(-1).to(dtype=target, device="cpu") + dst = dst_block.narrow(0, rec.offset, rec.length) dst.copy_(src_cpu, non_blocking=False) with ThreadPoolExecutor(max_workers=actual_workers) as ex: @@ -153,18 +235,20 @@ def copy_one(rec: TensorIndexRecord) -> None: progress_update_interval = max(1, len(index) // 100) for i, rec in enumerate(index): + block = macroblocks[rec.macroblock] + dst_block = block_views[rec.macroblock] src = state_dict[rec.name] - if target_dtype == torch.bfloat16: - src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") + if block.dtype == torch.bfloat16: + src_cpu = src.view(-1).to(dtype=block.dtype, device="cpu") src_bits = src_cpu.view(torch.uint16) - dst = flash_view.narrow(0, rec.offset, rec.length).view( + dst = dst_block.narrow(0, rec.offset, rec.length).view( torch.uint16 ) dst.copy_(src_bits, non_blocking=False) else: - src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") - dst = flash_view.narrow(0, rec.offset, rec.length) + src_cpu = src.view(-1).to(dtype=block.dtype, device="cpu") + dst = dst_block.narrow(0, rec.offset, rec.length) dst.copy_(src_cpu, non_blocking=False) # Batch progress updates to reduce overhead @@ -184,21 +268,49 @@ def copy_one(rec: TensorIndexRecord) -> None: mm.flush() # Append footer - meta_payload = { - "format": FILE_FORMAT_V3, - "target_dtype": dtype_to_string(target_dtype), - "align_bytes": int(align_bytes), - "total_elems": int(total_elems), - "index": [ - { - "name": r.name, - "shape": r.shape, - "offset": int(r.offset), - "length": int(r.length), - } - for r in index - ], - } + if len(macroblocks) == 1: + block = macroblocks[0] + meta_payload = { + "format": FILE_FORMAT_V3, + "target_dtype": dtype_to_string(block.dtype), + "align_bytes": int(align_bytes), + "total_elems": int(block.total_elems), + "index": [ + { + "name": r.name, + "shape": r.shape, + "offset": int(r.offset), + "length": int(r.length), + } + for r in index + ], + } + else: + meta_payload = { + "format": FILE_FORMAT_V4, + "align_bytes": int(align_bytes), + "total_payload_bytes": int(total_payload_bytes), + "total_elems": sum(block.total_elems for block in macroblocks), + "macroblocks": [ + { + "dtype": dtype_to_string(block.dtype), + "offset_bytes": int(block.offset_bytes), + "length_bytes": int(block.length_bytes), + "length_elems": int(block.total_elems), + } + for block in macroblocks + ], + "index": [ + { + "name": r.name, + "shape": r.shape, + "offset": int(r.offset), + "length": int(r.length), + "macroblock": int(r.macroblock), + } + for r in index + ], + } footer_json = json.dumps( meta_payload, separators=(",", ":"), ensure_ascii=False ).encode("utf-8") diff --git a/tests/test_baseline.py b/tests/test_baseline.py index f864c7b..94d3084 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -5,8 +5,52 @@ import torch import tqdm from flashpack import FlashPackMixin +from flashpack.constants import FILE_FORMAT_V4 +from flashpack.deserialization import assign_from_file, get_flashpack_file_metadata +from flashpack.serialization import pack_to_file from flashpack.utils import timer + +def test_mixed_dtype_roundtrip(tmp_path) -> None: + """ + Ensure checkpoints with multiple dtypes roundtrip via macroblocks. + """ + + class MixedDTypeModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.float_param = torch.nn.Parameter(torch.randn(4, dtype=torch.float32)) + self.bfloat_param = torch.nn.Parameter(torch.randn(3, dtype=torch.bfloat16)) + self.register_buffer("int_buffer", torch.arange(6, dtype=torch.int16)) + + torch.manual_seed(0) + source = MixedDTypeModel() + torch.manual_seed(1) + destination = MixedDTypeModel() + path = tmp_path / "mixed.flashpack" + + pack_to_file( + source.state_dict(), + destination_path=str(path), + target_dtype=None, + silent=True, + ) + + meta = get_flashpack_file_metadata(str(path)) + assert meta["format"] == FILE_FORMAT_V4 + assert len(meta["macroblocks"]) == 3 + + assign_from_file(destination, str(path), device="cpu", silent=True) + + source_state = source.state_dict() + loaded_state = destination.state_dict() + for name, tensor in source_state.items(): + assert name in loaded_state + loaded = loaded_state[name] + assert loaded.dtype == tensor.dtype + assert torch.equal(loaded, tensor) + + HERE = os.path.dirname(os.path.abspath(__file__)) From 3fbf3e86f99a292042bc060438204fbc3fe69737 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Tue, 25 Nov 2025 17:34:31 -0500 Subject: [PATCH 02/11] add revert command, fix missing types, add robust test --- .pre-commit-config.yaml | 37 ++++++++++++++ src/flashpack/__main__.py | 28 +++++++++- src/flashpack/commands.py | 27 ++++++++++ src/flashpack/deserialization.py | 22 ++++++-- src/flashpack/integrations/diffusers/model.py | 13 +++-- src/flashpack/serialization.py | 33 ++++++------ src/flashpack/utils.py | 51 ++++++++++++++++++- tests/test_baseline.py | 37 ++++++++++++-- 8 files changed, 217 insertions(+), 31 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..b28c903 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,37 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-docstring-first + - id: check-toml + - id: check-yaml + exclude: packaging/.* + args: + - --allow-multiple-documents + - id: mixed-line-ending + args: [--fix=lf] + - id: end-of-file-fixer + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: 'v0.3.4' + hooks: + - id: ruff + name: lint with ruff + - id: ruff + name: sort imports with ruff + args: [--select, I, --fix] + - id: ruff-format + name: format with ruff + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-vcs-permalinks + - id: debug-statements + + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v14.0.6 + hooks: + - id: clang-format diff --git a/src/flashpack/__main__.py b/src/flashpack/__main__.py index cf3e1a6..abd4d2c 100644 --- a/src/flashpack/__main__.py +++ b/src/flashpack/__main__.py @@ -5,7 +5,7 @@ import click from . import __version__ -from .commands import convert_to_flashpack +from .commands import convert_to_flashpack, revert_from_flashpack from .constants import FILE_FORMAT_V3, FILE_FORMAT_V4 from .deserialization import get_flashpack_file_metadata from .integrations import patch_integrations @@ -194,5 +194,31 @@ def convert( exit(1) +@main.command(name="revert") +@click.argument("path", type=click.Path(exists=True)) +@click.argument("destination_path", type=click.Path(), required=False) +@click.option("--verbose", "-v", is_flag=True, help="Verbose output.") +def revert( + path: str, + destination_path: str, + verbose: bool, +) -> None: + """ + Revert a flashpack file back to a safetensors or torch file. + """ + try: + result_path = revert_from_flashpack( + path, + destination_path, + silent=not verbose, + ) + print(green(f"Success: Saved to {os.path.abspath(result_path)}")) + except Exception as e: + print(red(f"Error: {e}")) + if verbose: + traceback.print_exc() + exit(1) + + if __name__ == "__main__": main() diff --git a/src/flashpack/commands.py b/src/flashpack/commands.py index f97ca06..6f875ec 100644 --- a/src/flashpack/commands.py +++ b/src/flashpack/commands.py @@ -285,3 +285,30 @@ def convert_to_flashpack( silent, **kwargs, ) + + +def revert_from_flashpack( + path: str, + destination_path: str | None = None, + silent: bool = True, +) -> str: + """ + Reverts a flashpack file to a state dictionary. + """ + from .deserialization import revert_from_file + + state_dict = revert_from_file(path, silent=silent) + if not destination_path: + destination_path = path.replace(".flashpack", ".safetensors") + + _, ext = os.path.splitext(destination_path) + if ext == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, destination_path) + else: + import torch + + torch.save(state_dict, destination_path) + + return destination_path diff --git a/src/flashpack/deserialization.py b/src/flashpack/deserialization.py index a4cf4ab..ab56874 100644 --- a/src/flashpack/deserialization.py +++ b/src/flashpack/deserialization.py @@ -20,6 +20,7 @@ ) from .utils import ( get_module_and_attribute, + get_packing_dtype, human_num_elements, is_ignored_tensor_name, maybe_init_distributed, @@ -179,8 +180,9 @@ def _cpu_storage_from_memmaps( warnings.filterwarnings("ignore", category=UserWarning) for mm, spec in zip(memmaps, specs): tensor = torch.from_numpy(mm) - if spec.dtype == torch.bfloat16: - tensor = tensor.view(torch.bfloat16) + packing_dtype = get_packing_dtype(spec.dtype) + if spec.dtype != packing_dtype: + tensor = tensor.view(spec.dtype) blocks.append(tensor) return FlashTensorStorage(blocks=blocks, backing_arrays=memmaps) @@ -368,6 +370,20 @@ def _get_align(block_idx: int) -> int: raise ValueError(f"Could not get tensor for record {rec}") from e +def revert_from_file( + path: str, + silent: bool = True, +) -> dict[str, torch.Tensor]: + """ + Revert a flashpack file to a state dictionary. + """ + storage, meta = read_flashpack_file(path, silent=silent) + state_dict = {} + for name, view in iterate_from_flash_tensor(storage, meta): + state_dict[name] = view.detach().cpu() + return state_dict + + def assign_from_file( model: torch.nn.Module, path: str, @@ -375,7 +391,7 @@ def assign_from_file( strict: bool | None = None, strict_params: bool = True, strict_buffers: bool = False, - keep_flash_ref_on_model: bool = True, + keep_flash_ref_on_model: bool = False, silent: bool = True, num_streams: int = DEFAULT_NUM_STREAMS, chunk_bytes: int = DEFAULT_CHUNK_BYTES, diff --git a/src/flashpack/integrations/diffusers/model.py b/src/flashpack/integrations/diffusers/model.py index 4ba3849..97290ae 100644 --- a/src/flashpack/integrations/diffusers/model.py +++ b/src/flashpack/integrations/diffusers/model.py @@ -5,12 +5,11 @@ from pathlib import Path from typing import Any -import torch -from huggingface_hub import create_repo, snapshot_download - import diffusers +import torch from diffusers import ModelMixin from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from huggingface_hub import create_repo, snapshot_download from ... import __version__ from ...constants import ( @@ -42,7 +41,9 @@ def save_pretrained_flashpack( private: bool | None = None, **kwargs: Any, ) -> None: - if not os.path.isdir(save_directory): + if not os.path.exists(save_directory): + os.makedirs(save_directory) + elif not os.path.isdir(save_directory): raise ValueError(f"Save directory {save_directory} is not a directory") model_path = os.path.join(save_directory, "model.flashpack") @@ -114,7 +115,9 @@ def from_pretrained_flashpack( device = ( torch.device(device) if isinstance(device, str) - else torch.device("cpu") if device is None else device + else torch.device("cpu") + if device is None + else device ) user_agent = { diff --git a/src/flashpack/serialization.py b/src/flashpack/serialization.py index 7919893..3684e33 100644 --- a/src/flashpack/serialization.py +++ b/src/flashpack/serialization.py @@ -16,7 +16,7 @@ MAGIC, U64LE, ) -from .utils import dtype_to_string, timer, torch_dtype_to_numpy_dtype +from .utils import dtype_to_string, get_packing_dtype, timer, torch_dtype_to_numpy_dtype @dataclass @@ -174,10 +174,7 @@ def _lcm(a: int, b: int) -> int: block.offset_bytes : block.offset_bytes + block.length_bytes ] np_dtype = torch_dtype_to_numpy_dtype(block.dtype) - if block.dtype == torch.bfloat16: - typed_view = block_slice.view(np.uint16) - else: - typed_view = block_slice.view(np_dtype) + typed_view = block_slice.view(np_dtype) block_numpy_views.append(typed_view) block_views.append(torch.from_numpy(typed_view)) @@ -202,16 +199,18 @@ def copy_one(rec: TensorIndexRecord) -> None: block = macroblocks[rec.macroblock] dst_block = block_views[rec.macroblock] src = state_dict[rec.name] - target = block.dtype - if target == torch.bfloat16: - src_cpu = src.view(-1).to(dtype=target, device="cpu") - src_bits = src_cpu.view(torch.uint16) + target_dtype = block.dtype + packing_dtype = get_packing_dtype(target_dtype) + + if target_dtype != packing_dtype: + src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") + src_bits = src_cpu.view(packing_dtype) dst = dst_block.narrow(0, rec.offset, rec.length).view( - torch.uint16 + packing_dtype ) dst.copy_(src_bits, non_blocking=False) else: - src_cpu = src.view(-1).to(dtype=target, device="cpu") + src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") dst = dst_block.narrow(0, rec.offset, rec.length) dst.copy_(src_cpu, non_blocking=False) @@ -238,16 +237,18 @@ def copy_one(rec: TensorIndexRecord) -> None: block = macroblocks[rec.macroblock] dst_block = block_views[rec.macroblock] src = state_dict[rec.name] + target_dtype = block.dtype + packing_dtype = get_packing_dtype(target_dtype) - if block.dtype == torch.bfloat16: - src_cpu = src.view(-1).to(dtype=block.dtype, device="cpu") - src_bits = src_cpu.view(torch.uint16) + if target_dtype != packing_dtype: + src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") + src_bits = src_cpu.view(packing_dtype) dst = dst_block.narrow(0, rec.offset, rec.length).view( - torch.uint16 + packing_dtype ) dst.copy_(src_bits, non_blocking=False) else: - src_cpu = src.view(-1).to(dtype=block.dtype, device="cpu") + src_cpu = src.view(-1).to(dtype=target_dtype, device="cpu") dst = dst_block.narrow(0, rec.offset, rec.length) dst.copy_(src_cpu, non_blocking=False) diff --git a/src/flashpack/utils.py b/src/flashpack/utils.py index c55cfce..d3a411e 100644 --- a/src/flashpack/utils.py +++ b/src/flashpack/utils.py @@ -11,6 +11,11 @@ logger = logging.getLogger("flashpack") +try: + float4_e2m1fn = torch.float4_e2m1fn +except AttributeError: + float4_e2m1fn = None # type: ignore + def maybe_init_distributed( rank: int | None = None, @@ -70,21 +75,63 @@ def dtype_to_string(dtype: torch.dtype) -> str: return str(dtype).split(".")[1] +def get_packing_dtype(dtype: torch.dtype) -> torch.dtype: + """ + Get the dtype to use for packing a given dtype. + """ + if dtype is float4_e2m1fn: + raise ValueError(f"Unsupported dtype for packing: {dtype}") + elif dtype in [ + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, + ]: + return torch.uint8 + elif dtype is torch.bfloat16: + return torch.uint16 + elif dtype is torch.complex32: + return torch.uint32 + else: + return dtype + + def torch_dtype_to_numpy_dtype(dtype: torch.dtype) -> np.dtype: """ Convert a torch.dtype to a numpy.dtype. """ + if dtype is float4_e2m1fn: + raise ValueError( + "4-bit data types are not supported at this time due to NumPy not having a similar 4-bit dtype. " + "If you feel up to tackling the task of supporting this, an idea is to combine 2 4-bit values " + "into a single 8-bit value and use that as the payload, then un-do this in the deserialization step. " + "Please feel free to contribute a PR if you're interested in adding support for this!" + ) + mapping = { torch.float32: np.float32, torch.float64: np.float64, torch.float16: np.float16, - torch.bfloat16: np.uint16, # store raw bf16 bits as uint16 payload torch.int8: np.int8, - torch.uint8: np.uint8, torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, + torch.uint8: np.uint8, + torch.uint16: np.uint16, + torch.uint32: np.uint32, + torch.uint64: np.uint64, torch.bool: np.bool_, + torch.complex64: np.complex64, + torch.complex128: np.complex128, + # unsupported dtypes, we map to uints + torch.float8_e4m3fn: np.uint8, + torch.float8_e4m3fnuz: np.uint8, + torch.float8_e5m2: np.uint8, + torch.float8_e5m2fnuz: np.uint8, + torch.float8_e8m0fnu: np.uint8, + torch.bfloat16: np.uint16, + torch.complex32: np.uint32, } if dtype not in mapping: raise ValueError(f"Unsupported dtype for packing: {dtype}") diff --git a/tests/test_baseline.py b/tests/test_baseline.py index 94d3084..dcea69a 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -19,9 +19,38 @@ def test_mixed_dtype_roundtrip(tmp_path) -> None: class MixedDTypeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.float_param = torch.nn.Parameter(torch.randn(4, dtype=torch.float32)) - self.bfloat_param = torch.nn.Parameter(torch.randn(3, dtype=torch.bfloat16)) - self.register_buffer("int_buffer", torch.arange(6, dtype=torch.int16)) + self.float_param = torch.nn.Parameter(torch.zeros(4, dtype=torch.float32)) + self.bfloat_param = torch.nn.Parameter(torch.zeros(3, dtype=torch.bfloat16)) + self.float16_param = torch.nn.Parameter(torch.zeros(2, dtype=torch.float16)) + self.register_buffer("int8_buffer", torch.zeros(4, dtype=torch.int8)) + self.register_buffer("uint8_buffer", torch.zeros(4, dtype=torch.uint8)) + self.register_buffer("int16_buffer", torch.zeros(4, dtype=torch.int16)) + self.register_buffer("uint16_buffer", torch.zeros(4, dtype=torch.uint16)) + self.register_buffer("int32_buffer", torch.zeros(4, dtype=torch.int32)) + self.register_buffer("uint32_buffer", torch.zeros(4, dtype=torch.uint32)) + self.register_buffer("int64_buffer", torch.zeros(4, dtype=torch.int64)) + self.register_buffer("uint64_buffer", torch.zeros(4, dtype=torch.uint64)) + self.register_buffer( + "float8_buffer", torch.zeros(4, dtype=torch.float8_e4m3fn) + ) + self.register_buffer( + "float8_fnuz_buffer", torch.zeros(4, dtype=torch.float8_e4m3fnuz) + ) + self.register_buffer( + "float8_e5m2_buffer", torch.zeros(4, dtype=torch.float8_e5m2) + ) + self.register_buffer( + "float8_e5m2_fnuz_buffer", torch.zeros(4, dtype=torch.float8_e5m2fnuz) + ) + self.register_buffer( + "float8_e8m0fnu_buffer", torch.zeros(4, dtype=torch.float8_e8m0fnu) + ) + self.register_buffer( + "complex64_buffer", torch.zeros(4, dtype=torch.complex64) + ) + self.register_buffer( + "complex128_buffer", torch.zeros(4, dtype=torch.complex128) + ) torch.manual_seed(0) source = MixedDTypeModel() @@ -38,7 +67,7 @@ def __init__(self) -> None: meta = get_flashpack_file_metadata(str(path)) assert meta["format"] == FILE_FORMAT_V4 - assert len(meta["macroblocks"]) == 3 + assert len(meta["macroblocks"]) == 18 assign_from_file(destination, str(path), device="cpu", silent=True) From 17e58b6737083841d553d8183940fc2e17dc2cf2 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Tue, 25 Nov 2025 17:36:21 -0500 Subject: [PATCH 03/11] add pre commit hook --- .github/workflows/pre-commit.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 .github/workflows/pre-commit.yaml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 0000000..56481b5 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,18 @@ +name: Lint (pre-commit) + +on: + push: + branches: + - main + pull_request: + types: [assigned, opened, synchronize, reopened] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: "3.11" + - uses: pre-commit/action@v3.0.0 From 4a2cb93a800676f5829f78cbafc5575f05105462 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Tue, 25 Nov 2025 17:45:36 -0500 Subject: [PATCH 04/11] lints --- pyproject.toml | 2 +- scripts/run_benchmark.py | 8 ++++---- src/flashpack/integrations/diffusers/auto_model.py | 8 ++++---- src/flashpack/integrations/diffusers/patch.py | 4 ++-- src/flashpack/integrations/diffusers/pipeline.py | 8 +++----- src/flashpack/integrations/transformers/model.py | 7 ++++--- src/flashpack/integrations/transformers/patch.py | 11 +++++------ src/flashpack/mixin.py | 5 +++-- tests/test_speed_comparison.py | 11 ++++++++--- tests/test_wan_pipeline.py | 4 ++-- 10 files changed, 36 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a508996..532988b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,4 +46,4 @@ dev = [ ] [tool.setuptools_scm] -write_to = "src/flashpack/version.py" \ No newline at end of file +write_to = "src/flashpack/version.py" diff --git a/scripts/run_benchmark.py b/scripts/run_benchmark.py index fe3edb9..d2b89c1 100644 --- a/scripts/run_benchmark.py +++ b/scripts/run_benchmark.py @@ -8,13 +8,13 @@ import torch from flashpack.integrations.diffusers import patch_diffusers_auto_model from flashpack.integrations.transformers import patch_transformers_auto_model +from huggingface_hub import snapshot_download patch_diffusers_auto_model() patch_transformers_auto_model() -from diffusers.models import AutoModel as DiffusersAutoModel -from huggingface_hub import snapshot_download -from transformers import AutoModel as TransformersAutoModel +from diffusers.models import AutoModel as DiffusersAutoModel # noqa: E402 +from transformers import AutoModel as TransformersAutoModel # noqa: E402 def test_model( @@ -201,4 +201,4 @@ def sync_and_flush() -> None: ) print_test_result( test_model_name, accelerate_time, flashpack_time, total_bytes - ) \ No newline at end of file + ) diff --git a/src/flashpack/integrations/diffusers/auto_model.py b/src/flashpack/integrations/diffusers/auto_model.py index 5d5d8e5..5f6b59b 100644 --- a/src/flashpack/integrations/diffusers/auto_model.py +++ b/src/flashpack/integrations/diffusers/auto_model.py @@ -14,15 +14,15 @@ import os -from flashpack.integrations.diffusers.model import FlashPackDiffusersModelMixin -from flashpack.utils import logger -from huggingface_hub.utils import validate_hf_hub_args - from diffusers.configuration_utils import ConfigMixin from diffusers.utils.dynamic_modules_utils import ( get_class_from_dynamic_module, resolve_trust_remote_code, ) +from huggingface_hub.utils import validate_hf_hub_args + +from flashpack.integrations.diffusers.model import FlashPackDiffusersModelMixin +from flashpack.utils import logger class AutoFlashPackModel(ConfigMixin): diff --git a/src/flashpack/integrations/diffusers/patch.py b/src/flashpack/integrations/diffusers/patch.py index f415d47..0161b17 100644 --- a/src/flashpack/integrations/diffusers/patch.py +++ b/src/flashpack/integrations/diffusers/patch.py @@ -7,10 +7,10 @@ def patch_diffusers_auto_model() -> None: """ patch_diffusers_pipeline_loading_utils() - from flashpack.integrations.diffusers.auto_model import AutoFlashPackModel - import diffusers.models.auto_model + from flashpack.integrations.diffusers.auto_model import AutoFlashPackModel + diffusers.models.auto_model.AutoModel = AutoFlashPackModel diff --git a/src/flashpack/integrations/diffusers/pipeline.py b/src/flashpack/integrations/diffusers/pipeline.py index 9d2beb8..affb9fe 100644 --- a/src/flashpack/integrations/diffusers/pipeline.py +++ b/src/flashpack/integrations/diffusers/pipeline.py @@ -7,10 +7,6 @@ import torch import torch.distributed as dist -from huggingface_hub import DDUFEntry, create_repo, read_dduf_file, snapshot_download -from packaging import version -from typing_extensions import Self - from diffusers import OnnxRuntimeModel from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from diffusers.quantizers import PipelineQuantizationConfig @@ -30,6 +26,9 @@ populate_model_card, ) from diffusers.utils.torch_utils import get_device, is_compiled_module +from huggingface_hub import DDUFEntry, create_repo, read_dduf_file, snapshot_download +from packaging import version +from typing_extensions import Self from ...constants import ( DEFAULT_ALIGN_BYTES, @@ -278,7 +277,6 @@ def from_pretrained_flashpack( variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) - use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) diff --git a/src/flashpack/integrations/transformers/model.py b/src/flashpack/integrations/transformers/model.py index 7075193..57e77d6 100644 --- a/src/flashpack/integrations/transformers/model.py +++ b/src/flashpack/integrations/transformers/model.py @@ -5,9 +5,8 @@ from typing import Any import torch -from huggingface_hub import create_repo, snapshot_download - import transformers +from huggingface_hub import create_repo, snapshot_download from transformers.modeling_utils import PreTrainedModel from transformers.utils.hub import create_and_tag_model_card @@ -119,7 +118,9 @@ def from_pretrained_flashpack( device = ( torch.device(device) if isinstance(device, str) - else torch.device("cpu") if device is None else device + else torch.device("cpu") + if device is None + else device ) user_agent = { diff --git a/src/flashpack/integrations/transformers/patch.py b/src/flashpack/integrations/transformers/patch.py index d22e4d9..8ba7c03 100644 --- a/src/flashpack/integrations/transformers/patch.py +++ b/src/flashpack/integrations/transformers/patch.py @@ -29,9 +29,6 @@ def patch_auto_factory() -> None: from transformers.configuration_utils import ( PretrainedConfig as PreTrainedConfig, ) - from flashpack.integrations.transformers.model import ( - FlashPackTransformersModelMixin, - ) import transformers.models.auto.auto_factory from transformers.dynamic_module_utils import ( @@ -40,6 +37,10 @@ def patch_auto_factory() -> None: ) from transformers.models.auto.configuration_auto import AutoConfig + from flashpack.integrations.transformers.model import ( + FlashPackTransformersModelMixin, + ) + def patched_from_config(cls, config, **kwargs): trust_remote_code = kwargs.pop("trust_remote_code", None) has_remote_code = ( @@ -316,6 +317,4 @@ def from_pretrained_flashpack( transformers.models.auto.auto_factory._BaseAutoModelClass.from_pretrained = ( patched_from_pretrained ) - transformers.models.auto.auto_factory._BaseAutoModelClass.from_pretrained_flashpack = ( - from_pretrained_flashpack - ) + transformers.models.auto.auto_factory._BaseAutoModelClass.from_pretrained_flashpack = from_pretrained_flashpack diff --git a/src/flashpack/mixin.py b/src/flashpack/mixin.py index 1a9cc38..1680259 100644 --- a/src/flashpack/mixin.py +++ b/src/flashpack/mixin.py @@ -17,7 +17,6 @@ class FlashPackMixin: - flashpack_coerce_dtype: bool = False flashpack_init_method: str | None = None flashpack_ignore_names: list[str] | None = None @@ -53,7 +52,9 @@ def from_flashpack( device = ( torch.device(device) if isinstance(device, str) - else torch.device("cpu") if device is None else device + else torch.device("cpu") + if device is None + else device ) with init_empty_weights(): diff --git a/tests/test_speed_comparison.py b/tests/test_speed_comparison.py index 97e7cee..28888b6 100644 --- a/tests/test_speed_comparison.py +++ b/tests/test_speed_comparison.py @@ -1,9 +1,10 @@ import os import time +import matplotlib import safetensors.torch import torch -import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np @@ -16,17 +17,19 @@ sf_filename = os.path.join(repo_dir, "model.safetensors") flashpack_filename = os.path.join(repo_dir, "model.flashpack") -print(f"Preparing model") +print("Preparing model") model = GPT2Model.from_pretrained("gpt2", device_map="cuda") if not os.path.exists(flashpack_filename): pack_to_file(model, flashpack_filename, target_dtype=model.dtype) print("Running load time comparison (10 runs each)") + def cuda_sync(): if torch.cuda.is_available(): torch.cuda.synchronize() + num_runs = 10 times_pt = [] times_sf = [] @@ -111,7 +114,9 @@ def cuda_sync(): ax.set_yticklabels(labels, color=label_color) ax.invert_yaxis() # top-to-bottom order as specified -ax.set_xlabel("Loading Time (seconds)", fontsize=12, fontweight="bold", color=label_color) +ax.set_xlabel( + "Loading Time (seconds)", fontsize=12, fontweight="bold", color=label_color +) ax.set_title( "load_state_dict() Time Comparison", fontsize=14, diff --git a/tests/test_wan_pipeline.py b/tests/test_wan_pipeline.py index d5e06dd..0cf801a 100644 --- a/tests/test_wan_pipeline.py +++ b/tests/test_wan_pipeline.py @@ -1,5 +1,6 @@ import os import sys +from typing import Optional import torch from diffusers.models import AutoencoderKLWan, WanTransformer3DModel @@ -13,7 +14,6 @@ from flashpack.utils import timer from huggingface_hub import snapshot_download from transformers import AutoTokenizer, UMT5EncoderModel -from typing import Optional class FlashPackWanTransformer3DModel( @@ -129,4 +129,4 @@ def __init__( num_inference_steps=28, generator=generator, ) - generator.manual_seed(42) \ No newline at end of file + generator.manual_seed(42) From b872a7071752de19ab3badf21b3e2bf662c40e82 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Tue, 25 Nov 2025 17:50:42 -0500 Subject: [PATCH 05/11] dos2unix --- README.md | 182 +++++++++++++++++++++++++++--------------------------- 1 file changed, 91 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index e4156a4..1591f17 100644 --- a/README.md +++ b/README.md @@ -1,91 +1,91 @@ -
- - - - FlashPack Logo - -

Disk-to-GPU Tensor loading at up to 25Gbps without GDS

-
- -
- - - - Benchmark Results - -Run this benchmark in `scripts/run_benchmark.py` -
- -
- - - - Benchmark Results - -Run this benchmark in `tests/test_speed_comparison.py` -
- -# Integration Guide -## Mixins -### Diffusers/Transformers - -```py -# Integration classes -from flashpack.integrations.diffusers import FlashPackDiffusersModelMixin, FlashPackDiffusionPipeline -from flashpack.integrations.transformers import FlashPackTransformersModelMixin - -# Base classes -from diffusers.models import MyModel, SomeOtherModel -from diffusers.pipelines import MyPipeline - -# Define mixed classes -class FlashPackMyModel(MyModel, FlashPackDiffusersModelMixin): - pass - -class FlashPackMyPipeline(MyPipeline, FlashPackDiffusionPipine): - def __init__( - self, - my_model: FlashPackMyModel, - other_model: SomeOtherModel, - ) -> None: - super().__init__() - -# Load base pipeline -pipeline = FlashPackMyPipeline.from_pretrained("some/repository") - -# Save flashpack pipeline -pipeline.save_pretrained_flashpack( - "some_directory", - push_to_hub=False, # pass repo_id when using this -) - -# Load directly from flashpack directory or repository -pipeline = FlashPackMyPipeline.from_pretrained_flashpack("my/flashpack-repository") -``` - -### Vanilla PyTorch - -```py -from flashpack import FlashPackMixin - -class MyModule(nn.Module, FlashPackMixin): - def __init__(self, some_arg: int = 4) -> None: - ... - -module = MyModule(some_arg = 4) -module.save_flashpack("model.flashpack") - -loaded_module = module.from_flashpack("model.flashpack", some_arg=4) -``` - -## Direct Integration - -```py -from flashpack import pack_to_file, assign_from_file - -flashpack_path = "/path/to/model.flashpack" -model = nn.Module(...) - -pack_to_file(model, flashpack_path) # write state dict to file -assign_from_file(model, flashpack_path) # load state dict from file -``` +
+ + + + FlashPack Logo + +

Disk-to-GPU Tensor loading at up to 25Gbps without GDS

+
+ +
+ + + + Benchmark Results + +Run this benchmark in `scripts/run_benchmark.py` +
+ +
+ + + + Benchmark Results + +Run this benchmark in `tests/test_speed_comparison.py` +
+ +# Integration Guide +## Mixins +### Diffusers/Transformers + +```py +# Integration classes +from flashpack.integrations.diffusers import FlashPackDiffusersModelMixin, FlashPackDiffusionPipeline +from flashpack.integrations.transformers import FlashPackTransformersModelMixin + +# Base classes +from diffusers.models import MyModel, SomeOtherModel +from diffusers.pipelines import MyPipeline + +# Define mixed classes +class FlashPackMyModel(MyModel, FlashPackDiffusersModelMixin): + pass + +class FlashPackMyPipeline(MyPipeline, FlashPackDiffusionPipine): + def __init__( + self, + my_model: FlashPackMyModel, + other_model: SomeOtherModel, + ) -> None: + super().__init__() + +# Load base pipeline +pipeline = FlashPackMyPipeline.from_pretrained("some/repository") + +# Save flashpack pipeline +pipeline.save_pretrained_flashpack( + "some_directory", + push_to_hub=False, # pass repo_id when using this +) + +# Load directly from flashpack directory or repository +pipeline = FlashPackMyPipeline.from_pretrained_flashpack("my/flashpack-repository") +``` + +### Vanilla PyTorch + +```py +from flashpack import FlashPackMixin + +class MyModule(nn.Module, FlashPackMixin): + def __init__(self, some_arg: int = 4) -> None: + ... + +module = MyModule(some_arg = 4) +module.save_flashpack("model.flashpack") + +loaded_module = module.from_flashpack("model.flashpack", some_arg=4) +``` + +## Direct Integration + +```py +from flashpack import pack_to_file, assign_from_file + +flashpack_path = "/path/to/model.flashpack" +model = nn.Module(...) + +pack_to_file(model, flashpack_path) # write state dict to file +assign_from_file(model, flashpack_path) # load state dict from file +``` From d78a12baac9c4beb468fdbbf7064a71e81f0b2a7 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 26 Nov 2025 00:01:21 +0000 Subject: [PATCH 06/11] fix integrations --- src/flashpack/deserialization.py | 27 ++++++++-- .../integrations/transformers/model.py | 1 + src/flashpack/mixin.py | 28 ++++++----- tests/test_baseline.py | 49 ++++++++++--------- tests/test_integrations.py | 19 +------ 5 files changed, 65 insertions(+), 59 deletions(-) diff --git a/src/flashpack/deserialization.py b/src/flashpack/deserialization.py index ab56874..e92cd3a 100644 --- a/src/flashpack/deserialization.py +++ b/src/flashpack/deserialization.py @@ -9,6 +9,7 @@ import numpy as np import torch import torch.distributed as dist +import tqdm from .constants import ( DEFAULT_CHUNK_BYTES, @@ -208,8 +209,12 @@ def _copy_memmaps_into_storage( block_tensor = storage.block(idx) num_pipeline_buffers = max(1, min(num_streams, 8)) + + # For dtypes that require bit-reinterpretation (e.g. bfloat16 stored as uint16), + # allocate staging buffers in the packing dtype + packing_dtype = get_packing_dtype(spec.dtype) staging_bufs = [ - torch.empty(elems_per_chunk, dtype=spec.dtype, pin_memory=True) + torch.empty(elems_per_chunk, dtype=packing_dtype, pin_memory=True) for _ in range(num_pipeline_buffers) ] num_cuda_streams = max(1, min(num_streams, 8)) @@ -221,7 +226,7 @@ def _copy_memmaps_into_storage( sz = end - start buf_idx = chunk_idx % num_pipeline_buffers - buf = staging_bufs[buf_idx].narrow(0, 0, sz) + buf_raw = staging_bufs[buf_idx].narrow(0, 0, sz) stream = streams[chunk_idx % num_cuda_streams] if chunk_idx >= num_pipeline_buffers: @@ -231,9 +236,13 @@ def _copy_memmaps_into_storage( with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) src_t = torch.from_numpy(np_view) - if src_t.dtype != spec.dtype: - src_t = src_t.to(dtype=spec.dtype) - buf.copy_(src_t, non_blocking=False) + buf_raw.copy_(src_t, non_blocking=False) + + # Reinterpret bits if needed (e.g. uint16 -> bfloat16) + if spec.dtype != packing_dtype: + buf = buf_raw.view(spec.dtype) + else: + buf = buf_raw with torch.cuda.stream(stream): block_tensor.narrow(0, start, sz).copy_(buf, non_blocking=True) @@ -379,8 +388,16 @@ def revert_from_file( """ storage, meta = read_flashpack_file(path, silent=silent) state_dict = {} + progress: tqdm.tqdm | None = None + + if not silent: + progress = tqdm.tqdm(desc="Reverting from flashpack", total=len(storage)) + for name, view in iterate_from_flash_tensor(storage, meta): state_dict[name] = view.detach().cpu() + if progress: + progress.update(1) + return state_dict diff --git a/src/flashpack/integrations/transformers/model.py b/src/flashpack/integrations/transformers/model.py index 57e77d6..ddfadc0 100644 --- a/src/flashpack/integrations/transformers/model.py +++ b/src/flashpack/integrations/transformers/model.py @@ -168,6 +168,7 @@ def from_pretrained_flashpack( ) kwargs.update(model_kwargs) + kwargs.pop("config", None) return cls.from_flashpack( flashpack_path, config, diff --git a/src/flashpack/mixin.py b/src/flashpack/mixin.py index 1680259..a4ce397 100644 --- a/src/flashpack/mixin.py +++ b/src/flashpack/mixin.py @@ -1,7 +1,8 @@ from __future__ import annotations import inspect -from typing import Any +from collections.abc import Callable +from typing import Any, ClassVar import torch from accelerate import init_empty_weights @@ -17,11 +18,11 @@ class FlashPackMixin: - flashpack_coerce_dtype: bool = False - flashpack_init_method: str | None = None - flashpack_ignore_names: list[str] | None = None - flashpack_ignore_prefixes: list[str] | None = None - flashpack_ignore_suffixes: list[str] | None = None + flashpack_coerce_dtype: ClassVar[bool] = False + flashpack_init_method: ClassVar[str | None] = None + flashpack_ignore_names: ClassVar[list[str] | None] = None + flashpack_ignore_prefixes: ClassVar[list[str] | None] = None + flashpack_ignore_suffixes: ClassVar[list[str] | None] = None @classmethod def from_flashpack( @@ -44,6 +45,7 @@ def from_flashpack( local_rank: int | None = None, world_size: int | None = None, coerce_dtype: bool = False, + init_fn: Callable[..., "FlashPackMixin"] | None = None, **kwargs: Any, ) -> FlashPackMixin: """ @@ -58,14 +60,16 @@ def from_flashpack( ) with init_empty_weights(): - if cls.flashpack_init_method is not None and hasattr( - cls, cls.flashpack_init_method - ): - init_fn = getattr(cls, cls.flashpack_init_method) - else: - init_fn = cls + if init_fn is None: + if cls.flashpack_init_method is not None and hasattr( + cls, cls.flashpack_init_method + ): + init_fn = getattr(cls, cls.flashpack_init_method) + else: + init_fn = cls parameters = inspect.signature(init_fn).parameters + kwargs = {k: v for k, v in kwargs.items() if k in parameters} if "rank" in parameters: kwargs["rank"] = rank if "local_rank" in parameters: diff --git a/tests/test_baseline.py b/tests/test_baseline.py index dcea69a..50bd38c 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -19,37 +19,37 @@ def test_mixed_dtype_roundtrip(tmp_path) -> None: class MixedDTypeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.float_param = torch.nn.Parameter(torch.zeros(4, dtype=torch.float32)) - self.bfloat_param = torch.nn.Parameter(torch.zeros(3, dtype=torch.bfloat16)) - self.float16_param = torch.nn.Parameter(torch.zeros(2, dtype=torch.float16)) - self.register_buffer("int8_buffer", torch.zeros(4, dtype=torch.int8)) - self.register_buffer("uint8_buffer", torch.zeros(4, dtype=torch.uint8)) - self.register_buffer("int16_buffer", torch.zeros(4, dtype=torch.int16)) - self.register_buffer("uint16_buffer", torch.zeros(4, dtype=torch.uint16)) - self.register_buffer("int32_buffer", torch.zeros(4, dtype=torch.int32)) - self.register_buffer("uint32_buffer", torch.zeros(4, dtype=torch.uint32)) - self.register_buffer("int64_buffer", torch.zeros(4, dtype=torch.int64)) - self.register_buffer("uint64_buffer", torch.zeros(4, dtype=torch.uint64)) + self.float_param = torch.nn.Parameter(torch.ones(4, dtype=torch.float32)) + self.bfloat_param = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16)) + self.float16_param = torch.nn.Parameter(torch.ones(2, dtype=torch.float16)) + self.register_buffer("int8_buffer", torch.ones(4, dtype=torch.int8)) + self.register_buffer("uint8_buffer", torch.ones(4, dtype=torch.uint8)) + self.register_buffer("int16_buffer", torch.ones(4, dtype=torch.int16)) + self.register_buffer("uint16_buffer", torch.ones(4, dtype=torch.uint16)) + self.register_buffer("int32_buffer", torch.ones(4, dtype=torch.int32)) + self.register_buffer("uint32_buffer", torch.ones(4, dtype=torch.uint32)) + self.register_buffer("int64_buffer", torch.ones(4, dtype=torch.int64)) + self.register_buffer("uint64_buffer", torch.ones(4, dtype=torch.uint64)) self.register_buffer( - "float8_buffer", torch.zeros(4, dtype=torch.float8_e4m3fn) + "float8_buffer", torch.ones(4, dtype=torch.float8_e4m3fn) ) self.register_buffer( - "float8_fnuz_buffer", torch.zeros(4, dtype=torch.float8_e4m3fnuz) + "float8_fnuz_buffer", torch.ones(4, dtype=torch.float8_e4m3fnuz) ) self.register_buffer( - "float8_e5m2_buffer", torch.zeros(4, dtype=torch.float8_e5m2) + "float8_e5m2_buffer", torch.ones(4, dtype=torch.float8_e5m2) ) self.register_buffer( - "float8_e5m2_fnuz_buffer", torch.zeros(4, dtype=torch.float8_e5m2fnuz) + "float8_e5m2_fnuz_buffer", torch.ones(4, dtype=torch.float8_e5m2fnuz) ) self.register_buffer( - "float8_e8m0fnu_buffer", torch.zeros(4, dtype=torch.float8_e8m0fnu) + "float8_e8m0fnu_buffer", torch.ones(4, dtype=torch.float8_e8m0fnu) ) self.register_buffer( - "complex64_buffer", torch.zeros(4, dtype=torch.complex64) + "complex64_buffer", torch.ones(4, dtype=torch.complex64) ) self.register_buffer( - "complex128_buffer", torch.zeros(4, dtype=torch.complex128) + "complex128_buffer", torch.ones(4, dtype=torch.complex128) ) torch.manual_seed(0) @@ -138,7 +138,7 @@ class FlashPackWanTransformer3DModel( model_path = os.path.join(save_dir, "model.flashpack") if not os.path.exists(model_path): - initial_model.save_pretrained_flashpack(save_dir, target_dtype=torch.bfloat16) + initial_model.save_pretrained_flashpack(save_dir) flashpack_model = FlashPackWanTransformer3DModel.from_pretrained_flashpack( save_dir, @@ -211,12 +211,13 @@ class FlashPackWanTextEncoderModel( model_path = os.path.join(save_dir, "model.flashpack") if not os.path.exists(model_path): - initial_model.save_pretrained_flashpack(save_dir, target_dtype=torch.bfloat16) + initial_model.save_pretrained_flashpack(save_dir) - flashpack_model = FlashPackWanTextEncoderModel.from_pretrained_flashpack( - save_dir, - device="cuda" if torch.cuda.is_available() else "cpu", - ) + with timer("load_flashpack"): + flashpack_model = FlashPackWanTextEncoderModel.from_pretrained_flashpack( + save_dir, + device="cuda" if torch.cuda.is_available() else "cpu", + ) # Build lookups initial_model_params = { diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 0c06468..69cf37c 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -48,21 +48,4 @@ def test_transformers() -> None: with TemporaryDirectory() as tmpdir: model.save_pretrained_flashpack(tmpdir) assert os.path.exists(os.path.join(tmpdir, "model.flashpack")) - assert AutoModel.from_pretrained_flashpack(tmpdir) is not None - patch_transformers_auto_model() - - from flashpack.integrations.transformers.model import ( - FlashPackTransformersModelMixin, - ) - from transformers.models import AutoModel - - model = AutoModel.from_pretrained( - "openai/clip-vit-base-patch32", - ) - assert model is not None - assert isinstance(model, FlashPackTransformersModelMixin) - - with TemporaryDirectory() as tmpdir: - model.save_pretrained_flashpack(tmpdir) - assert os.path.exists(os.path.join(tmpdir, "model.flashpack")) - assert AutoModel.from_pretrained_flashpack(tmpdir) is not None + assert AutoModel.from_pretrained_flashpack(tmpdir) is not None \ No newline at end of file From 23f5d29209818c2c3ce5f37e4016426eb75f1ec7 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 26 Nov 2025 01:20:53 +0000 Subject: [PATCH 07/11] update tests --- tests/test_speed_comparison.py | 259 +++++++++++++++++---------------- tests/test_wan_pipeline.py | 62 +++++--- 2 files changed, 173 insertions(+), 148 deletions(-) diff --git a/tests/test_speed_comparison.py b/tests/test_speed_comparison.py index 28888b6..9cbfa9d 100644 --- a/tests/test_speed_comparison.py +++ b/tests/test_speed_comparison.py @@ -12,133 +12,136 @@ from huggingface_hub import snapshot_download from transformers import GPT2Model -repo_dir = snapshot_download("gpt2") -pt_filename = os.path.join(repo_dir, "pytorch_model.bin") -sf_filename = os.path.join(repo_dir, "model.safetensors") -flashpack_filename = os.path.join(repo_dir, "model.flashpack") - -print("Preparing model") -model = GPT2Model.from_pretrained("gpt2", device_map="cuda") -if not os.path.exists(flashpack_filename): - pack_to_file(model, flashpack_filename, target_dtype=model.dtype) - -print("Running load time comparison (10 runs each)") - - -def cuda_sync(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -num_runs = 10 -times_pt = [] -times_sf = [] -times_sf_fast = [] -times_fp = [] - -# Repeat timings -for i in range(num_runs): - # PyTorch .bin - os.environ.pop("SAFETENSORS_FAST_GPU", None) - start = time.time() - state_dict = torch.load(pt_filename, map_location="cuda") - model.load_state_dict(state_dict, strict=False) - cuda_sync() - end = time.time() - times_pt.append(end - start) - - # Safetensors - start = time.time() - state_dict = safetensors.torch.load_file(sf_filename, device="cuda") - model.load_state_dict(state_dict, strict=False) - cuda_sync() - end = time.time() - times_sf.append(end - start) - - # Safetensors (fast gpu) - os.environ["SAFETENSORS_FAST_GPU"] = "1" - start = time.time() - state_dict = safetensors.torch.load_file(sf_filename, device="cuda") - model.load_state_dict(state_dict, strict=False) - cuda_sync() - end = time.time() - times_sf_fast.append(end - start) - - # Flashpack - start = time.time() - assign_from_file(model, flashpack_filename, device="cuda") - cuda_sync() - end = time.time() - times_fp.append(end - start) - -print("Timing complete. Means (s):") -print(f" pytorch: {np.mean(times_pt):.3f}") -print(f" safetensors: {np.mean(times_sf):.3f}") -print(f" safetensors (fast gpu): {np.mean(times_sf_fast):.3f}") -print(f" flashpack: {np.mean(times_fp):.3f}") - -# Plot configuration (aligned with scripts/plot-benchmark.py) -accelerate_color = "#0f5ef3" -flashpack_color = "#adff02" -label_color = "#111111" - -labels = [ - "pytorch", - "safetensors", - "safetensors (fast gpu)", - "flashpack", -] -means = [ - float(np.mean(times_pt)), - float(np.mean(times_sf)), - float(np.mean(times_sf_fast)), - float(np.mean(times_fp)), -] - -colors = [accelerate_color, accelerate_color, accelerate_color, flashpack_color] - -fig, ax = plt.subplots(figsize=(10, 4)) -ax.patch.set_facecolor((0, 0, 0, 0)) - -# Style spines and ticks -for spine in ax.spines.values(): - spine.set_color(label_color) -ax.xaxis.label.set_color(label_color) -ax.yaxis.label.set_color(label_color) -ax.tick_params(axis="x", colors=label_color) -ax.tick_params(axis="y", colors=label_color) - -y_pos = np.arange(len(labels)) -bars = ax.barh(y_pos, means, color=colors, alpha=0.8) -ax.set_yticks(y_pos) -ax.set_yticklabels(labels, color=label_color) -ax.invert_yaxis() # top-to-bottom order as specified - -ax.set_xlabel( - "Loading Time (seconds)", fontsize=12, fontweight="bold", color=label_color -) -ax.set_title( - "load_state_dict() Time Comparison", - fontsize=14, - fontweight="bold", - pad=16, - color=label_color, -) -ax.grid(axis="x", alpha=0.3, linestyle="--") - -# Add value labels at the end of each bar -for bar, val in zip(bars, means): - ax.text( - bar.get_width() + max(means) * 0.01, - bar.get_y() + bar.get_height() / 2, - f"{val:.2f}s", - va="center", - ha="left", - fontsize=9, +def test_speed_comparison() -> None: + """ + Test the speed comparison between PyTorch, Safetensors, and Flashpack. + """ + repo_dir = snapshot_download("gpt2") + pt_filename = os.path.join(repo_dir, "pytorch_model.bin") + sf_filename = os.path.join(repo_dir, "model.safetensors") + flashpack_filename = os.path.join(repo_dir, "model.flashpack") + + print("Preparing model") + model = GPT2Model.from_pretrained("gpt2", device_map="cuda") + if not os.path.exists(flashpack_filename): + pack_to_file(model, flashpack_filename, target_dtype=model.dtype) + + print("Running load time comparison (10 runs each)") + + + def cuda_sync(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + num_runs = 10 + times_pt = [] + times_sf = [] + times_sf_fast = [] + times_fp = [] + + # Repeat timings + for i in range(num_runs): + # PyTorch .bin + os.environ.pop("SAFETENSORS_FAST_GPU", None) + start = time.time() + state_dict = torch.load(pt_filename, map_location="cuda") + model.load_state_dict(state_dict, strict=False) + cuda_sync() + end = time.time() + times_pt.append(end - start) + + # Safetensors + start = time.time() + state_dict = safetensors.torch.load_file(sf_filename, device="cuda") + model.load_state_dict(state_dict, strict=False) + cuda_sync() + end = time.time() + times_sf.append(end - start) + + # Safetensors (fast gpu) + os.environ["SAFETENSORS_FAST_GPU"] = "1" + start = time.time() + state_dict = safetensors.torch.load_file(sf_filename, device="cuda") + model.load_state_dict(state_dict, strict=False) + cuda_sync() + end = time.time() + times_sf_fast.append(end - start) + + # Flashpack + start = time.time() + assign_from_file(model, flashpack_filename, device="cuda") + cuda_sync() + end = time.time() + times_fp.append(end - start) + + print("Timing complete. Means (s):") + print(f" pytorch: {np.mean(times_pt):.3f}") + print(f" safetensors: {np.mean(times_sf):.3f}") + print(f" safetensors (fast gpu): {np.mean(times_sf_fast):.3f}") + print(f" flashpack: {np.mean(times_fp):.3f}") + + # Plot configuration (aligned with scripts/plot-benchmark.py) + accelerate_color = "#0f5ef3" + flashpack_color = "#adff02" + label_color = "#111111" + + labels = [ + "pytorch", + "safetensors", + "safetensors (fast gpu)", + "flashpack", + ] + means = [ + float(np.mean(times_pt)), + float(np.mean(times_sf)), + float(np.mean(times_sf_fast)), + float(np.mean(times_fp)), + ] + + colors = [accelerate_color, accelerate_color, accelerate_color, flashpack_color] + + fig, ax = plt.subplots(figsize=(10, 4)) + ax.patch.set_facecolor((0, 0, 0, 0)) + + # Style spines and ticks + for spine in ax.spines.values(): + spine.set_color(label_color) + ax.xaxis.label.set_color(label_color) + ax.yaxis.label.set_color(label_color) + ax.tick_params(axis="x", colors=label_color) + ax.tick_params(axis="y", colors=label_color) + + y_pos = np.arange(len(labels)) + bars = ax.barh(y_pos, means, color=colors, alpha=0.8) + ax.set_yticks(y_pos) + ax.set_yticklabels(labels, color=label_color) + ax.invert_yaxis() # top-to-bottom order as specified + + ax.set_xlabel( + "Loading Time (seconds)", fontsize=12, fontweight="bold", color=label_color + ) + ax.set_title( + "load_state_dict() Time Comparison", + fontsize=14, + fontweight="bold", + pad=16, color=label_color, ) - -plt.tight_layout() -output_path = "./speed_comparison.png" -plt.savefig(output_path, dpi=300, bbox_inches="tight", transparent=True) -print(f"Graph saved to: {output_path}") + ax.grid(axis="x", alpha=0.3, linestyle="--") + + # Add value labels at the end of each bar + for bar, val in zip(bars, means): + ax.text( + bar.get_width() + max(means) * 0.01, + bar.get_y() + bar.get_height() / 2, + f"{val:.2f}s", + va="center", + ha="left", + fontsize=9, + color=label_color, + ) + + plt.tight_layout() + output_path = "./speed_comparison.png" + plt.savefig(output_path, dpi=300, bbox_inches="tight", transparent=True) + print(f"Graph saved to: {output_path}") diff --git a/tests/test_wan_pipeline.py b/tests/test_wan_pipeline.py index 0cf801a..69bcb14 100644 --- a/tests/test_wan_pipeline.py +++ b/tests/test_wan_pipeline.py @@ -1,7 +1,7 @@ import os -import sys from typing import Optional +import pytest import torch from diffusers.models import AutoencoderKLWan, WanTransformer3DModel from diffusers.pipelines import WanPipeline @@ -55,21 +55,25 @@ def __init__( HERE = os.path.dirname(os.path.abspath(__file__)) -pipeline_dir = os.path.join(HERE, "wan_pipeline") -os.makedirs(pipeline_dir, exist_ok=True) +PIPELINE_DIR = os.path.join(HERE, "wan_pipeline") -if len(sys.argv) < 2: - raise ValueError("Usage: python wan_pipe.py ") -if sys.argv[1] not in ["save", "load"]: - raise ValueError("Usage: python wan_pipe.py ") +@pytest.fixture(scope="module") +def repo_dir(): + """Download and cache the Wan model repository.""" + return snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") -is_save = sys.argv[1] == "save" -is_load = sys.argv[1] == "load" -repo_dir = snapshot_download("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") +@pytest.fixture(scope="module") +def pipeline_dir(): + """Return the directory for saving/loading the flashpack pipeline.""" + os.makedirs(PIPELINE_DIR, exist_ok=True) + return PIPELINE_DIR -if is_save: + +@pytest.fixture(scope="module") +def saved_pipeline(repo_dir, pipeline_dir): + """Save the pipeline using flashpack and return the path.""" transformer = FlashPackWanTransformer3DModel.from_pretrained( os.path.join(repo_dir, "transformer"), torch_dtype=torch.bfloat16, @@ -98,35 +102,53 @@ def __init__( ) with timer("save"): - pipeline.save_pretrained_flashpack( - pipeline_dir, - ) + pipeline.save_pretrained_flashpack(pipeline_dir) + + return pipeline_dir + -elif is_load: +def test_save_pipeline(saved_pipeline): + """Test that the pipeline can be saved using flashpack.""" + assert os.path.exists(saved_pipeline) + # Check that the expected files exist + assert os.path.isdir(saved_pipeline) + + +def test_load_and_inference_accelerate(repo_dir): + """Test loading and running inference with accelerate.""" device = "cuda" if torch.cuda.is_available() else "cpu" - generator = torch.Generator(device=device).manual_seed(42) + with timer("load_and_inference_accelerate"): pipeline = FlashPackWanPipeline.from_pretrained( repo_dir, device_map="cuda", torch_dtype=torch.bfloat16, ) - pipeline( + output = pipeline( prompt="A beautiful sunset over a calm ocean.", width=832, height=480, num_inference_steps=28, ) + assert output is not None + + +def test_load_and_inference_flashpack(saved_pipeline): + """Test loading and running inference with flashpack.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + generator = torch.Generator(device=device).manual_seed(42) + with timer("load_and_inference_flashpack"): pipeline = FlashPackWanPipeline.from_pretrained_flashpack( - pipeline_dir, device_map=device, silent=False + saved_pipeline, device_map=device, silent=False ) - pipeline( + output = pipeline( prompt="A beautiful sunset over a calm ocean.", width=832, height=480, num_inference_steps=28, generator=generator, ) - generator.manual_seed(42) + + assert output is not None From 8ea58618ad3e089526c14c5e4fa74160c2c61e56 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 26 Nov 2025 01:21:25 +0000 Subject: [PATCH 08/11] lints --- tests/test_integrations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 69cf37c..9595480 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -48,4 +48,4 @@ def test_transformers() -> None: with TemporaryDirectory() as tmpdir: model.save_pretrained_flashpack(tmpdir) assert os.path.exists(os.path.join(tmpdir, "model.flashpack")) - assert AutoModel.from_pretrained_flashpack(tmpdir) is not None \ No newline at end of file + assert AutoModel.from_pretrained_flashpack(tmpdir) is not None From 68034ad390ef2334a299f4cc1653aec90dbc8efa Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 26 Nov 2025 01:23:07 +0000 Subject: [PATCH 09/11] more lints --- tests/test_speed_comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_speed_comparison.py b/tests/test_speed_comparison.py index 9cbfa9d..b936e03 100644 --- a/tests/test_speed_comparison.py +++ b/tests/test_speed_comparison.py @@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download from transformers import GPT2Model + def test_speed_comparison() -> None: """ Test the speed comparison between PyTorch, Safetensors, and Flashpack. @@ -28,7 +29,6 @@ def test_speed_comparison() -> None: print("Running load time comparison (10 runs each)") - def cuda_sync(): if torch.cuda.is_available(): torch.cuda.synchronize() From fb33db825f265e8c92844abec6f82d1377eb55ef Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 26 Nov 2025 01:30:15 +0000 Subject: [PATCH 10/11] more lints --- tests/test_integrations.py | 2 +- tests/test_wan_pipeline.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 9595480..7d6a73f 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -37,7 +37,7 @@ def test_transformers() -> None: from flashpack.integrations.transformers.model import ( FlashPackTransformersModelMixin, ) - from transformers.models import AutoModel + from transformers.models.auto.modeling_auto import AutoModel model = AutoModel.from_pretrained( "openai/clip-vit-base-patch32", diff --git a/tests/test_wan_pipeline.py b/tests/test_wan_pipeline.py index 69bcb14..68ab69a 100644 --- a/tests/test_wan_pipeline.py +++ b/tests/test_wan_pipeline.py @@ -116,8 +116,6 @@ def test_save_pipeline(saved_pipeline): def test_load_and_inference_accelerate(repo_dir): """Test loading and running inference with accelerate.""" - device = "cuda" if torch.cuda.is_available() else "cpu" - with timer("load_and_inference_accelerate"): pipeline = FlashPackWanPipeline.from_pretrained( repo_dir, From e39895a166f9829c29eecdcbf8155b34d4732087 Mon Sep 17 00:00:00 2001 From: Benjamin Paine Date: Wed, 26 Nov 2025 01:52:15 +0000 Subject: [PATCH 11/11] add docs, update colors --- README.md | 96 ++++++++++++++++++++++++++++++++++ scripts/plot_benchmark.py | 2 +- tests/test_speed_comparison.py | 2 +- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1591f17..6313707 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,10 @@

Disk-to-GPU Tensor loading at up to 25Gbps without GDS

+## Updates + +- **2025-11-25**: Now supports **multiple data types per checkpoint** with no regressions in speed! +
@@ -89,3 +93,95 @@ model = nn.Module(...) pack_to_file(model, flashpack_path) # write state dict to file assign_from_file(model, flashpack_path) # load state dict from file ``` + +# CLI Commands + +FlashPack provides a command-line interface for converting, inspecting, and reverting flashpack files. + +## `flashpack convert` + +Convert a model to a flashpack file. + +```bash +flashpack convert [destination_path] [options] +``` + +**Arguments:** +- `path_or_repo_id` - Local path or Hugging Face repository ID +- `destination_path` - (Optional) Output path for the flashpack file + +**Options:** +| Option | Description | +|--------|-------------| +| `--subfolder` | Subfolder of the model (for repo_id) | +| `--variant` | Model variant (for repo_id) | +| `--dtype` | Target dtype for the flashpack file. When omitted, no type changes are made | +| `--ignore-names` | Tensor names to ignore (can be specified multiple times) | +| `--ignore-prefixes` | Tensor prefixes to ignore (can be specified multiple times) | +| `--ignore-suffixes` | Tensor suffixes to ignore (can be specified multiple times) | +| `--use-transformers` | Load the path as a transformers model | +| `--use-diffusers` | Load the path as a diffusers model | +| `-v, --verbose` | Enable verbose output | + +**Examples:** +```bash +# Convert a local model +flashpack convert ./my_model ./my_model.flashpack + +# Convert from Hugging Face +flashpack convert stabilityai/stable-diffusion-xl-base-1.0 --subfolder unet --use-diffusers + +# Convert with specific dtype +flashpack convert ./my_model ./my_model.flashpack --dtype float16 +``` + +## `flashpack revert` + +Revert a flashpack file back to safetensors or torch format. + +```bash +flashpack revert [destination_path] [options] +``` + +**Arguments:** +- `path` - Path to the flashpack file +- `destination_path` - (Optional) Output path for the reverted file + +**Options:** +| Option | Description | +|--------|-------------| +| `-v, --verbose` | Enable verbose output | + +**Example:** +```bash +flashpack revert ./my_model.flashpack ./my_model.safetensors +``` + +## `flashpack metadata` + +Print the metadata of a flashpack file. + +```bash +flashpack metadata [options] +``` + +**Arguments:** +- `path` - Path to the flashpack file + +**Options:** +| Option | Description | +|--------|-------------| +| `-i, --show-index` | Show the tensor index | +| `-j, --json` | Output metadata in JSON format | + +**Examples:** +```bash +# View basic metadata +flashpack metadata ./my_model.flashpack + +# View metadata with tensor index +flashpack metadata ./my_model.flashpack --show-index + +# Output as JSON +flashpack metadata ./my_model.flashpack --json +``` diff --git a/scripts/plot_benchmark.py b/scripts/plot_benchmark.py index 44a7e20..402a004 100644 --- a/scripts/plot_benchmark.py +++ b/scripts/plot_benchmark.py @@ -10,7 +10,7 @@ # configuration accelerate_color = "#0f5ef3" flashpack_color = "#adff02" -label_color = "#111111" +label_color = "#eeeeee" model_labels = { "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": "Wan2.1 1.3B DiT", "Wan-AI/Wan2.1-T2V-14B-Diffusers": "Wan2.1 14B DiT", diff --git a/tests/test_speed_comparison.py b/tests/test_speed_comparison.py index b936e03..77320a0 100644 --- a/tests/test_speed_comparison.py +++ b/tests/test_speed_comparison.py @@ -83,7 +83,7 @@ def cuda_sync(): # Plot configuration (aligned with scripts/plot-benchmark.py) accelerate_color = "#0f5ef3" flashpack_color = "#adff02" - label_color = "#111111" + label_color = "#eeeeee" labels = [ "pytorch",