Skip to content

Commit e8f6fa5

Browse files
committed
[Feature] Weight Synchronization Schemes - Core Infrastructure
ghstack-source-id: eab1403 Pull-Request: #3185
1 parent 9d5c276 commit e8f6fa5

File tree

9 files changed

+2408
-2
lines changed

9 files changed

+2408
-2
lines changed

docs/source/reference/collectors.rst

Lines changed: 249 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,257 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117

118118
Policy copy decision tree in Collectors.
119119

120-
Weight Synchronization in Distributed Environments
120+
Weight Synchronization using Weight Update Schemes
121121
--------------------------------------------------
122122

123+
RL pipelines are typically split in two big computational buckets: training, and inference.
124+
While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125+
synchronize its weights with the inference one.
126+
In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127+
used in both instances. From there, anything can happen:
128+
129+
- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130+
`DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
131+
for his instance of the policy.
132+
- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs
133+
synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond
134+
policy-to-policy weight synchronization strategies.
135+
- In the LLM world, the inference engine and the training one are very different: they will use different libraries,
136+
kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized
137+
vs non-quantized).
138+
This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends.
139+
- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively
140+
asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141+
is to store the weights on some intermediary server and let the workers fetch them when necessary.
142+
143+
TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144+
transfer:
145+
146+
- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147+
- A `Receiver` class that casts the weights to the destination module (policy or other utility module);
148+
- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149+
- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
150+
151+
Each of these classes is detailed below.
152+
153+
Usage Examples
154+
~~~~~~~~~~~~~~
155+
156+
.. note::
157+
**Runnable versions** of these examples are available in the repository:
158+
159+
- `examples/collectors/weight_sync_standalone.py <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_standalone.py>`_: Standalone weight synchronization
160+
- `examples/collectors/weight_sync_collectors.py <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_collectors.py>`_: Collector integration
161+
162+
Using Weight Update Schemes Independently
163+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164+
165+
Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
166+
167+
.. code-block:: python
168+
169+
import torch
170+
import torch.nn as nn
171+
from torch import multiprocessing as mp
172+
from tensordict import TensorDict
173+
from torchrl.weight_update import (
174+
MultiProcessWeightSyncScheme,
175+
SharedMemWeightSyncScheme,
176+
)
177+
178+
# Create a simple policy
179+
policy = nn.Linear(4, 2)
180+
181+
# Example 1: Multiprocess weight synchronization with state_dict
182+
# --------------------------------------------------------------
183+
# On the main process side (trainer):
184+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
185+
sender = scheme.create_sender()
186+
187+
# Register worker pipes
188+
parent_pipe, child_pipe = mp.Pipe()
189+
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)
190+
191+
# Send weights to workers
192+
weights = policy.state_dict()
193+
sender.update_weights(weights)
194+
195+
# On the worker process side:
196+
# receiver = scheme.create_receiver()
197+
# receiver.register_model(policy)
198+
# receiver.register_worker_transport(child_pipe)
199+
# # Receive and apply weights
200+
# result = receiver._transport.receive_weights(timeout=5.0)
201+
# if result is not None:
202+
# model_id, weights = result
203+
# receiver.apply_weights(weights)
204+
205+
# Example 2: Shared memory weight synchronization
206+
# ------------------------------------------------
207+
# Create shared memory scheme with auto-registration
208+
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
209+
shared_sender = shared_scheme.create_sender()
210+
211+
# Register worker pipes for lazy registration
212+
parent_pipe2, child_pipe2 = mp.Pipe()
213+
shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2)
214+
215+
# Send weights (automatically creates shared buffer on first send)
216+
weights_td = TensorDict.from_module(policy)
217+
shared_sender.update_weights(weights_td)
218+
219+
# Workers automatically see updates via shared memory!
220+
221+
Using Weight Update Schemes with Collectors
222+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
223+
224+
Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization
225+
across multiple inference workers:
226+
227+
.. code-block:: python
228+
229+
import torch.nn as nn
230+
from tensordict.nn import TensorDictModule
231+
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
232+
from torchrl.envs import GymEnv
233+
from torchrl.weight_update import (
234+
MultiProcessWeightSyncScheme,
235+
SharedMemWeightSyncScheme,
236+
)
237+
238+
# Create environment and policy
239+
env = GymEnv("CartPole-v1")
240+
policy = TensorDictModule(
241+
nn.Linear(env.observation_spec["observation"].shape[-1],
242+
env.action_spec.shape[-1]),
243+
in_keys=["observation"],
244+
out_keys=["action"],
245+
)
246+
247+
# Example 1: Single collector with multiprocess scheme
248+
# -----------------------------------------------------
249+
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
250+
251+
collector = SyncDataCollector(
252+
create_env_fn=lambda: GymEnv("CartPole-v1"),
253+
policy=policy,
254+
frames_per_batch=64,
255+
total_frames=1000,
256+
weight_sync_schemes={"policy": scheme},
257+
)
258+
259+
# Collect data and update weights periodically
260+
for i, data in enumerate(collector):
261+
# ... training step with data ...
262+
263+
# Update policy weights every N iterations
264+
if i % 10 == 0:
265+
new_weights = policy.state_dict()
266+
collector.update_policy_weights_(new_weights)
267+
268+
collector.shutdown()
269+
270+
# Example 2: Multiple collectors with shared memory
271+
# --------------------------------------------------
272+
# Shared memory is more efficient for frequent updates
273+
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
274+
275+
collector = MultiSyncDataCollector(
276+
create_env_fn=[
277+
lambda: GymEnv("CartPole-v1"),
278+
lambda: GymEnv("CartPole-v1"),
279+
lambda: GymEnv("CartPole-v1"),
280+
],
281+
policy=policy,
282+
frames_per_batch=192,
283+
total_frames=10000,
284+
weight_sync_schemes={"policy": shared_scheme},
285+
)
286+
287+
# Workers automatically see weight updates via shared memory
288+
for data in collector:
289+
# ... training ...
290+
collector.update_policy_weights_(TensorDict.from_module(policy))
291+
292+
collector.shutdown()
293+
294+
.. note::
295+
When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all
296+
processes share the same memory buffers. This is ideal for frequent weight updates but requires all
297+
processes to be on the same machine.
298+
299+
.. note::
300+
The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state
301+
dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured
302+
models and supports advanced features like lazy initialization.
303+
304+
Weight Senders
305+
~~~~~~~~~~~~~~
306+
307+
.. currentmodule:: torchrl.weight_update
308+
309+
.. autosummary::
310+
:toctree: generated/
311+
:template: rl_template.rst
312+
313+
WeightSender
314+
RayModuleTransformSender
315+
316+
Weight Receivers
317+
~~~~~~~~~~~~~~~~
318+
319+
.. currentmodule:: torchrl.weight_update
320+
321+
.. autosummary::
322+
:toctree: generated/
323+
:template: rl_template.rst
324+
325+
WeightReceiver
326+
RayModuleTransformReceiver
327+
328+
Transports
329+
~~~~~~~~~~
330+
331+
.. currentmodule:: torchrl.weight_update
332+
333+
.. autosummary::
334+
:toctree: generated/
335+
:template: rl_template.rst
336+
337+
TransportBackend
338+
MPTransport
339+
SharedMemTransport
340+
RayTransport
341+
RayActorTransport
342+
RPCTransport
343+
DistributedTransport
344+
345+
Schemes
346+
~~~~~~~
347+
348+
.. currentmodule:: torchrl.weight_update
349+
350+
.. autosummary::
351+
:toctree: generated/
352+
:template: rl_template.rst
353+
354+
WeightSyncScheme
355+
MultiProcessWeightSyncScheme
356+
SharedMemWeightSyncScheme
357+
NoWeightSyncScheme
358+
RayWeightSyncScheme
359+
RayModuleTransformScheme
360+
RPCWeightSyncScheme
361+
DistributedWeightSyncScheme
362+
363+
Legacy: Weight Synchronization in Distributed Environments
364+
----------------------------------------------------------
365+
366+
.. warning::
367+
The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon.
368+
The Weight update schemes, which provides more flexibility and a better compatibility with heavy
369+
weight transfers (e.g., LLMs) is to be preferred.
370+
123371
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
124372
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
125373
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.

0 commit comments

Comments
 (0)