Skip to content

Commit ada5dac

Browse files
committed
[Feature] Documentation
ghstack-source-id: 684a5c5 Pull-Request: #3192
1 parent 13434eb commit ada5dac

File tree

2 files changed

+274
-41
lines changed

2 files changed

+274
-41
lines changed

docs/source/conf.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,19 @@ def kill_procs(gallery_conf, fname):
198198
]
199199

200200

201-
aafig_default_options = {"scale": 1.5, "aspect": 1.0, "proportional": True}
201+
# aafigure configuration
202+
# Ensure SVG output for HTML and PDF for LaTeX, while keeping text builder raw
203+
aafig_format = {"latex": "pdf", "html": "svg", "text": None}
204+
205+
# Use monospace + textual rendering by default for better ASCII alignment
206+
# Values are percentages (without the % sign)
207+
aafig_default_options = {
208+
"scale": 130,
209+
"aspect": 60,
210+
"proportional": False,
211+
"textual": True,
212+
"line_width": 1.2,
213+
}
202214

203215
# -- Generate knowledge base references -----------------------------------
204216
current_path = os.path.dirname(os.path.realpath(__file__))

docs/source/reference/collectors.rst

Lines changed: 261 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,19 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117

118118
Policy copy decision tree in Collectors.
119119

120-
Weight Synchronization using Weight Update Schemes
121-
--------------------------------------------------
120+
Weight Synchronization
121+
----------------------
122122

123-
RL pipelines are typically split in two big computational buckets: training, and inference.
124-
While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125-
synchronize its weights with the inference one.
126-
In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127-
used in both instances. From there, anything can happen:
123+
In reinforcement learning, the training pipeline is typically split into two computational phases:
124+
**inference** (data collection) and **training** (policy optimization). While the inference pipeline
125+
sends data to the training one, the training pipeline needs to periodically synchronize its weights
126+
with the inference workers to ensure they collect data using up-to-date policies.
127+
128+
Overview & Motivation
129+
~~~~~~~~~~~~~~~~~~~~~
130+
131+
In the simplest setting, the same policy weights are used in both training and inference. However,
132+
real-world RL systems often face additional complexity:
128133

