@@ -117,14 +117,19 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117
118118   Policy copy decision tree in Collectors.
119119
120- Weight Synchronization using Weight Update Schemes 
121- --------------------------------------------------  
120+ Weight Synchronization
121+ ---------------------- 
122122
123- RL pipelines are typically split in two big computational buckets: training, and inference.
124- While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125- synchronize its weights with the inference one.
126- In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127- used in both instances. From there, anything can happen:
123+ In reinforcement learning, the training pipeline is typically split into two computational phases:
124+ **inference ** (data collection) and **training ** (policy optimization). While the inference pipeline
125+ sends data to the training one, the training pipeline needs to periodically synchronize its weights
126+ with the inference workers to ensure they collect data using up-to-date policies.
127+ 
128+ Overview & Motivation
129+ ~~~~~~~~~~~~~~~~~~~~~ 
130+ 
131+ In the simplest setting, the same policy weights are used in both training and inference. However,
132+ real-world RL systems often face additional complexity:
128133
129134- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130135  `DataCollectors ` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
@@ -140,15 +145,226 @@ used in both instances. From there, anything can happen:
140145  asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141146  is to store the weights on some intermediary server and let the workers fetch them when necessary.
142147
143- TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144- transfer:
148+ Key Challenges
149+ ^^^^^^^^^^^^^^ 
150+ 
151+ Modern RL training often involves multiple models that need independent synchronization:
152+ 
153+ 1. **Multiple Models Per Collector **: A collector might need to update:
154+ 
155+    - The main policy network
156+    - A value network in a Ray actor within the replay buffer
157+    - Models embedded in the environment itself
158+    - Separate world models or auxiliary networks
159+ 
160+ 2. **Different Update Strategies **: Each model may require different synchronization approaches:
161+ 
162+    - Full state_dict transfer vs. TensorDict-based updates
163+    - Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
164+    - Varied update frequencies
165+ 
166+ 3. **Worker-Agnostic Updates **: Some models (like those in shared Ray actors) shouldn't be tied to
167+    specific worker indices, requiring a more flexible update mechanism.
168+ 
169+ The Solution
170+ ^^^^^^^^^^^^ 
171+ 
172+ TorchRL addresses these challenges through a flexible, modular architecture built around four components:
173+ 
174+ - **WeightSyncScheme **: Defines *what * to synchronize and *how * (user-facing configuration)
175+ - **WeightSender **: Handles distributing weights from the main process to workers (internal)
176+ - **WeightReceiver **: Handles applying weights in worker processes (internal)
177+ - **TransportBackend **: Manages the actual communication layer (internal)
178+ 
179+ This design allows you to independently configure synchronization for multiple models,
180+ choose appropriate transport mechanisms, and swap strategies without rewriting your training code.
181+ 
182+ Architecture & Concepts
183+ ~~~~~~~~~~~~~~~~~~~~~~~ 
184+ 
185+ Component Roles
186+ ^^^^^^^^^^^^^^^ 
187+ 
188+ The weight synchronization system separates concerns into four distinct layers:
189+ 
190+ 1. **WeightSyncScheme ** (User-Facing)
191+ 
192+    This is your main configuration interface. You create scheme objects that define:
145193
146- - A `Sender ` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147- - A `Receiver ` class that casts the weights to the destination module (policy or other utility module);
148- - A `Transport ` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149- - A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
194+    - The synchronization strategy (``"state_dict" `` or ``"tensordict" ``)
195+    - The transport mechanism (multiprocessing pipes, shared memory, Ray, RPC, etc.)
196+    - Additional options like auto-registration and timeout behavior
150197
151- Each of these classes is detailed below.
198+    When working with collectors, you pass a dictionary mapping model IDs to schemes.
199+ 
200+ 2. **WeightSender ** (Internal)
201+ 
202+    Created by the scheme in the main training process. The sender:
203+ 
204+    - Holds a reference to the source model
205+    - Manages transport connections to all workers
206+    - Extracts weights using the configured strategy
207+    - Broadcasts weight updates across all transports
208+ 
209+ 3. **WeightReceiver ** (Internal)
210+ 
211+    Created by the scheme in each worker process. The receiver:
212+ 
213+    - Holds a reference to the destination model
214+    - Polls its transport for weight updates
215+    - Applies received weights using the configured strategy
216+    - Handles model registration and initialization
217+ 
218+ 4. **TransportBackend ** (Internal)
219+ 
220+    Implements the actual communication mechanism:
221+ 
222+    - ``MPTransport ``: Uses multiprocessing pipes
223+    - ``SharedMemTransport ``: Uses shared memory buffers (zero-copy)
224+    - ``RayTransport ``: Uses Ray's object store
225+    - ``RPCTransport ``: Uses PyTorch RPC
226+    - ``DistributedTransport ``: Uses collective communication (NCCL, Gloo, MPI)
227+ 
228+ Initialization Phase
229+ ^^^^^^^^^^^^^^^^^^^^ 
230+ 
231+ When you create a collector with weight sync schemes, the following initialization occurs:
232+ 
233+ .. aafig ::
234+     :aspect:  60
235+     :scale:  130
236+     :textual: 
237+ 
238+     INITIALIZATION PHASE
239+     ====================
240+ 
241+                         WeightSyncScheme
242+                         +------------------+ 
243+                         |                   | 
244+                         |  Configuration:   | 
245+                         |  - strategy       | 
246+                         |  - transport type | 
247+                         |                   | 
248+                         +--------+---------+ 
249+                                  | 
250+                     +------------+-------------+ 
251+                     |                           | 
252+                     |  creates                  |  creates
253+                     |                           | 
254+                     v                          v
255+             Main Process                 Worker Process
256+             +--------------+              +---------------+ 
257+             |  WeightSender |              |  WeightReceiver| 
258+             |               |              |                | 
259+             |  - strategy   |              |  - strategy    | 
260+             |  - transports |              |  - transport   | 
261+             |  - model ref  |              |  - model ref   | 
262+             |               |              |                | 
263+             |  Registers:   |              |  Registers:    | 
264+             |  - "model"    |              |  - model       | 
265+             |  - "workers"  |              |  - transport   | 
266+             +--------------+              +---------------+ 
267+                     |                             | 
268+                     |    Transport Layer          | 
269+                     |     +----------------+       | 
270+                     +--> |  MPTransport    |  <----+
271+                     |     |  (pipes)        |       | 
272+                     |     +----------------+       | 
273+                     |                             | 
274+                     |    +--------------------+    | 
275+                     +-> |"SharedMemTransport"| <-+
276+                     |    |"(shared memory)""    |    | 
277+                     |    +--------------------+    | 
278+                     |                             | 
279+                     |    +----------------+        | 
280+                     +-> |  RayTransport   |  <-----+
281+                         |  (Ray store)    | 
282+                         +----------------+ 
283+ 
284+ The scheme creates a sender in the main process and a receiver in each worker, then establishes
285+ transport connections between them.
286+ 
287+ Synchronization Phase
288+ ^^^^^^^^^^^^^^^^^^^^^ 
289+ 
290+ When you call ``collector.update_policy_weights_() ``, the weight synchronization proceeds as follows:
291+ 
292+ .. aafig ::
293+     :aspect:  60
294+     :scale:  130
295+     :textual: 
296+ 
297+     SYNCHRONIZATION PHASE
298+     =====================
299+ 
300+         Main Process                                    Worker Process
301+         
302+     +-------------------+                               +-------------------+
303+     |  WeightSender      |                               | WeightReceiver    | 
304+ |                    |                               |                   | 
305+ |  1. Extract        |                               | 4. Poll transport | 
306+ |     weights from   |                               |                   | 
307+ |     model using    |                               |    for weights    | 
308+ |     strategy       |                               |                   | 
309+ |                    |    2. Send via                |                   | 
310+ |  +-------------+   |       Transport               | +--------------+  | 
311+ |  | Strategy    |   |    +------------+             | | Strategy     |  | 
312+ |  | extract()   |   |    |            |             | | apply()      |  | 
313+ |  +-------------+   +----+ Transport  +------------>+ +--------------+  | 
314+ |         |          |    |            |             |        |          | 
315+ |         v          |    +------------+             |        v          | 
316+ |  +-------------+   |                               | +--------------+  | 
317+ |  | Model       |   |                               | | Model (dest) |  | 
318+ |  | (source)    |   |  3. Acknowledgment (optional) | |              |  | 
319+ |  +-------------+   | <---------------------------+ | +--------------+  | 
320+ |                    |                               |                   | 
321+ 
322+                                                         |     to model using | 
323+ |     strategy       | 
324+ +-------------------+ 
325+ 
326+ 1. **Extract **: Sender extracts weights from the source model (state_dict or TensorDict)
327+ 2. **Send **: Sender broadcasts weights through all registered transports
328+ 3. **Acknowledge ** (optional): Some transports send acknowledgment back
329+ 4. **Poll **: Receiver checks its transport for new weights
330+ 5. **Apply **: Receiver applies weights to the destination model
331+ 
332+ Multi-Model Synchronization
333+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
334+ 
335+ One of the key features is support for synchronizing multiple models independently:
336+ 
337+ .. aafig ::
338+     :aspect:  60
339+     :scale:  130
340+     :textual: 
341+ 
342+       Main Process                 Worker Processes (1, 2, ..., B)
343+ 
344+     +-----------------+            +-------------------+
345+     |  Orchestrator    |            | Collector         | 
346+ |                  |            |                   | 
347+ |  Models:         |            | Models:           | 
348+ |   +----------+   |            |  +--------+       | 
349+ |   | Policy A |   |            |  |Policy A|       | 
350+ |   +----------+   |            |  +--------+       | 
351+ |                  |            |                   | 
352+ |   +----------+   |            |  +--------+       | 
353+ |   | Model  B |   |            |  |Model  B|       | 
354+ |   +----------+   |            |  +--------+       | 
355+ |                  |            |                   | 
356+ |  Weight Senders: |            | Weight Receivers: | 
357+ |   +----------+   |            |   +------------+  | 
358+ |   | Sender A +---+------------+-> + Receiver A |  | 
359+ |   +----------+   |            |   +------------+  | 
360+ |                  |            |                   | 
361+ |   +----------+   |            |  +------------+   | 
362+ |   | Sender B +---+------------+->+ Receiver B |   | 
363+ |   +----------+   |  "Pipes"   |  +------------+   | 
364+ 
365+ 
366+ Each model gets its own sender/receiver pair, allowing independent synchronization frequencies,
367+ different transport mechanisms per model, and model-specific strategies.
152368
153369Usage Examples
154370~~~~~~~~~~~~~~ 
@@ -301,32 +517,55 @@ across multiple inference workers:
301517    dictionaries, while ``"tensordict" `` uses TensorDict format which can be more efficient for structured
302518    models and supports advanced features like lazy initialization.
303519
304- Weight Senders
305- ~~~~~~~~~~~~~~ 
520+ API Reference
521+ ~~~~~~~~~~~~~ 
522+ 
523+ The weight synchronization system provides both user-facing configuration classes and internal
524+ implementation classes that are automatically managed by the collectors.
525+ 
526+ Schemes (User-Facing Configuration)
527+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
528+ 
529+ These are the main classes you'll use to configure weight synchronization. Pass them in the
530+ ``weight_sync_schemes `` dictionary when creating collectors.
306531
307532.. currentmodule :: torchrl.weight_update 
308533
309534.. autosummary ::
310535    :toctree:  generated/
311536    :template:  rl_template.rst
312537
313-     WeightSender
314-     RayModuleTransformSender
538+     WeightSyncScheme
539+     MultiProcessWeightSyncScheme
540+     SharedMemWeightSyncScheme
541+     NoWeightSyncScheme
542+     RayWeightSyncScheme
543+     RayModuleTransformScheme
544+     RPCWeightSyncScheme
545+     DistributedWeightSyncScheme
546+ 
547+ Senders and Receivers (Internal)
548+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
315549
316- Weight Receivers 
317- ~~~~~~~~~~~~~~~~ 
550+ These classes are automatically created and managed by the schemes. You typically don't need 
551+ to interact with them directly. 
318552
319553.. currentmodule :: torchrl.weight_update 
320554
321555.. autosummary ::
322556    :toctree:  generated/
323557    :template:  rl_template.rst
324558
559+     WeightSender
325560    WeightReceiver
561+     RayModuleTransformSender
326562    RayModuleTransformReceiver
327563
328- Transports
329- ~~~~~~~~~~ 
564+ Transport Backends (Internal)
565+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
566+ 
567+ Transport classes handle the actual communication between processes. They are automatically
568+ selected and configured by the schemes.
330569
331570.. currentmodule :: torchrl.weight_update 
332571
@@ -342,24 +581,6 @@ Transports
342581    RPCTransport
343582    DistributedTransport
344583
345- Schemes
346- ~~~~~~~ 
347- 
348- .. currentmodule :: torchrl.weight_update 
349- 
350- .. autosummary ::
351-     :toctree:  generated/
352-     :template:  rl_template.rst
353- 
354-     WeightSyncScheme
355-     MultiProcessWeightSyncScheme
356-     SharedMemWeightSyncScheme
357-     NoWeightSyncScheme
358-     RayWeightSyncScheme
359-     RayModuleTransformScheme
360-     RPCWeightSyncScheme
361-     DistributedWeightSyncScheme
362- 
363584Legacy: Weight Synchronization in Distributed Environments
364585---------------------------------------------------------- 
365586
0 commit comments