From a86e1586f4972119f6cba07342a25bbd0fee7c39 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Fri, 20 Mar 2026 23:12:52 -0700 Subject: [PATCH] feat: support flash weight streaming via mlx-flash --- mlx_engine/generate.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlx_engine/generate.py b/mlx_engine/generate.py index 0f9c4264..057c6803 100644 --- a/mlx_engine/generate.py +++ b/mlx_engine/generate.py @@ -128,6 +128,9 @@ def load_model( kv_bits: Optional[int] = None, kv_group_size: Optional[int] = None, quantized_kv_start: Optional[int] = None, + flash_mode: Optional[bool] = False, + flash_ram_gb: Optional[float] = 10.0, + flash_debug: Optional[bool] = False, ) -> ModelKit | VisionModelKit: """ Load a language model or vision-language model from the specified path. @@ -146,6 +149,9 @@ def load_model( kv_bits (Optional[int]): Number of bits for KV cache quantization. kv_group_size (Optional[int]): Group size for KV cache quantization. quantized_kv_start (Optional[int]): Step to begin KV cache quantization when enabled. + flash_mode (Optional[bool]): Enable flash weight streaming for large models over RAM. + flash_ram_gb (Optional[float]): RAM budget in GB for the flash mode page cache. + flash_debug (Optional[bool]): Enable verbose debugging for flash loading. Returns: ModelKit | VisionModelKit: An initialized model instance: @@ -158,6 +164,21 @@ def load_model( ValueError: If the model configuration is invalid or unsupported """ set_seed(seed) + + if flash_mode: + try: + from mlx_flash import FlashConfig + from mlx_flash.integration.lmstudio import apply_flash_patch + + flash_cfg = FlashConfig( + enabled=True, + ram_budget_gb=flash_ram_gb, + debug=flash_debug, + ) + apply_flash_patch(flash_cfg) + except ImportError: + logger.warning("flash_mode requested but mlx-flash is not installed.") + model_path = Path(model_path) config_json = json.loads((model_path / "config.json").read_text()) model_type = config_json.get("model_type", None)