From 350208657dbf457338f19af3b93514b56c2da690 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 10 Nov 2025 14:46:15 -0800 Subject: [PATCH 1/2] Add switchable in-flight weight updates to Generator Summary: This change adds optional parameters to Generator.update_weights() to support in-flight weight updates without waiting for pending requests to complete. The original blocking behavior is preserved as default, with new opt-in parameters: - wait_for_pending (default=True): When False, updates weights immediately without draining the request queue - reset_cache (default=True): When False, preserves KV cache during updates This enables faster weight updates during training at the cost of potential mid-generation weight switching for in-flight requests. Test Plan: - Verified Python syntax compiles successfully - Default behavior unchanged (backwards compatible) - New behavior available via explicit parameter flags --- src/forge/actors/generator.py | 65 +++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 32ae69906..3538b06d2 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -426,18 +426,28 @@ async def run(self) -> None: self.request_lock.notify_all() @endpoint - async def update_weights(self, version: int) -> None: + async def update_weights( + self, version: int, *, wait_for_pending: bool = True, reset_cache: bool = True + ) -> None: """Update weights on base model from a generator version to be found in a torchstore volume. Args: generator_version (int): Generator version from which to update. This will correspond to a key in a torchstore volume. + wait_for_pending (bool): If True (default), blocks new requests and waits for pending requests to + complete before updating weights. If False, updates weights immediately without waiting, + allowing in-flight updates (requests may see new weights mid-generation). + reset_cache (bool): If True (default), resets the KV prefix cache after updating weights. + Set to False to preserve cache when doing in-flight updates. Example: >>> trainer.train_step(...) >>> version += 1 >>> await trainer.push_weights() + >>> # Original blocking behavior (default) >>> generator.update_weights(version) + >>> # In-flight updates (opt-in) + >>> generator.update_weights(version, wait_for_pending=False, reset_cache=False) """ # TODO: enable shared memory prefetch for DCP-based weight sync if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: @@ -448,27 +458,28 @@ async def update_weights(self, version: int) -> None: # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests - async with self.request_lock: - self.accepting_requests = False - curr_requests = [fut for _, fut in self.requests.values()] - if curr_requests: - # Record pending requests metrics - record_metric( - "generator_perf/update_weights/avg_pending_requests", - len(curr_requests), - Reduce.MEAN, - ) - record_metric( - "generator_perf/update_weights/max_pending_requests", - len(curr_requests), - Reduce.MAX, - ) - logger.debug(f"Waiting for {len(curr_requests)} pending requests") - - # Wait until all pending requests have been processed - # TODO: If generating long sequences, this might be long and will block - # generator weight updates - await self.request_lock.wait_for(lambda: len(self.requests) == 0) + if wait_for_pending: + async with self.request_lock: + self.accepting_requests = False + curr_requests = [fut for _, fut in self.requests.values()] + if curr_requests: + # Record pending requests metrics + record_metric( + "generator_perf/update_weights/avg_pending_requests", + len(curr_requests), + Reduce.MEAN, + ) + record_metric( + "generator_perf/update_weights/max_pending_requests", + len(curr_requests), + Reduce.MAX, + ) + logger.debug(f"Waiting for {len(curr_requests)} pending requests") + + # Wait until all pending requests have been processed + # TODO: If generating long sequences, this might be long and will block + # generator weight updates + await self.request_lock.wait_for(lambda: len(self.requests) == 0) # Record weight update metrics record_metric( @@ -492,12 +503,14 @@ async def update_weights(self, version: int) -> None: self.generator_version = version # After updating the weights, we need to reset the KV cache - self.scheduler.reset_prefix_cache() + if reset_cache: + self.scheduler.reset_prefix_cache() # Resume accepting requests and wake up any waiting generate() calls - async with self.request_lock: - self.accepting_requests = True - self.request_lock.notify_all() + if wait_for_pending: + async with self.request_lock: + self.accepting_requests = True + self.request_lock.notify_all() logger.info(f"Weight update completed (now v{self.generator_version})") From 947592cc420a0c3342b14a79a1f671b9ff445ebf Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 10 Nov 2025 14:53:22 -0800 Subject: [PATCH 2/2] Make in-flight weight update behavior configurable via dataclass Summary: Added two new configuration fields to the Generator dataclass: - wait_for_pending_on_update (default=True): Controls whether weight updates wait for pending requests to complete - reset_cache_on_update (default=True): Controls whether to reset KV cache after weight updates The update_weights() method now uses these config values as defaults, but still allows per-call overrides via optional parameters. This enables users to configure the behavior globally via config files while maintaining flexibility for per-call customization. Test Plan: - Verified Python syntax compiles successfully - Default behavior unchanged (backwards compatible) - Config values can be set in Generator instantiation - Per-call overrides still work as before --- src/forge/actors/generator.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 3538b06d2..94a4affe8 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -80,6 +80,10 @@ class Generator(ForgeActor): sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled. + wait_for_pending_on_update (bool): If True (default), weight updates will block new requests and wait + for pending requests to complete. If False, enables in-flight weight updates. + reset_cache_on_update (bool): If True (default), resets the KV prefix cache after weight updates. + Set to False to preserve cache during updates. Example: >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service( @@ -97,6 +101,8 @@ class Generator(ForgeActor): use_dcp_for_weight_sync: bool | None = None prefetch_weights_to_shm: bool = True n_fetcher_procs: int = 8 + wait_for_pending_on_update: bool = True + reset_cache_on_update: bool = True def __post_init__(self): super().__init__() @@ -427,28 +433,39 @@ async def run(self) -> None: @endpoint async def update_weights( - self, version: int, *, wait_for_pending: bool = True, reset_cache: bool = True + self, + version: int, + *, + wait_for_pending: bool | None = None, + reset_cache: bool | None = None, ) -> None: """Update weights on base model from a generator version to be found in a torchstore volume. Args: generator_version (int): Generator version from which to update. This will correspond to a key in a torchstore volume. - wait_for_pending (bool): If True (default), blocks new requests and waits for pending requests to + wait_for_pending (bool | None): If True, blocks new requests and waits for pending requests to complete before updating weights. If False, updates weights immediately without waiting, allowing in-flight updates (requests may see new weights mid-generation). - reset_cache (bool): If True (default), resets the KV prefix cache after updating weights. + If None (default), uses the value from wait_for_pending_on_update config. + reset_cache (bool | None): If True, resets the KV prefix cache after updating weights. Set to False to preserve cache when doing in-flight updates. + If None (default), uses the value from reset_cache_on_update config. Example: >>> trainer.train_step(...) >>> version += 1 >>> await trainer.push_weights() - >>> # Original blocking behavior (default) + >>> # Uses config defaults >>> generator.update_weights(version) - >>> # In-flight updates (opt-in) + >>> # Override config for this call >>> generator.update_weights(version, wait_for_pending=False, reset_cache=False) """ + # Use config defaults if not explicitly provided + if wait_for_pending is None: + wait_for_pending = self.wait_for_pending_on_update + if reset_cache is None: + reset_cache = self.reset_cache_on_update # TODO: enable shared memory prefetch for DCP-based weight sync if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: logger.info(f"[Generator] Fetching weights for v{version} to shared memory") @@ -474,7 +491,9 @@ async def update_weights( len(curr_requests), Reduce.MAX, ) - logger.debug(f"Waiting for {len(curr_requests)} pending requests") + logger.debug( + f"Waiting for {len(curr_requests)} pending requests" + ) # Wait until all pending requests have been processed # TODO: If generating long sequences, this might be long and will block