From 0b5b071df1fde53960040835776dc64bd2644242 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 28 Jan 2026 15:09:56 -0500 Subject: [PATCH 1/2] Frontend support for magcache, threshold and skips. Signed-off-by: BuffMcBigHuge --- frontend/src/components/SettingsPanel.tsx | 72 +++++++++++++++++++++++ frontend/src/data/parameterMetadata.ts | 15 +++++ frontend/src/hooks/usePipelines.ts | 1 + frontend/src/hooks/useStreamState.ts | 3 + frontend/src/hooks/useWebRTC.ts | 6 ++ frontend/src/lib/api.ts | 6 ++ frontend/src/pages/StreamPage.tsx | 36 ++++++++++++ frontend/src/types/index.ts | 4 ++ 8 files changed, 143 insertions(+) diff --git a/frontend/src/components/SettingsPanel.tsx b/frontend/src/components/SettingsPanel.tsx index 9af3ed09a..5aa6b30e5 100644 --- a/frontend/src/components/SettingsPanel.tsx +++ b/frontend/src/components/SettingsPanel.tsx @@ -66,6 +66,12 @@ interface SettingsPanelProps { onNoiseControllerChange?: (enabled: boolean) => void; manageCache?: boolean; onManageCacheChange?: (enabled: boolean) => void; + useMagcache?: boolean; + onUseMagcacheChange?: (enabled: boolean) => void; + magcacheThresh?: number; + onMagcacheThreshChange?: (thresh: number) => void; + magcacheK?: number; + onMagcacheKChange?: (k: number) => void; quantization?: "fp8_e4m3fn" | null; onQuantizationChange?: (quantization: "fp8_e4m3fn" | null) => void; kvCacheAttentionBias?: number; @@ -123,6 +129,12 @@ export function SettingsPanel({ onNoiseControllerChange, manageCache = true, onManageCacheChange, + useMagcache = false, + onUseMagcacheChange, + magcacheThresh = 0.12, + onMagcacheThreshChange, + magcacheK = 4, + onMagcacheKChange, quantization = "fp8_e4m3fn", onQuantizationChange, kvCacheAttentionBias = 0.3, @@ -160,6 +172,11 @@ export function SettingsPanel({ vaceContextScale, onVaceContextScaleChange ); + const magcacheThreshSlider = useLocalSliderValue( + magcacheThresh, + onMagcacheThreshChange + ); + const magcacheKSlider = useLocalSliderValue(magcacheK, onMagcacheKChange); // Validation error states const [heightError, setHeightError] = useState(null); @@ -746,6 +763,61 @@ export function SettingsPanel({
+ {/* MagCache toggle - shown for pipelines that support it */} + {pipelines?.[pipelineId]?.supportsMagcache && ( + <> +
+ + {})} + variant="outline" + size="sm" + className="h-7" + > + {useMagcache ? "ON" : "OFF"} + +
+ {/* MagCache quality controls - only shown when MagCache is enabled */} + {useMagcache && ( + <> + parseFloat(v) || 0.12} + /> + Math.round(v)} + inputParser={v => parseInt(v) || 2} + /> + + )} + + )} + {/* KV Cache bias control - shown for pipelines that support it */} {pipelines?.[pipelineId]?.supportsKvCacheBias && ( = { tooltip: "Enables pipeline to automatically manage the cache which influences newly generated frames. Disable for manual control via Reset Cache.", }, + useMagcache: { + label: "MagCache:", + tooltip: + "Enables MagCache (magnitude-aware residual caching) to skip redundant denoising steps for faster inference. Toggle during a live stream to compare speed vs quality.", + }, + magcacheThresh: { + label: "MagCache Threshold:", + tooltip: + "Controls quality vs speed tradeoff. Lower values (0.05-0.08) = better quality, fewer skips. Higher values (0.15-0.25) = faster, more skips. Default: 0.12", + }, + magcacheK: { + label: "MagCache Max Skips:", + tooltip: + "Maximum consecutive denoising steps to skip before forcing a compute. Lower values (1) = better quality. Higher values (3-4) = faster but may cause instability. Default: 2", + }, resetCache: { label: "Reset Cache:", tooltip: diff --git a/frontend/src/hooks/usePipelines.ts b/frontend/src/hooks/usePipelines.ts index d8fb47ac4..28d1f634d 100644 --- a/frontend/src/hooks/usePipelines.ts +++ b/frontend/src/hooks/usePipelines.ts @@ -65,6 +65,7 @@ export function usePipelines() { supportsCacheManagement: schema.supports_cache_management, supportsKvCacheBias: schema.supports_kv_cache_bias, supportsQuantization: schema.supports_quantization, + supportsMagcache: schema.supports_magcache, minDimension: schema.min_dimension, recommendedQuantizationVramThreshold: schema.recommended_quantization_vram_threshold ?? undefined, diff --git a/frontend/src/hooks/useStreamState.ts b/frontend/src/hooks/useStreamState.ts index 176fc55b0..eda3c9a2b 100644 --- a/frontend/src/hooks/useStreamState.ts +++ b/frontend/src/hooks/useStreamState.ts @@ -150,6 +150,9 @@ export function useStreamState() { noiseScale: initialDefaults.noiseScale, noiseController: initialDefaults.noiseController, manageCache: true, + useMagcache: false, + magcacheThresh: 0.12, + magcacheK: 2, quantization: null, kvCacheAttentionBias: 0.3, paused: false, diff --git a/frontend/src/hooks/useWebRTC.ts b/frontend/src/hooks/useWebRTC.ts index 8d6dc990c..f51e5ff80 100644 --- a/frontend/src/hooks/useWebRTC.ts +++ b/frontend/src/hooks/useWebRTC.ts @@ -16,6 +16,9 @@ interface InitialParameters { noise_scale?: number; noise_controller?: boolean; manage_cache?: boolean; + use_magcache?: boolean; + magcache_thresh?: number; + magcache_K?: number; kv_cache_attention_bias?: number; vace_ref_images?: string[]; vace_context_scale?: number; @@ -324,6 +327,9 @@ export function useWebRTC(options?: UseWebRTCOptions) { noise_scale?: number; noise_controller?: boolean; manage_cache?: boolean; + use_magcache?: boolean; + magcache_thresh?: number; + magcache_K?: number; reset_cache?: boolean; kv_cache_attention_bias?: number; paused?: boolean; diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 4e765f1a1..a022f658b 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -22,11 +22,16 @@ export interface WebRTCOfferRequest { noise_scale?: number; noise_controller?: boolean; manage_cache?: boolean; + use_magcache?: boolean; + magcache_thresh?: number; + magcache_K?: number; kv_cache_attention_bias?: number; vace_ref_images?: string[]; vace_context_scale?: number; pipeline_ids?: string[]; images?: string[]; + first_frame_image?: string; + last_frame_image?: string; }; } @@ -390,6 +395,7 @@ export interface PipelineSchemaInfo { requires_models: boolean; supports_lora: boolean; supports_vace: boolean; + supports_magcache?: boolean; usage: string[]; // Pipeline config schema config_schema: PipelineConfigSchema; diff --git a/frontend/src/pages/StreamPage.tsx b/frontend/src/pages/StreamPage.tsx index e3bb66177..0c99260fb 100644 --- a/frontend/src/pages/StreamPage.tsx +++ b/frontend/src/pages/StreamPage.tsx @@ -524,6 +524,30 @@ export function StreamPage() { }); }; + const handleUseMagcacheChange = (enabled: boolean) => { + updateSettings({ useMagcache: enabled }); + // Send MagCache update to backend (runtime) + sendParameterUpdate({ + use_magcache: enabled, + }); + }; + + const handleMagcacheThreshChange = (thresh: number) => { + updateSettings({ magcacheThresh: thresh }); + // Send MagCache threshold update to backend (runtime) + sendParameterUpdate({ + magcache_thresh: thresh, + }); + }; + + const handleMagcacheKChange = (k: number) => { + updateSettings({ magcacheK: k }); + // Send MagCache K update to backend (runtime) + sendParameterUpdate({ + magcache_K: k, + }); + }; + const handleQuantizationChange = (quantization: "fp8_e4m3fn" | null) => { updateSettings({ quantization }); // Note: This setting requires pipeline reload, so we don't send parameter update here @@ -976,6 +1000,7 @@ export function StreamPage() { noise_scale?: number; noise_controller?: boolean; manage_cache?: boolean; + use_magcache?: boolean; kv_cache_attention_bias?: number; spout_sender?: { enabled: boolean; name: string }; spout_receiver?: { enabled: boolean; name: string }; @@ -1007,6 +1032,11 @@ export function StreamPage() { initialParameters.manage_cache = settings.manageCache ?? true; } + // MagCache for pipelines that support it + if (currentPipeline?.supportsMagcache) { + initialParameters.use_magcache = settings.useMagcache ?? false; + } + // KV cache bias for pipelines that support it if (currentPipeline?.supportsKvCacheBias) { initialParameters.kv_cache_attention_bias = @@ -1356,6 +1386,12 @@ export function StreamPage() { onNoiseControllerChange={handleNoiseControllerChange} manageCache={settings.manageCache ?? true} onManageCacheChange={handleManageCacheChange} + useMagcache={settings.useMagcache ?? false} + onUseMagcacheChange={handleUseMagcacheChange} + magcacheThresh={settings.magcacheThresh ?? 0.12} + onMagcacheThreshChange={handleMagcacheThreshChange} + magcacheK={settings.magcacheK ?? 4} + onMagcacheKChange={handleMagcacheKChange} quantization={ settings.quantization !== undefined ? settings.quantization diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 74495bc23..499507354 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -59,6 +59,9 @@ export interface SettingsState { noiseScale?: number; noiseController?: boolean; manageCache?: boolean; + useMagcache?: boolean; + magcacheThresh?: number; + magcacheK?: number; quantization?: "fp8_e4m3fn" | null; kvCacheAttentionBias?: number; paused?: boolean; @@ -117,6 +120,7 @@ export interface PipelineInfo { supportsCacheManagement?: boolean; supportsKvCacheBias?: boolean; supportsQuantization?: boolean; + supportsMagcache?: boolean; minDimension?: number; recommendedQuantizationVramThreshold?: number | null; // Available VAE types from config schema enum (derived from vae_type field presence) From 3b89c3273fd2e9ff454920282c275c6594b3b9c6 Mon Sep 17 00:00:00 2001 From: BuffMcBigHuge Date: Wed, 28 Jan 2026 17:11:05 -0500 Subject: [PATCH 2/2] Magcache server-side implementation. Signed-off-by: BuffMcBigHuge --- src/scope/core/pipelines/base_schema.py | 12 + .../pipelines/krea_realtime_video/pipeline.py | 1 + .../longlive/modules/causal_model.py | 261 ++++++++++++++---- src/scope/core/pipelines/longlive/pipeline.py | 1 + src/scope/core/pipelines/longlive/schema.py | 1 + src/scope/core/pipelines/memflow/pipeline.py | 1 + .../core/pipelines/reward_forcing/pipeline.py | 1 + .../pipelines/streamdiffusionv2/pipeline.py | 1 + .../core/pipelines/wan2_1/blocks/denoise.py | 28 ++ .../pipelines/wan2_1/components/generator.py | 77 ++++++ src/scope/core/pipelines/wan2_1/magcache.py | 78 ++++++ .../wan2_1/vace/models/causal_vace_model.py | 232 +++++++++++++--- src/scope/server/schema.py | 17 ++ 13 files changed, 627 insertions(+), 84 deletions(-) create mode 100644 src/scope/core/pipelines/wan2_1/magcache.py diff --git a/src/scope/core/pipelines/base_schema.py b/src/scope/core/pipelines/base_schema.py index 1f4f1106e..f3e4e59f5 100644 --- a/src/scope/core/pipelines/base_schema.py +++ b/src/scope/core/pipelines/base_schema.py @@ -181,6 +181,7 @@ class BasePipelineConfig(BaseModel): supports_cache_management: ClassVar[bool] = False supports_kv_cache_bias: ClassVar[bool] = False supports_quantization: ClassVar[bool] = False + supports_magcache: ClassVar[bool] = False min_dimension: ClassVar[int] = 1 # Whether this pipeline contains modifications based on the original project modified: ClassVar[bool] = False @@ -232,6 +233,16 @@ class BasePipelineConfig(BaseModel): ref_images: list[str] | None = ref_images_field() vace_context_scale: float = vace_context_scale_field() + # MagCache quality parameters (only used when supports_magcache=True) + magcache_thresh: Annotated[float, Field(ge=0.05, le=0.5)] = Field( + default=0.12, + description="MagCache error threshold - higher values allow more skipping (faster but lower quality)", + ) + magcache_K: Annotated[int, Field(ge=1, le=4)] = Field( + default=2, + description="MagCache max consecutive skips - higher values allow more skipping but may cause instability", + ) + @classmethod def get_pipeline_metadata(cls) -> dict[str, str]: """Return pipeline identification metadata. @@ -321,6 +332,7 @@ def get_schema_with_metadata(cls) -> dict[str, Any]: metadata["supports_cache_management"] = cls.supports_cache_management metadata["supports_kv_cache_bias"] = cls.supports_kv_cache_bias metadata["supports_quantization"] = cls.supports_quantization + metadata["supports_magcache"] = cls.supports_magcache metadata["min_dimension"] = cls.min_dimension metadata["recommended_quantization_vram_threshold"] = ( cls.recommended_quantization_vram_threshold diff --git a/src/scope/core/pipelines/krea_realtime_video/pipeline.py b/src/scope/core/pipelines/krea_realtime_video/pipeline.py index 5a00255ef..b3ffc960b 100644 --- a/src/scope/core/pipelines/krea_realtime_video/pipeline.py +++ b/src/scope/core/pipelines/krea_realtime_video/pipeline.py @@ -181,6 +181,7 @@ def __init__( self.state.set("current_start_frame", 0) self.state.set("manage_cache", True) self.state.set("kv_cache_attention_bias", DEFAULT_KV_CACHE_ATTENTION_BIAS) + self.state.set("use_magcache", False) self.state.set("height", config.height) self.state.set("width", config.width) diff --git a/src/scope/core/pipelines/longlive/modules/causal_model.py b/src/scope/core/pipelines/longlive/modules/causal_model.py index 8607855bc..169e2b250 100644 --- a/src/scope/core/pipelines/longlive/modules/causal_model.py +++ b/src/scope/core/pipelines/longlive/modules/causal_model.py @@ -1,8 +1,11 @@ # Modified from https://github.com/NVlabs/LongLive # SPDX-License-Identifier: CC-BY-NC-SA-4.0 +import logging import math import torch + +logger = logging.getLogger(__name__) import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin @@ -1118,6 +1121,41 @@ def _forward_inference( ]) """ + # ------------------------------------------------------------------ + # MagCache (Magnitude-aware Cache) support + # + # We cache the residual at the *token* level right before the head: + # residual = x_after_blocks - x_before_blocks + # + # When MagCache decides to skip, we reuse the last cached residual: + # x_after_blocks ~= x_before_blocks + residual_cache + # + # This is adapted from the Wan2.1 MagCache reference code: + # `magcache_generate.py` in https://github.com/Zehong-Ma/MagCache + # and matches the paper's formulation (rt = vθ(xt,t) - xt). + # ------------------------------------------------------------------ + + def _magcache_reset(): + pass + '''self._magcache_step = 0 + self._magcache_accumulated_ratio = 1.0 + self._magcache_accumulated_err = 0.0 + self._magcache_accumulated_steps = 0 + self._magcache_residual_cache = None + self._magcache_stats = {"skipped": 0, "computed": 0} + ''' + + # Expose for the diffusion wrapper to reset cleanly (e.g., toggle change). + if not hasattr(self, "_magcache_reset"): + self._magcache_reset = _magcache_reset # type: ignore[attr-defined] + + magcache_cfg = getattr(self, "_magcache_config", None) + magcache_enabled = bool(getattr(magcache_cfg, "enabled", False)) + magcache_num_steps = getattr(self, "_magcache_num_steps", None) + + if magcache_enabled and not hasattr(self, "_magcache_step"): + _magcache_reset() + # time embeddings # with amp.autocast(dtype=torch.float32): e = self.time_embedding( @@ -1164,60 +1202,181 @@ def custom_forward(*inputs, **kwargs): return custom_forward - cache_update_info = None - cache_update_infos = [] # Collect cache update info for all blocks - for block_index, block in enumerate(self.blocks): - # print(f"block_index: {block_index}") - if torch.is_grad_enabled() and self.gradient_checkpointing: - kwargs.update( - { - "kv_cache": kv_cache[block_index], - "current_start": current_start, - "cache_start": cache_start, - } - ) - # print(f"forward checkpointing") - result = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - **kwargs, - use_reentrant=False, - ) - # Handle the result - if kv_cache is not None and isinstance(result, tuple): - x, block_cache_update_info = result - cache_update_infos.append((block_index, block_cache_update_info)) - # Extract base info for subsequent blocks (without concrete cache update details) - cache_update_info = block_cache_update_info[ - :2 - ] # (current_end, local_end_index) - else: - x = result + # Decide whether to skip the expensive transformer blocks. + # When skipping we reuse the cached residual and do not update KV/cross-attn caches. + skip_forward = False + ori_x = x + + # Debug: track MagCache statistics + if not hasattr(self, "_magcache_stats"): + self._magcache_stats = {"skipped": 0, "computed": 0} + + if magcache_enabled and magcache_num_steps is not None: + from scope.core.pipelines.wan2_1.magcache import wan21_t2v_13b_mag_ratios + + retention_ratio = float(getattr(magcache_cfg, "retention_ratio", 0.2)) + retain_steps = int(float(magcache_num_steps) * retention_ratio) + step_idx = int(getattr(self, "_magcache_step", 0)) + residual_cache = getattr(self, "_magcache_residual_cache", None) + + # Match reference implementation: only consider skipping after retention phase + # The retention phase ensures early steps (which are most important for quality) + # are always computed fully. + if step_idx >= retain_steps: + ratios = getattr(self, "_magcache_ratios", None) + if ratios is None or int(ratios.shape[0]) != int(magcache_num_steps): + ratios = wan21_t2v_13b_mag_ratios(int(magcache_num_steps)) + setattr(self, "_magcache_ratios", ratios) + + if step_idx < int(ratios.shape[0]): + cur_ratio = float(ratios[step_idx]) + accumulated_ratio = float( + getattr(self, "_magcache_accumulated_ratio", 1.0) + ) + accumulated_steps = int( + getattr(self, "_magcache_accumulated_steps", 0) + ) + accumulated_err = float( + getattr(self, "_magcache_accumulated_err", 0.0) + ) + + # Accumulate error estimate for this step + # (matches reference: accumulate BEFORE deciding to skip) + accumulated_ratio *= cur_ratio + accumulated_steps += 1 + cur_skip_err = abs(1.0 - accumulated_ratio) + accumulated_err += cur_skip_err + + thresh = float(getattr(magcache_cfg, "thresh", 0.12)) + K = int(getattr(magcache_cfg, "K", 1)) + + # Check if we can skip: need a cached residual, error below threshold, + # and haven't exceeded max consecutive skips + if ( + residual_cache is not None + and accumulated_err < thresh + and accumulated_steps <= K + ): + skip_forward = True + # Save updated accumulation state for next step + setattr(self, "_magcache_accumulated_ratio", accumulated_ratio) + setattr(self, "_magcache_accumulated_steps", accumulated_steps) + setattr(self, "_magcache_accumulated_err", accumulated_err) + logger.debug( + f"MagCache SKIP: step={step_idx}, acc_err={accumulated_err:.4f} < thresh={thresh}, acc_steps={accumulated_steps} <= K={K}" + ) + else: + # Force compute: reset accumulation state + # (matches reference: reset when NOT skipping) + setattr(self, "_magcache_accumulated_ratio", 1.0) + setattr(self, "_magcache_accumulated_steps", 0) + setattr(self, "_magcache_accumulated_err", 0.0) + logger.debug( + f"MagCache COMPUTE: step={step_idx}, acc_err={accumulated_err:.4f}, thresh={thresh}, " + f"acc_steps={accumulated_steps}, K={K}, cache={'exists' if residual_cache is not None else 'None'}" + ) else: - kwargs.update( - { - "kv_cache": kv_cache[block_index], - "crossattn_cache": crossattn_cache[block_index], - "current_start": current_start, - "cache_start": cache_start, - } + logger.debug( + f"MagCache RETAIN: step={step_idx} < retain_steps={retain_steps}" ) - # print(f"forward no checkpointing") - result = block(x, **kwargs) - # Handle the result - if kv_cache is not None and isinstance(result, tuple): - x, block_cache_update_info = result - cache_update_infos.append((block_index, block_cache_update_info)) - # Extract base info for subsequent blocks (without concrete cache update details) - cache_update_info = block_cache_update_info[ - :2 - ] # (current_end, local_end_index) + + cache_update_infos = [] # Collected only when not skipping + + # Safety check: if we're about to skip but the KV cache expects updates, + # we must not skip to avoid cache index misalignment. + # This can happen when magcache_K is set too high relative to num_steps. + if skip_forward and kv_cache is not None: + # Check if the first block's cache indicates we need an update + # by comparing current_start with the cache's expected position + first_cache = kv_cache[0] if kv_cache else None + if first_cache is not None: + cache_global_end = first_cache.get("global_end_index") + if cache_global_end is not None: + cache_end_val = ( + cache_global_end.item() + if hasattr(cache_global_end, "item") + else int(cache_global_end) + ) + # If current_start is beyond what cache expects, force compute + if current_start is not None and current_start > cache_end_val: + logger.warning( + f"MagCache: forcing compute due to KV cache misalignment " + f"(current_start={current_start}, cache_end={cache_end_val})" + ) + skip_forward = False + + if skip_forward: + x = ori_x + getattr(self, "_magcache_residual_cache") + self._magcache_stats["skipped"] += 1 + else: + self._magcache_stats["computed"] += 1 + for block_index, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + kwargs.update( + { + "kv_cache": kv_cache[block_index], + "current_start": current_start, + "cache_start": cache_start, + } + ) + result = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + **kwargs, + use_reentrant=False, + ) + if kv_cache is not None and isinstance(result, tuple): + x, block_cache_update_info = result + cache_update_infos.append((block_index, block_cache_update_info)) + else: + x = result else: - x = result - # log_gpu_memory(f"in _forward_inference: {x[0].device}") - # After all blocks are processed, apply cache updates in a single pass - if kv_cache is not None and cache_update_infos: - self._apply_cache_updates(kv_cache, cache_update_infos) + kwargs.update( + { + "kv_cache": kv_cache[block_index], + "crossattn_cache": crossattn_cache[block_index], + "current_start": current_start, + "cache_start": cache_start, + } + ) + result = block(x, **kwargs) + if kv_cache is not None and isinstance(result, tuple): + x, block_cache_update_info = result + cache_update_infos.append((block_index, block_cache_update_info)) + else: + x = result + + if kv_cache is not None and cache_update_infos: + self._apply_cache_updates(kv_cache, cache_update_infos) + + if magcache_enabled: + setattr(self, "_magcache_residual_cache", x - ori_x) + + ''' + if magcache_enabled and magcache_num_steps is not None: + next_step = int(getattr(self, "_magcache_step", 0)) + 1 + setattr(self, "_magcache_step", next_step) + if next_step >= int(magcache_num_steps): + # Log MagCache stats at end of each chunk + stats = getattr(self, "_magcache_stats", {"skipped": 0, "computed": 0}) + if stats["skipped"] > 0 or stats["computed"] > 0: + total = stats["skipped"] + stats["computed"] + skip_pct = 100 * stats["skipped"] / total if total > 0 else 0 + logger.info( + f"MagCache: skipped {stats['skipped']}/{total} steps ({skip_pct:.1f}%)" + ) + # Reset ALL state between chunks. + # CRITICAL: The residual cache MUST be cleared because each chunk has + # different latent content. Reusing a residual from chunk N on chunk N+1 + # causes severe artifacts ("looping noise") since the cached residual + # was computed for completely different spatial content. + self._magcache_step = 0 + self._magcache_accumulated_ratio = 1.0 + self._magcache_accumulated_err = 0.0 + self._magcache_accumulated_steps = 0 + self._magcache_stats = {"skipped": 0, "computed": 0} + self._magcache_residual_cache = None # Clear to prevent cross-chunk artifacts + ''' # head x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2)) diff --git a/src/scope/core/pipelines/longlive/pipeline.py b/src/scope/core/pipelines/longlive/pipeline.py index 70372183f..30754e68b 100644 --- a/src/scope/core/pipelines/longlive/pipeline.py +++ b/src/scope/core/pipelines/longlive/pipeline.py @@ -181,6 +181,7 @@ def __init__( self.state.set("current_start_frame", 0) self.state.set("manage_cache", True) self.state.set("kv_cache_attention_bias", 1.0) + self.state.set("use_magcache", False) self.state.set("height", config.height) self.state.set("width", config.width) diff --git a/src/scope/core/pipelines/longlive/schema.py b/src/scope/core/pipelines/longlive/schema.py index fa56d9e7d..2624466af 100644 --- a/src/scope/core/pipelines/longlive/schema.py +++ b/src/scope/core/pipelines/longlive/schema.py @@ -39,6 +39,7 @@ class LongLiveConfig(BasePipelineConfig): ] supports_cache_management = True + supports_magcache = True supports_quantization = True min_dimension = 16 modified = True diff --git a/src/scope/core/pipelines/memflow/pipeline.py b/src/scope/core/pipelines/memflow/pipeline.py index e4ad2f339..f8701b8a2 100644 --- a/src/scope/core/pipelines/memflow/pipeline.py +++ b/src/scope/core/pipelines/memflow/pipeline.py @@ -181,6 +181,7 @@ def __init__( self.state.set("current_start_frame", 0) self.state.set("manage_cache", True) self.state.set("kv_cache_attention_bias", 1.0) + self.state.set("use_magcache", False) self.state.set("height", config.height) self.state.set("width", config.width) diff --git a/src/scope/core/pipelines/reward_forcing/pipeline.py b/src/scope/core/pipelines/reward_forcing/pipeline.py index c6a0214de..f24413929 100644 --- a/src/scope/core/pipelines/reward_forcing/pipeline.py +++ b/src/scope/core/pipelines/reward_forcing/pipeline.py @@ -155,6 +155,7 @@ def __init__( self.state.set("current_start_frame", 0) self.state.set("manage_cache", True) self.state.set("kv_cache_attention_bias", 1.0) + self.state.set("use_magcache", False) self.state.set("height", config.height) self.state.set("width", config.width) diff --git a/src/scope/core/pipelines/streamdiffusionv2/pipeline.py b/src/scope/core/pipelines/streamdiffusionv2/pipeline.py index f594e2e54..fa3f6d6ad 100644 --- a/src/scope/core/pipelines/streamdiffusionv2/pipeline.py +++ b/src/scope/core/pipelines/streamdiffusionv2/pipeline.py @@ -157,6 +157,7 @@ def __init__( self.state.set("current_start_frame", 0) self.state.set("manage_cache", True) self.state.set("kv_cache_attention_bias", 1.0) + self.state.set("use_magcache", False) self.state.set("noise_scale", 0.7) self.state.set("noise_controller", True) diff --git a/src/scope/core/pipelines/wan2_1/blocks/denoise.py b/src/scope/core/pipelines/wan2_1/blocks/denoise.py index 2c344680b..dd26993de 100644 --- a/src/scope/core/pipelines/wan2_1/blocks/denoise.py +++ b/src/scope/core/pipelines/wan2_1/blocks/denoise.py @@ -123,6 +123,25 @@ def inputs(self) -> list[InputParam]: type_hint=float, description="Scaling factor for VACE hint injection", ), + # MagCache toggle (runtime) + InputParam( + "use_magcache", + default=False, + type_hint=bool, + description="Enable MagCache (magnitude-aware residual caching) for faster inference", + ), + InputParam( + "magcache_thresh", + default=0.12, + type_hint=float, + description="MagCache error threshold - lower = better quality, higher = faster", + ), + InputParam( + "magcache_K", + default=2, + type_hint=int, + description="MagCache max consecutive skips before forcing compute", + ), ] @property @@ -151,6 +170,7 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState batch_size = noise.shape[0] num_frames = noise.shape[1] denoising_step_list = block_state.current_denoising_step_list.clone() + magcache_num_steps = int(len(denoising_step_list)) conditional_dict = {"prompt_embeds": block_state.conditioning_embeds} @@ -196,6 +216,10 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState kv_cache_attention_bias=block_state.kv_cache_attention_bias, vace_context=block_state.vace_context, vace_context_scale=block_state.vace_context_scale, + use_magcache=block_state.use_magcache, + magcache_thresh=block_state.magcache_thresh, + magcache_K=block_state.magcache_K, + magcache_num_steps=magcache_num_steps, ) next_timestep = denoising_step_list[index + 1] @@ -232,6 +256,10 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState kv_cache_attention_bias=block_state.kv_cache_attention_bias, vace_context=block_state.vace_context, vace_context_scale=block_state.vace_context_scale, + use_magcache=block_state.use_magcache, + magcache_thresh=block_state.magcache_thresh, + magcache_K=block_state.magcache_K, + magcache_num_steps=magcache_num_steps, ) block_state.latents = denoised_pred diff --git a/src/scope/core/pipelines/wan2_1/components/generator.py b/src/scope/core/pipelines/wan2_1/components/generator.py index 09bf6968b..58e4465e0 100644 --- a/src/scope/core/pipelines/wan2_1/components/generator.py +++ b/src/scope/core/pipelines/wan2_1/components/generator.py @@ -1,9 +1,11 @@ # Modified from https://github.com/guandeh17/Self-Forcing import inspect import json +import logging import os import types +logger = logging.getLogger(__name__) import torch from scope.core.pipelines.utils import load_state_dict @@ -206,6 +208,21 @@ def _call_model(self, *args, **kwargs): } return self.model(*args, **accepted) + def _get_base_model(self, model): + """Walk down wrappers (PEFT, VACE) to find the actual model implementation.""" + curr = model + while True: + if hasattr(curr, "base_model"): # PEFT wrapper + curr = curr.base_model + elif hasattr(curr, "causal_wan_model"): # VACE wrapper + curr = curr.causal_wan_model + elif hasattr(curr, "model") and isinstance(curr.model, torch.nn.Module): + # Some other wrappers might use .model + curr = curr.model + else: + break + return curr + def forward( self, noisy_image_or_video: torch.Tensor, @@ -229,6 +246,12 @@ def forward( vace_context: torch.Tensor | None = None, vace_context_scale: float = 1.0, sink_recache_after_switch: bool = False, + # MagCache (runtime toggle) + use_magcache: bool | None = None, + magcache_thresh: float | None = None, + magcache_K: int | None = None, + magcache_retention_ratio: float | None = None, + magcache_num_steps: int | None = None, ) -> torch.Tensor: prompt_embeds = conditional_dict["prompt_embeds"] @@ -240,6 +263,60 @@ def forward( logits = None # X0 prediction + + # Apply MagCache settings to the underlying model (if provided). + # We find the actual model implementation through any wrappers (LoRA, VACE). + if use_magcache is not None: + # Lazy import to avoid bringing numpy into every import path. + from scope.core.pipelines.wan2_1.magcache import MagCacheConfig + + cfg = MagCacheConfig( + enabled=bool(use_magcache), + thresh=float(magcache_thresh) if magcache_thresh is not None else 0.12, + K=int(magcache_K) if magcache_K is not None else 2, + retention_ratio=float(magcache_retention_ratio) + if magcache_retention_ratio is not None + else 0.2, + ) + + # Find all modules in the stack that might have MagCache logic + target_models = [] + curr = self.model + target_models.append(curr) + while True: + if hasattr(curr, "base_model"): + curr = curr.base_model + target_models.append(curr) + elif hasattr(curr, "causal_wan_model"): + curr = curr.causal_wan_model + target_models.append(curr) + elif hasattr(curr, "model") and isinstance(curr.model, torch.nn.Module): + curr = curr.model + target_models.append(curr) + else: + break + + for model in target_models: + # Update config + step count, and reset internal state when toggled/changed. + if getattr(model, "_magcache_config", None) != cfg: + logger.info( + f"MagCache config changed on {model.__class__.__name__}: enabled={cfg.enabled}, thresh={cfg.thresh}, K={cfg.K}" + ) + model._magcache_config = cfg + if hasattr(model, "_magcache_reset"): + model._magcache_reset() + + if magcache_num_steps is not None: + if getattr(model, "_magcache_num_steps", None) != int( + magcache_num_steps + ): + logger.info( + f"MagCache num_steps set on {model.__class__.__name__}: {magcache_num_steps}" + ) + model._magcache_num_steps = int(magcache_num_steps) + if hasattr(model, "_magcache_reset"): + model._magcache_reset() + if kv_cache is not None: flow_pred = self._call_model( noisy_image_or_video.permute(0, 2, 1, 3, 4), diff --git a/src/scope/core/pipelines/wan2_1/magcache.py b/src/scope/core/pipelines/wan2_1/magcache.py new file mode 100644 index 000000000..7b1f2b392 --- /dev/null +++ b/src/scope/core/pipelines/wan2_1/magcache.py @@ -0,0 +1,78 @@ +""" +MagCache utilities for Wan2.1-family diffusion models. + +This implements the *runtime* portion of MagCache (NeurIPS 2025): +- Use a pre-calibrated per-step magnitude ratio curve (gamma_t) +- Accumulate an error estimate when skipping consecutive steps +- Reuse a cached residual when the estimated error stays below a threshold + +Paper: "MagCache: Fast Video Generation with Magnitude-Aware Cache" (arXiv:2506.09045) +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +def nearest_interp(src: np.ndarray, target_length: int) -> np.ndarray: + """Nearest-neighbor interpolate a 1D array to target_length.""" + if target_length <= 0: + raise ValueError("target_length must be positive") + src = np.asarray(src, dtype=np.float64) + src_length = int(src.shape[0]) + if src_length == 0: + raise ValueError("src must be non-empty") + if target_length == 1: + return np.asarray([src[-1]], dtype=np.float64) + if src_length == 1: + return np.repeat(src, target_length).astype(np.float64) + + scale = (src_length - 1) / (target_length - 1) + mapped_indices = np.round(np.arange(target_length) * scale).astype(int) + return src[mapped_indices] + + +# --------------------------------------------------------------------------- +# Pre-calibrated magnitude ratios +# --------------------------------------------------------------------------- +# +# The upstream reference implementation provides "mag_ratios" as an interleaved +# [cond_step0, uncond_step0, cond_step1, uncond_step1, ...] curve and runs the +# model twice per sampling step (CFG). Scope's LongLive pipeline does not do +# CFG in the diffusion wrapper, so we use the *conditional* curve. +# +# Source: https://raw.githubusercontent.com/Zehong-Ma/MagCache/refs/heads/main/MagCache4Wan2.1/magcache_generate.py +# + +# Wan2.1 T2V 1.3B, sample_steps=50 (50-step curve), conditional branch only. +# (The upstream array is length 100 with interleaved cond/uncond + initial padding.) +_WAN21_T2V_13B_INTERLEAVED_100 = np.array([1.0]*2+[1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939], + dtype=np.float64, +) + +def wan21_t2v_13b_mag_ratios(num_steps: int) -> np.ndarray: + """Return conditional-branch mag ratios for Wan2.1 T2V 1.3B.""" + # Conditional ratios are every other value starting at index 0: + # [cond0, uncond0, cond1, uncond1, ...] + cond = _WAN21_T2V_13B_INTERLEAVED_100[0::2] + # The source includes an initial padding entry; keep it (gamma_0 ~= 1.0). + if cond.shape[0] != 50: + # Should not happen, but keep robust. + cond = cond[:50] + if num_steps == cond.shape[0]: + return cond.copy() + return nearest_interp(cond, num_steps) + + +@dataclass +class MagCacheConfig: + enabled: bool = False + # Total accumulated error threshold (δ in the paper) + thresh: float = 0.12 + # Max consecutive skipped steps (K in the paper) + # Reference default is 2, which is more conservative for quality + K: int = 2 + # Fraction of initial steps to preserve unchanged (default 20%) + retention_ratio: float = 0.2 diff --git a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py index 96ed07d29..74c042ba0 100644 --- a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py +++ b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py @@ -2,6 +2,7 @@ # Adapted for causal/autoregressive generation with factory pattern # Pipeline-agnostic using duck typing - works with any CausalWanModel import inspect +import logging import math import torch @@ -12,6 +13,8 @@ create_vace_attention_block_class, ) +logger = logging.getLogger(__name__) + # TODO: Consolidate this with other pipeline implementations into a shared wan2_1/utils module. # This is a standard sinusoidal positional embedding - identical across all pipelines apart from krea which has forced dtype @@ -467,48 +470,211 @@ def custom_forward(*inputs, **kwargs): return custom_forward - # Process through blocks + # -------------------------------------------------------------- + # MagCache support (for VACE-wrapped causal Wan models) + # + # Same strategy as the base LongLive CausalWanModel: + # cache token-level residual right before the head, optionally skip blocks. + # -------------------------------------------------------------- + + def _magcache_reset(): + pass + ''' + self._magcache_step = 0 + self._magcache_accumulated_ratio = 1.0 + self._magcache_accumulated_err = 0.0 + self._magcache_accumulated_steps = 0 + self._magcache_residual_cache = None + self._magcache_stats = {"skipped": 0, "computed": 0} + ''' + if not hasattr(self, "_magcache_reset"): + self._magcache_reset = _magcache_reset # type: ignore[attr-defined] + + magcache_cfg = getattr(self, "_magcache_config", None) + magcache_enabled = bool(getattr(magcache_cfg, "enabled", False)) + magcache_num_steps = getattr(self, "_magcache_num_steps", None) + + if magcache_enabled and not hasattr(self, "_magcache_step"): + _magcache_reset() + + skip_forward = False + ori_x = x + + # Debug: track MagCache statistics + if not hasattr(self, "_magcache_stats"): + self._magcache_stats = {"skipped": 0, "computed": 0} + + if magcache_enabled and magcache_num_steps is not None: + from scope.core.pipelines.wan2_1.magcache import ( + wan21_t2v_13b_mag_ratios, + ) + + retention_ratio = float(getattr(magcache_cfg, "retention_ratio", 0.2)) + retain_steps = int(float(magcache_num_steps) * retention_ratio) + step_idx = int(getattr(self, "_magcache_step", 0)) + residual_cache = getattr(self, "_magcache_residual_cache", None) + + # Match reference implementation: only consider skipping after retention phase + # The retention phase ensures early steps (which are most important for quality) + # are always computed fully. + if step_idx >= retain_steps: + ratios = getattr(self, "_magcache_ratios", None) + if ratios is None or int(ratios.shape[0]) != int(magcache_num_steps): + ratios = wan21_t2v_13b_mag_ratios(int(magcache_num_steps)) + setattr(self, "_magcache_ratios", ratios) + + if step_idx < int(ratios.shape[0]): + cur_ratio = float(ratios[step_idx]) + accumulated_ratio = float( + getattr(self, "_magcache_accumulated_ratio", 1.0) + ) + accumulated_steps = int( + getattr(self, "_magcache_accumulated_steps", 0) + ) + accumulated_err = float( + getattr(self, "_magcache_accumulated_err", 0.0) + ) + + # Accumulate error estimate for this step + # (matches reference: accumulate BEFORE deciding to skip) + accumulated_ratio *= cur_ratio + accumulated_steps += 1 + cur_skip_err = abs(1.0 - accumulated_ratio) + accumulated_err += cur_skip_err + + thresh = float(getattr(magcache_cfg, "thresh", 0.12)) + K = int(getattr(magcache_cfg, "K", 1)) + + # Check if we can skip: need a cached residual, error below threshold, + # and haven't exceeded max consecutive skips + if ( + residual_cache is not None + and accumulated_err < thresh + and accumulated_steps <= K + ): + skip_forward = True + # Save updated accumulation state for next step + setattr(self, "_magcache_accumulated_ratio", accumulated_ratio) + setattr(self, "_magcache_accumulated_steps", accumulated_steps) + setattr(self, "_magcache_accumulated_err", accumulated_err) + logger.debug( + f"MagCache (VACE) SKIP: step={step_idx}, acc_err={accumulated_err:.4f} < thresh={thresh}, acc_steps={accumulated_steps} <= K={K}" + ) + else: + # Force compute: reset accumulation state + # (matches reference: reset when NOT skipping) + setattr(self, "_magcache_accumulated_ratio", 1.0) + setattr(self, "_magcache_accumulated_steps", 0) + setattr(self, "_magcache_accumulated_err", 0.0) + logger.debug( + f"MagCache (VACE) COMPUTE: step={step_idx}, acc_err={accumulated_err:.4f}, thresh={thresh}, " + f"acc_steps={accumulated_steps}, K={K}, cache={'exists' if residual_cache is not None else 'None'}" + ) + else: + logger.debug( + f"MagCache (VACE) RETAIN: step={step_idx} < retain_steps={retain_steps}" + ) + + # Process through blocks (or skip) cache_update_infos = [] - for block_index, block in enumerate(self.blocks): - # Build per-block kwargs: - # - kv_cache/crossattn_cache are always per-block indexed - # - Additional block_kwargs are dynamically filtered based on block's signature - # and automatically indexed if they're per-block lists - filtered_block_kwargs = self._filter_block_kwargs(block_kwargs, block_index) - per_block_kwargs = { - "kv_cache": kv_cache[block_index], - "current_start": current_start, - **filtered_block_kwargs, - } - - if torch.is_grad_enabled() and self.causal_wan_model.gradient_checkpointing: - kwargs = {**base_kwargs, **per_block_kwargs} - result = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, - **kwargs, - use_reentrant=False, + + # Safety check: if we're about to skip but the KV cache expects updates, + # we must not skip to avoid cache index misalignment. + # This can happen when magcache_K is set too high relative to num_steps. + if skip_forward and kv_cache is not None: + # Check if the first block's cache indicates we need an update + # by comparing current_start with the cache's expected position + first_cache = kv_cache[0] if kv_cache else None + if first_cache is not None: + cache_global_end = first_cache.get("global_end_index") + if cache_global_end is not None: + cache_end_val = ( + cache_global_end.item() + if hasattr(cache_global_end, "item") + else int(cache_global_end) + ) + # If current_start is beyond what cache expects, force compute + if current_start is not None and current_start > cache_end_val: + logger.warning( + f"MagCache (VACE): forcing compute due to KV cache misalignment " + f"(current_start={current_start}, cache_end={cache_end_val})" + ) + skip_forward = False + + if skip_forward: + x = ori_x + getattr(self, "_magcache_residual_cache") + self._magcache_stats["skipped"] += 1 + else: + self._magcache_stats["computed"] += 1 + for block_index, block in enumerate(self.blocks): + # Build per-block kwargs: + # - kv_cache/crossattn_cache are always per-block indexed + # - Additional block_kwargs are dynamically filtered based on block's signature + # and automatically indexed if they're per-block lists + filtered_block_kwargs = self._filter_block_kwargs( + block_kwargs, block_index ) - if kv_cache is not None and isinstance(result, tuple): - x, block_cache_update_info = result - cache_update_infos.append((block_index, block_cache_update_info)) - else: - x = result - else: - per_block_kwargs["crossattn_cache"] = crossattn_cache[block_index] - kwargs = {**base_kwargs, **per_block_kwargs} - result = block(x, **kwargs) - if kv_cache is not None and isinstance(result, tuple): - x, block_cache_update_info = result - cache_update_infos.append((block_index, block_cache_update_info)) + per_block_kwargs = { + "kv_cache": kv_cache[block_index], + "current_start": current_start, + **filtered_block_kwargs, + } + + if torch.is_grad_enabled() and self.causal_wan_model.gradient_checkpointing: + kwargs = {**base_kwargs, **per_block_kwargs} + result = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + **kwargs, + use_reentrant=False, + ) + if kv_cache is not None and isinstance(result, tuple): + x, block_cache_update_info = result + cache_update_infos.append((block_index, block_cache_update_info)) + else: + x = result else: - x = result + per_block_kwargs["crossattn_cache"] = crossattn_cache[block_index] + kwargs = {**base_kwargs, **per_block_kwargs} + result = block(x, **kwargs) + if kv_cache is not None and isinstance(result, tuple): + x, block_cache_update_info = result + cache_update_infos.append((block_index, block_cache_update_info)) + else: + x = result if kv_cache is not None and cache_update_infos: self.causal_wan_model._apply_cache_updates( kv_cache, cache_update_infos, **block_kwargs ) + if magcache_enabled and not skip_forward: + setattr(self, "_magcache_residual_cache", x - ori_x) + + if magcache_enabled and magcache_num_steps is not None: + next_step = int(getattr(self, "_magcache_step", 0)) + 1 + setattr(self, "_magcache_step", next_step) + if next_step >= int(magcache_num_steps): + # Log MagCache stats at end of each chunk + stats = getattr(self, "_magcache_stats", {"skipped": 0, "computed": 0}) + if stats["skipped"] > 0 or stats["computed"] > 0: + total = stats["skipped"] + stats["computed"] + skip_pct = 100 * stats["skipped"] / total if total > 0 else 0 + logger.info( + f"MagCache (VACE): skipped {stats['skipped']}/{total} steps ({skip_pct:.1f}%)" + ) + # Reset ALL state between chunks. + # CRITICAL: The residual cache MUST be cleared because each chunk has + # different latent content. Reusing a residual from chunk N on chunk N+1 + # causes severe artifacts ("looping noise") since the cached residual + # was computed for completely different spatial content. + self._magcache_step = 0 + self._magcache_accumulated_ratio = 1.0 + self._magcache_accumulated_err = 0.0 + self._magcache_accumulated_steps = 0 + self._magcache_stats = {"skipped": 0, "computed": 0} + self._magcache_residual_cache = None # Clear to prevent cross-chunk artifacts + x = self.causal_wan_model.head( x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2) ) diff --git a/src/scope/server/schema.py b/src/scope/server/schema.py index fbe7ef3bf..cf72bd6bd 100644 --- a/src/scope/server/schema.py +++ b/src/scope/server/schema.py @@ -96,6 +96,23 @@ class Parameters(BaseModel): default=None, description="Enable automatic cache management for parameter updates", ) + use_magcache: bool | None = Field( + default=None, + description="Enable MagCache (magnitude-aware residual caching) to skip redundant denoising steps. " + "Only supported by pipelines that advertise supports_magcache.", + ) + magcache_thresh: float | None = Field( + default=None, + description="MagCache error threshold. Lower values = better quality, fewer skips. Higher values = faster, more skips. Default: 0.12", + ge=0.05, + le=0.5, + ) + magcache_K: int | None = Field( + default=None, + description="MagCache max consecutive skips before forcing a compute. Lower = better quality. Higher = faster but may cause instability. Default: 2", + ge=1, + le=4, + ) reset_cache: bool | None = Field(default=None, description="Trigger a cache reset") kv_cache_attention_bias: float | None = Field( default=None,