Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mlx_engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def load_model(
kv_bits: Optional[int] = None,
kv_group_size: Optional[int] = None,
quantized_kv_start: Optional[int] = None,
flash_mode: Optional[bool] = False,
flash_ram_gb: Optional[float] = 10.0,
flash_debug: Optional[bool] = False,
) -> ModelKit | VisionModelKit:
"""
Load a language model or vision-language model from the specified path.
Expand All @@ -146,6 +149,9 @@ def load_model(
kv_bits (Optional[int]): Number of bits for KV cache quantization.
kv_group_size (Optional[int]): Group size for KV cache quantization.
quantized_kv_start (Optional[int]): Step to begin KV cache quantization when enabled.
flash_mode (Optional[bool]): Enable flash weight streaming for large models over RAM.
flash_ram_gb (Optional[float]): RAM budget in GB for the flash mode page cache.
flash_debug (Optional[bool]): Enable verbose debugging for flash loading.

Returns:
ModelKit | VisionModelKit: An initialized model instance:
Expand All @@ -158,6 +164,21 @@ def load_model(
ValueError: If the model configuration is invalid or unsupported
"""
set_seed(seed)

if flash_mode:
try:
from mlx_flash import FlashConfig
from mlx_flash.integration.lmstudio import apply_flash_patch

flash_cfg = FlashConfig(
enabled=True,
ram_budget_gb=flash_ram_gb,
debug=flash_debug,
)
apply_flash_patch(flash_cfg)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Patch the loader actually used by mlx-engine

When flash_mode=True, this only calls mlx_flash.integration.lmstudio.apply_flash_patch(), but that integration monkey-patches mlx_lm.load while every real model load in this repo goes through mlx_lm.utils.load instead: ModelKit._full_model_init() uses it in mlx_engine/model_kit/model_kit.py:92, BatchedModelKit.__init__() uses it in mlx_engine/model_kit/batched_model_kit.py:78, and the batchability probe uses the mlx_lm.utils.load alias in mlx_engine/generate.py:226. So the new flag never routes mlx-engine's text-model loads through FlashManager, and large models still follow the original eager load path/OOM behavior.

Useful? React with 👍 / 👎.

except ImportError:
logger.warning("flash_mode requested but mlx-flash is not installed.")

model_path = Path(model_path)
config_json = json.loads((model_path / "config.json").read_text())
model_type = config_json.get("model_type", None)
Expand Down
Loading