@@ -117,9 +117,257 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117
118118   Policy copy decision tree in Collectors.
119119
120- Weight Synchronization in Distributed Environments 
120+ Weight Synchronization using Weight Update Schemes 
121121-------------------------------------------------- 
122122
123+ RL pipelines are typically split in two big computational buckets: training, and inference.
124+ While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125+ synchronize its weights with the inference one.
126+ In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127+ used in both instances. From there, anything can happen:
128+ 
129+ - In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130+   `DataCollectors ` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
131+   for his instance of the policy.
132+ - In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs
133+   synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond
134+   policy-to-policy weight synchronization strategies.
135+ - In the LLM world, the inference engine and the training one are very different: they will use different libraries,
136+   kernels and calling APIs (e.g., `generate ` vs. `forward `). The weight format can also be drastically different (quantized
137+   vs non-quantized).
138+   This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends.
139+ - One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively
140+   asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141+   is to store the weights on some intermediary server and let the workers fetch them when necessary.
142+ 
143+ TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144+ transfer:
145+ 
146+ - A `Sender ` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147+ - A `Receiver ` class that casts the weights to the destination module (policy or other utility module);
148+ - A `Transport ` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149+ - A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
150+ 
151+ Each of these classes is detailed below.
152+ 
153+ Usage Examples
154+ ~~~~~~~~~~~~~~ 
155+ 
156+ .. note ::
157+     **Runnable versions ** of these examples are available in the repository:
158+     
159+     - `examples/collectors/weight_sync_standalone.py  <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_standalone.py >`_: Standalone weight synchronization
160+     - `examples/collectors/weight_sync_collectors.py  <https://github.com/pytorch/rl/blob/main/examples/collectors/weight_sync_collectors.py >`_: Collector integration
161+ 
162+ Using Weight Update Schemes Independently
163+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
164+ 
165+ Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
166+ 
167+ .. code-block :: python 
168+ 
169+     import  torch 
170+     import  torch.nn as  nn 
171+     from  torch import  multiprocessing as  mp 
172+     from  tensordict import  TensorDict 
173+     from  torchrl.weight_update import  ( 
174+         MultiProcessWeightSyncScheme, 
175+         SharedMemWeightSyncScheme, 
176+     ) 
177+ 
178+     #  Create a simple policy 
179+     policy =  nn.Linear(4 , 2 ) 
180+ 
181+     #  Example 1: Multiprocess weight synchronization with state_dict 
182+     #  -------------------------------------------------------------- 
183+     #  On the main process side (trainer): 
184+     scheme =  MultiProcessWeightSyncScheme(strategy = " state_dict"  
185+     sender =  scheme.create_sender() 
186+ 
187+     #  Register worker pipes 
188+     parent_pipe, child_pipe =  mp.Pipe() 
189+     sender.register_worker(worker_idx = 0 , pipe_or_context = parent_pipe) 
190+ 
191+     #  Send weights to workers 
192+     weights =  policy.state_dict() 
193+     sender.update_weights(weights) 
194+ 
195+     #  On the worker process side: 
196+     #  receiver = scheme.create_receiver() 
197+     #  receiver.register_model(policy) 
198+     #  receiver.register_worker_transport(child_pipe) 
199+     #  # Receive and apply weights 
200+     #  result = receiver._transport.receive_weights(timeout=5.0) 
201+     #  if result is not None: 
202+     #      model_id, weights = result 
203+     #      receiver.apply_weights(weights) 
204+ 
205+     #  Example 2: Shared memory weight synchronization 
206+     #  ------------------------------------------------ 
207+     #  Create shared memory scheme with auto-registration 
208+     shared_scheme =  SharedMemWeightSyncScheme(strategy = " tensordict" auto_register = True ) 
209+     shared_sender =  shared_scheme.create_sender() 
210+ 
211+     #  Register worker pipes for lazy registration 
212+     parent_pipe2, child_pipe2 =  mp.Pipe() 
213+     shared_sender.register_worker(worker_idx = 0 , pipe_or_context = parent_pipe2) 
214+ 
215+     #  Send weights (automatically creates shared buffer on first send) 
216+     weights_td =  TensorDict.from_module(policy) 
217+     shared_sender.update_weights(weights_td) 
218+ 
219+     #  Workers automatically see updates via shared memory! 
220+ 
221+ 
222+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
223+ 
224+ Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization
225+ across multiple inference workers:
226+ 
227+ .. code-block :: python 
228+ 
229+     import  torch.nn as  nn 
230+     from  tensordict.nn import  TensorDictModule 
231+     from  torchrl.collectors import  SyncDataCollector, MultiSyncDataCollector 
232+     from  torchrl.envs import  GymEnv 
233+     from  torchrl.weight_update import  ( 
234+         MultiProcessWeightSyncScheme, 
235+         SharedMemWeightSyncScheme, 
236+     ) 
237+ 
238+     #  Create environment and policy 
239+     env =  GymEnv(" CartPole-v1"  
240+     policy =  TensorDictModule( 
241+         nn.Linear(env.observation_spec[" observation" - 1 ], 
242+                   env.action_spec.shape[- 1 ]), 
243+         in_keys = [" observation"  
244+         out_keys = [" action"  
245+     ) 
246+ 
247+     #  Example 1: Single collector with multiprocess scheme 
248+     #  ----------------------------------------------------- 
249+     scheme =  MultiProcessWeightSyncScheme(strategy = " state_dict"  
250+ 
251+     collector =  SyncDataCollector( 
252+         create_env_fn = lambda : GymEnv(" CartPole-v1"  
253+         policy = policy, 
254+         frames_per_batch = 64 , 
255+         total_frames = 1000 , 
256+         weight_sync_schemes = {" policy"  
257+     ) 
258+ 
259+     #  Collect data and update weights periodically 
260+     for  i, data in  enumerate (collector): 
261+         #  ... training step with data ... 
262+ 
263+         #  Update policy weights every N iterations 
264+         if  i %  10  ==  0 : 
265+             new_weights =  policy.state_dict() 
266+             collector.update_policy_weights_(new_weights) 
267+ 
268+     collector.shutdown() 
269+ 
270+     #  Example 2: Multiple collectors with shared memory 
271+     #  -------------------------------------------------- 
272+     #  Shared memory is more efficient for frequent updates 
273+     shared_scheme =  SharedMemWeightSyncScheme(strategy = " tensordict" auto_register = True ) 
274+ 
275+     collector =  MultiSyncDataCollector( 
276+         create_env_fn = [ 
277+             lambda : GymEnv(" CartPole-v1"  
278+             lambda : GymEnv(" CartPole-v1"  
279+             lambda : GymEnv(" CartPole-v1"  
280+         ], 
281+         policy = policy, 
282+         frames_per_batch = 192 , 
283+         total_frames = 10000 , 
284+         weight_sync_schemes = {" policy"  
285+     ) 
286+ 
287+     #  Workers automatically see weight updates via shared memory 
288+     for  data in  collector: 
289+         #  ... training ... 
290+         collector.update_policy_weights_(TensorDict.from_module(policy)) 
291+ 
292+     collector.shutdown() 
293+ 
294+ note ::
295+     When using ``SharedMemWeightSyncScheme ``, weight updates are zero-copy and extremely fast since all
296+     processes share the same memory buffers. This is ideal for frequent weight updates but requires all
297+     processes to be on the same machine.
298+ 
299+ .. note ::
300+     The ``strategy `` parameter determines the weight format: ``"state_dict" `` uses PyTorch's native state
301+     dictionaries, while ``"tensordict" `` uses TensorDict format which can be more efficient for structured
302+     models and supports advanced features like lazy initialization.
303+ 
304+ Weight Senders
305+ ~~~~~~~~~~~~~~ 
306+ 
307+ .. currentmodule :: torchrl.weight_update 
308+ 
309+ .. autosummary ::
310+     :toctree:  generated/
311+     :template:  rl_template.rst
312+ 
313+     WeightSender
314+     RayModuleTransformSender
315+ 
316+ Weight Receivers
317+ ~~~~~~~~~~~~~~~~ 
318+ 
319+ .. currentmodule :: torchrl.weight_update 
320+ 
321+ .. autosummary ::
322+     :toctree:  generated/
323+     :template:  rl_template.rst
324+ 
325+     WeightReceiver
326+     RayModuleTransformReceiver
327+ 
328+ Transports
329+ ~~~~~~~~~~ 
330+ 
331+ .. currentmodule :: torchrl.weight_update 
332+ 
333+ .. autosummary ::
334+     :toctree:  generated/
335+     :template:  rl_template.rst
336+ 
337+     TransportBackend
338+     MPTransport
339+     SharedMemTransport
340+     RayTransport
341+     RayActorTransport
342+     RPCTransport
343+     DistributedTransport
344+ 
345+ Schemes
346+ ~~~~~~~ 
347+ 
348+ .. currentmodule :: torchrl.weight_update 
349+ 
350+ .. autosummary ::
351+     :toctree:  generated/
352+     :template:  rl_template.rst
353+ 
354+     WeightSyncScheme
355+     MultiProcessWeightSyncScheme
356+     SharedMemWeightSyncScheme
357+     NoWeightSyncScheme
358+     RayWeightSyncScheme
359+     RayModuleTransformScheme
360+     RPCWeightSyncScheme
361+     DistributedWeightSyncScheme
362+ 
363+ Legacy: Weight Synchronization in Distributed Environments
364+ ---------------------------------------------------------- 
365+ 
366+ .. warning ::
367+     The `WeightUpdater ` is considered legacy as per the 0.11 release and will be deprecated soon.
368+     The Weight update schemes, which provides more flexibility and a better compatibility with heavy
369+     weight transfers (e.g., LLMs) is to be preferred.
370+ 
123371In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
124372latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
125373mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
0 commit comments