129134
- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130135
`DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
@@ -140,15 +145,226 @@ used in both instances. From there, anything can happen:
140145
asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141146
is to store the weights on some intermediary server and let the workers fetch them when necessary.
142147

143-
TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144-
transfer:
148+
Key Challenges
149+
^^^^^^^^^^^^^^
150+
151+
Modern RL training often involves multiple models that need independent synchronization:
152+
153+
1. **Multiple Models Per Collector**: A collector might need to update:
154+
155+
- The main policy network
156+
- A value network in a Ray actor within the replay buffer
157+
- Models embedded in the environment itself
158+
- Separate world models or auxiliary networks
159+
160+
2. **Different Update Strategies**: Each model may require different synchronization approaches:
161+
162+
- Full state_dict transfer vs. TensorDict-based updates
163+
- Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
164+
- Varied update frequencies
165+
166+
3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to
167+
specific worker indices, requiring a more flexible update mechanism.
168+
169+
The Solution
170+
^^^^^^^^^^^^
171+
172+
TorchRL addresses these challenges through a flexible, modular architecture built around four components:
173+
174+
- **WeightSyncScheme**: Defines *what* to synchronize and *how* (user-facing configuration)
175+
- **WeightSender**: Handles distributing weights from the main process to workers (internal)
176+
- **WeightReceiver**: Handles applying weights in worker processes (internal)
177+
- **TransportBackend**: Manages the actual communication layer (internal)
178+
179+
This design allows you to independently configure synchronization for multiple models,
180+
choose appropriate transport mechanisms, and swap strategies without rewriting your training code.
181+
182+
Architecture & Concepts
183+
~~~~~~~~~~~~~~~~~~~~~~~
184+
185+
Component Roles
186+
^^^^^^^^^^^^^^^
187+
188+
The weight synchronization system separates concerns into four distinct layers:
189+
190+
1. **WeightSyncScheme** (User-Facing)
191+
192+
This is your main configuration interface. You create scheme objects that define:
145193

146-
- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147-
- A `Receiver` class that casts the weights to the destination module (policy or other utility module);
148-
- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149-
- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
194+
- The synchronization strategy (``"state_dict"`` or ``"tensordict"``)
195+
- The transport mechanism (multiprocessing pipes, shared memory, Ray, RPC, etc.)
196+
- Additional options like auto-registration and timeout behavior
150197

151-
Each of these classes is detailed below.
198+
When working with collectors, you pass a dictionary mapping model IDs to schemes.
199+
200+
2. **WeightSender** (Internal)
201+
202+
Created by the scheme in the main training process. The sender:
203+
204+
- Holds a reference to the source model
205+
- Manages transport connections to all workers
206+
- Extracts weights using the configured strategy
207+
- Broadcasts weight updates across all transports
208+
209+
3. **WeightReceiver** (Internal)
210+
211+
Created by the scheme in each worker process. The receiver:
212+
213+
- Holds a reference to the destination model
214+
- Polls its transport for weight updates
215+
- Applies received weights using the configured strategy
216+
- Handles model registration and initialization
217+
218+
4. **TransportBackend** (Internal)
219+
220+
Implements the actual communication mechanism:
221+
222+
- ``MPTransport``: Uses multiprocessing pipes
223+
- ``SharedMemTransport``: Uses shared memory buffers (zero-copy)
224+
- ``RayTransport``: Uses Ray's object store
225+
- ``RPCTransport``: Uses PyTorch RPC
226+
- ``DistributedTransport``: Uses collective communication (NCCL, Gloo, MPI)
227+
228+
Initialization Phase
229+
^^^^^^^^^^^^^^^^^^^^
230+
231+
When you create a collector with weight sync schemes, the following initialization occurs:
232+
233+
.. aafig::
234+
:aspect: 60
235+
:scale: 130
236+
:textual:
237+
238+
INITIALIZATION PHASE
239+
====================
240+
241+
WeightSyncScheme
242+
+------------------+
243+
| |
244+
| Configuration: |
245+
| - strategy |
246+
| - transport type |
247+
| |
248+
+--------+---------+
249+
|
250+
+------------+-------------+
251+
| |
252+
| creates | creates
253+
| |
254+
v v
255+
Main Process Worker Process
256+
+--------------+ +---------------+
257+
| WeightSender | | WeightReceiver|
258+
| | | |
259+
| - strategy | | - strategy |
260+
| - transports | | - transport |
261+
| - model ref | | - model ref |
262+
| | | |
263+
| Registers: | | Registers: |
264+
| - "model" | | - model |
265+
| - "workers" | | - transport |
266+
+--------------+ +---------------+
267+
| |
268+
| Transport Layer |
269+
| +----------------+ |
270+
+--> | MPTransport | <----+
271+
| | (pipes) | |
272+
| +----------------+ |
273+
| |
274+
| +--------------------+ |
275+
+-> |"SharedMemTransport"| <-+
276+
| |"(shared memory)"" | |
277+
| +--------------------+ |
278+
| |
279+
| +----------------+ |
280+
+-> | RayTransport | <-----+
281+
| (Ray store) |
282+
+----------------+
283+
284+
The scheme creates a sender in the main process and a receiver in each worker, then establishes
285+
transport connections between them.
286+
287+
Synchronization Phase
288+
^^^^^^^^^^^^^^^^^^^^^
289+
290+
When you call ``collector.update_policy_weights_()``, the weight synchronization proceeds as follows:
291+
292+
.. aafig::
293+
:aspect: 60
294+
:scale: 130
295+
:textual:
296+
297+
SYNCHRONIZATION PHASE
298+
=====================
299+
300+
Main Process Worker Process
301+
302+
+-------------------+ +-------------------+
303+
| WeightSender | | WeightReceiver |
304+
| | | |
305+
| 1. Extract | | 4. Poll transport |
306+
| weights from | | |
307+
| model using | | for weights |
308+
| strategy | | |
309+
| | 2. Send via | |
310+
| +-------------+ | Transport | +--------------+ |
311+
| | Strategy | | +------------+ | | Strategy | |
312+
| | extract() | | | | | | apply() | |
313+
| +-------------+ +----+ Transport +------------>+ +--------------+ |
314+
| | | | | | | |
315+
| v | +------------+ | v |
316+
| +-------------+ | | +--------------+ |
317+
| | Model | | | | Model (dest) | |
318+
| | (source) | | 3. Acknowledgment (optional) | | | |
319+
| +-------------+ | <---------------------------+ | +--------------+ |
320+
| | | |
321+
+-------------------+ | 5. Apply weights |
322+
| to model using |
323+
| strategy |
324+
+-------------------+
325+
326+
1. **Extract**: Sender extracts weights from the source model (state_dict or TensorDict)
327+
2. **Send**: Sender broadcasts weights through all registered transports
328+
3. **Acknowledge** (optional): Some transports send acknowledgment back
329+
4. **Poll**: Receiver checks its transport for new weights
330+
5. **Apply**: Receiver applies weights to the destination model
331+
332+
Multi-Model Synchronization
333+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
334+
335+
One of the key features is support for synchronizing multiple models independently:
336+
337+
.. aafig::
338+
:aspect: 60
339+
:scale: 130
340+
:textual:
341+
342+
Main Process Worker Processes (1, 2, ..., B)
343+
344+
+-----------------+ +-------------------+
345+
| Orchestrator | | Collector |
346+
| | | |
347+
| Models: | | Models: |
348+
| +----------+ | | +--------+ |
349+
| | Policy A | | | |Policy A| |
350+
| +----------+ | | +--------+ |
351+
| | | |
352+
| +----------+ | | +--------+ |
353+
| | Model B | | | |Model B| |
354+
| +----------+ | | +--------+ |
355+
| | | |
356+
| Weight Senders: | | Weight Receivers: |
357+
| +----------+ | | +------------+ |
358+
| | Sender A +---+------------+-> + Receiver A | |
359+
| +----------+ | | +------------+ |
360+
| | | |
361+
| +----------+ | | +------------+ |
362+
| | Sender B +---+------------+->+ Receiver B | |
363+
| +----------+ | "Pipes" | +------------+ |
364+
+-----------------+ +-------------------+
365+
366+
Each model gets its own sender/receiver pair, allowing independent synchronization frequencies,
367+
different transport mechanisms per model, and model-specific strategies.
152368

153369
Usage Examples
154370
~~~~~~~~~~~~~~
@@ -301,32 +517,55 @@ across multiple inference workers:
301517
dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured
302518
models and supports advanced features like lazy initialization.
303519

304-
Weight Senders
305-
~~~~~~~~~~~~~~
520+
API Reference
521+
~~~~~~~~~~~~~
522+
523+
The weight synchronization system provides both user-facing configuration classes and internal
524+
implementation classes that are automatically managed by the collectors.
525+
526+
Schemes (User-Facing Configuration)
527+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
528+
529+
These are the main classes you'll use to configure weight synchronization. Pass them in the
530+
``weight_sync_schemes`` dictionary when creating collectors.
306531

307532
.. currentmodule:: torchrl.weight_update
308533

309534
.. autosummary::
310535
:toctree: generated/
311536
:template: rl_template.rst
312537

313-
WeightSender
314-
RayModuleTransformSender
538+
WeightSyncScheme
539+
MultiProcessWeightSyncScheme
540+
SharedMemWeightSyncScheme
541+
NoWeightSyncScheme
542+
RayWeightSyncScheme
543+
RayModuleTransformScheme
544+
RPCWeightSyncScheme
545+
DistributedWeightSyncScheme
546+
547+
Senders and Receivers (Internal)
548+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
315549

316-
Weight Receivers
317-
~~~~~~~~~~~~~~~~
550+
These classes are automatically created and managed by the schemes. You typically don't need
551+
to interact with them directly.
318552

319553
.. currentmodule:: torchrl.weight_update
320554

321555
.. autosummary::
322556
:toctree: generated/
323557
:template: rl_template.rst
324558

559+
WeightSender
325560
WeightReceiver
561+
RayModuleTransformSender
326562
RayModuleTransformReceiver
327563

328-
Transports
329-
~~~~~~~~~~
564+
Transport Backends (Internal)
565+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
566+
567+
Transport classes handle the actual communication between processes. They are automatically
568+
selected and configured by the schemes.
330569

331570
.. currentmodule:: torchrl.weight_update
332571

@@ -342,24 +581,6 @@ Transports
342581
RPCTransport
343582
DistributedTransport
344583

345-
Schemes
346-
~~~~~~~
347-
348-
.. currentmodule:: torchrl.weight_update
349-
350-
.. autosummary::
351-
:toctree: generated/
352-
:template: rl_template.rst
353-
354-
WeightSyncScheme
355-
MultiProcessWeightSyncScheme
356-
SharedMemWeightSyncScheme
357-
NoWeightSyncScheme
358-
RayWeightSyncScheme
359-
RayModuleTransformScheme
360-
RPCWeightSyncScheme
361-
DistributedWeightSyncScheme
362-
363584
Legacy: Weight Synchronization in Distributed Environments
364585
----------------------------------------------------------
365586

0 commit comments

Comments
 (0)