diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index 7b7e053f498..0061a895578 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -150,7 +150,7 @@ def make_env(): if i == 10: t0 = time.time() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") collector.shutdown() exit() diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index f63c4d17409..a684a1b724c 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -148,7 +148,7 @@ def make_env(): if i == 10: t0 = time.time() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") collector.shutdown() exit() diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index c7981ae5209..795660fc683 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -125,5 +125,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") exit() diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index 1963e76f89e..151879a5423 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -113,5 +113,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") exit() diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index 2dcbe93a425..10a37d47a87 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -119,5 +119,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index 586efcbeab1..2c52c84321a 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -160,5 +160,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 236a4292674..5c9ef50b08a 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -129,5 +129,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") exit() diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index cef4fe6a0e8..84cc1b1de99 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -149,5 +149,5 @@ def gym_make(): t0 = time.time() collector.shutdown() t1 = time.time() - torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps") + torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps") exit() diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py index 691f1b6c466..6e685c6cc70 100644 --- a/examples/rlhf/train.py +++ b/examples/rlhf/train.py @@ -147,7 +147,7 @@ def main(cfg): elif it % log_interval == 0: # loss as float. note: this is a CPU-GPU sync point loss = batch.loss.item() - msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms" + msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt * 1000:.2f}ms" torchrl_logger.info(msg) loss_logger.info(msg) diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py index cc1cdf763da..b1bc2719086 100644 --- a/examples/rlhf/train_reward.py +++ b/examples/rlhf/train_reward.py @@ -155,7 +155,7 @@ def main(cfg): acc = _accuracy( batch.chosen_data.end_scores, batch.rejected_data.end_scores ) - msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms" + msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt * 1000:.2f}ms" torchrl_logger.info(msg) loss_logger.info(msg) diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py index 6ca676f25dc..88493192364 100644 --- a/test/llm/test_wrapper.py +++ b/test/llm/test_wrapper.py @@ -2536,7 +2536,7 @@ def test_batching( # Create 2 threads and send inputs inputs = [ TensorDict( - text=Text(prompt=[f"Question {i}?", f"Question {i+2}?"]), + text=Text(prompt=[f"Question {i}?", f"Question {i + 2}?"]), batch_size=(2,), ) for i in range(2) diff --git a/test/test_libs.py b/test/test_libs.py index 575c27701e3..d09a180d659 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3465,7 +3465,7 @@ def test_direct_download(self, task, tmpdir): def test_d4rl_dummy(self, task): t0 = time.time() _ = D4RLExperienceReplay(task, split_trajs=True, from_env=True, batch_size=2) - torchrl_logger.info(f"terminated test after {time.time()-t0}s") + torchrl_logger.info(f"terminated test after {time.time() - t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -3490,7 +3490,7 @@ def test_dataset_build(self, task, split_trajs, from_env): offline = sample.get(key) # assert sim.dtype == offline.dtype, key assert sim.shape[-1] == offline.shape[-1], key - torchrl_logger.info(f"terminated test after {time.time()-t0}s") + torchrl_logger.info(f"terminated test after {time.time() - t0}s") @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) @pytest.mark.parametrize("split_trajs", [True, False]) @@ -3509,7 +3509,7 @@ def test_d4rl_iteration(self, task, split_trajs): for sample in data: # noqa: B007 i += 1 assert len(data) // i == batch_size - torchrl_logger.info(f"terminated test after {time.time()-t0}s") + torchrl_logger.info(f"terminated test after {time.time() - t0}s") _MINARI_DATASETS = [] @@ -3769,7 +3769,7 @@ def test_load(self, dataset_idx, split): t0 = time.time() for i, sample in enumerate(data): t1 = time.time() - torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") + torchrl_logger.info(f"sampling time {1000 * (t1 - t0): 4.4f}ms") assert data.metadata["action_space"].is_in(sample["action"]) assert data.metadata["observation_space"].is_in(sample["observation"]) t0 = time.time() @@ -3907,7 +3907,7 @@ def test_load(self): t0 = time.time() for i, _ in enumerate(data): t1 = time.time() - torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") + torchrl_logger.info(f"sampling time {1000 * (t1 - t0): 4.4f}ms") t0 = time.time() if i == 10: break @@ -3961,7 +3961,7 @@ def test_load(self, image_size): assert (batch.get("pixels") != 0).any() assert (batch.get(("next", "pixels")) != 0).any() t1 = time.time() - torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms") + torchrl_logger.info(f"sampling time {1000 * (t1 - t0): 4.4f}ms") t0 = time.time() if i == 10: break diff --git a/test/test_rb.py b/test/test_rb.py index 71e5b30043f..f3abf85f9a0 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -188,7 +188,9 @@ ) @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: - def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): + def _get_rb( + self, rb_type, size, sampler, writer, storage, compilable=False, **kwargs + ): if storage is not None: storage = storage(size, compilable=compilable) @@ -204,6 +206,7 @@ def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): writer=writer, batch_size=3, compilable=compilable, + **kwargs, ) return rb @@ -375,7 +378,7 @@ def test_extend(self, rb_type, sampler, writer, storage, size, datatype): OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data) - and isinstance(rb._storage, TensorStorage) + and isinstance(rb.storage, TensorStorage) ) if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( @@ -383,7 +386,7 @@ def test_extend(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - length = min(rb._storage.max_size, len(rb) + data_shape) + length = min(rb.storage.max_size, len(rb) + data_shape) if writer is TensorDictMaxValueWriter: data["next", "reward"][-length:] = 1_000_000 with ( @@ -406,7 +409,7 @@ def data_iter(): data_iter = data_iter() for d in data_iter: - for b in rb._storage: + for b in rb.storage: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) b = b.exclude("index").select(*keys, strict=False) @@ -429,7 +432,7 @@ def data_iter(): OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data2) - and isinstance(rb._storage, TensorStorage) + and isinstance(rb.storage, TensorStorage) ) with ( pytest.warns( @@ -533,7 +536,7 @@ def test_sample(self, rb_type, sampler, writer, storage, size, datatype): OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data) - and isinstance(rb._storage, TensorStorage) + and isinstance(rb.storage, TensorStorage) ) if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( @@ -606,7 +609,7 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): OLD_TORCH and writer is not TensorDictMaxValueWriter and size < len(data) - and isinstance(rb._storage, TensorStorage) + and isinstance(rb.storage, TensorStorage) ) if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: with pytest.raises( @@ -624,7 +627,7 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) d1 = rb[2] - d2 = rb._storage[2] + d2 = rb.storage[2] if type(d1) is not type(d2): d1 = d1[0] if is_tensor_collection(data) or isinstance(data, torch.Tensor): @@ -639,7 +642,12 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): rb = self._get_rb( - rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=size, + delayed_init=False, ) serialized = pickle.dumps(rb) rb2 = pickle.loads(serialized) @@ -1033,13 +1041,13 @@ def test_storage_inplace_writing(self, storage_type, collate_fn): ) assert (rb[3:4] == 0).all() assert len(rb) == 100 - assert rb._writer._cursor == 100 + assert rb.writer._cursor == 100 rb[10:20] = TensorDict( {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] ) assert (rb[10:20] == 0).all() assert len(rb) == 100 - assert rb._writer._cursor == 100 + assert rb.writer._cursor == 100 rb[torch.arange(30, 40)] = TensorDict( {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] ) @@ -1068,13 +1076,13 @@ def test_storage_inplace_writing_transform(self, storage_type, collate_fn): ) assert (rb[3:4] == 2).all(), rb[3:4]["a"] assert len(rb) == 100 - assert rb._writer._cursor == 100 + assert rb.writer._cursor == 100 rb[10:20] = TensorDict( {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] ) assert (rb[10:20] == 2).all() assert len(rb) == 100 - assert rb._writer._cursor == 100 + assert rb.writer._cursor == 100 rb[torch.arange(30, 40)] = TensorDict( {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] ) @@ -1125,17 +1133,17 @@ def test_storage_inplace_writing_ndim(self, storage_type): ) assert (rb[0, 3:4] == 0).all() assert (rb[1, 3:4] != 0).all() - assert rb._writer._cursor == 50 + assert rb.writer._cursor == 50 rb[1, 5:6] = TensorDict( {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] ) assert (rb[1, 5:6] == 0).all() - assert rb._writer._cursor == 50 + assert rb.writer._cursor == 50 rb[:, 7:8] = TensorDict( {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] ).expand(2, 1) assert (rb[:, 7:8] == 0).all() - assert rb._writer._cursor == 50 + assert rb.writer._cursor == 50 # test broadcasting rb[:, 10:20] = TensorDict( {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] @@ -1408,7 +1416,10 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): class TestRNG: def test_rb_rng(self): state = torch.random.get_rng_state() - rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb = ReplayBufferRNG( + sampler=RandomSampler(), storage=LazyTensorStorage(100), delayed_init=False + ) + assert rb.initialized rb.extend(torch.arange(100)) rb._rng.set_state(state) a = rb.sample(32) @@ -1587,7 +1598,7 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) batch1 = self._get_data(rbtype, size=5) cond = ( - OLD_TORCH and size < len(batch1) and isinstance(rb._storage, TensorStorage) + OLD_TORCH and size < len(batch1) and isinstance(rb.storage, TensorStorage) ) with ( pytest.warns( @@ -1601,16 +1612,16 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): # Added fewer data than storage max size if size > 5 or storage is None: - assert rb._writer._cursor == 5 + assert rb.writer._cursor == 5 # Added more data than storage max size elif size < 5: - assert rb._writer._cursor == 5 - size + assert rb.writer._cursor == 5 - size # Added as data as storage max size else: - assert rb._writer._cursor == 0 + assert rb.writer._cursor == 0 batch2 = self._get_data(rbtype, size=size - 1) rb.extend(batch2) - assert rb._writer._cursor == size - 1 + assert rb.writer._cursor == size - 1 def test_add(self, rbtype, storage, size, prefetch): torch.manual_seed(0) @@ -1650,7 +1661,7 @@ def test_extend(self, rbtype, storage, size, prefetch): torch.manual_seed(0) rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) with ( pytest.warns( UserWarning, @@ -1662,7 +1673,7 @@ def test_extend(self, rbtype, storage, size, prefetch): rb.extend(data) length = len(rb) for d in data[-length:]: - for b in rb._storage: + for b in rb.storage: if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) b = b.exclude("index").select(*keys, strict=False) @@ -1681,7 +1692,7 @@ def test_sample(self, rbtype, storage, size, prefetch): torch.manual_seed(0) rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) with ( pytest.warns( UserWarning, @@ -1715,7 +1726,7 @@ def test_index(self, rbtype, storage, size, prefetch): torch.manual_seed(0) rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) + cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) with ( pytest.warns( UserWarning, @@ -1726,7 +1737,7 @@ def test_index(self, rbtype, storage, size, prefetch): ): rb.extend(data) d1 = rb[2] - d2 = rb._storage[2] + d2 = rb.storage[2] if type(d1) is not type(d2): d1 = d1[0] b = d1 == d2 @@ -2258,14 +2269,14 @@ def test_max_value_writer_serialize( device=device, ) rb.extend(td) - rb._writer.dumps(tmpdir) + rb.writer.dumps(tmpdir) # check we can dump twice - rb._writer.dumps(tmpdir) + rb.writer.dumps(tmpdir) other = TensorDictMaxValueWriter(rank_key="key") other.loads(tmpdir) - assert len(rb._writer._current_top_values) == len(other._current_top_values) + assert len(rb.writer._current_top_values) == len(other._current_top_values) torch.testing.assert_close( - torch.tensor(rb._writer._current_top_values), + torch.tensor(rb.writer._current_top_values), torch.tensor(other._current_top_values), ) @@ -2930,9 +2941,9 @@ def test_slice_sampler_prioritized(self, ndim, strict_length, circ, at_capacity) sc = samples[samples["traj"] == 0]["step_count"] assert (sc == 1).sum() == (sc == 2).sum() assert (sc == 1).sum() == (sc == 4).sum() - assert rb._sampler._cache + assert rb.sampler._cache rb.extend(data, update_priority=False) - assert not rb._sampler._cache + assert not rb.sampler._cache @pytest.mark.parametrize("ndim", [1, 2]) @pytest.mark.parametrize("strict_length", [True, False]) @@ -2999,14 +3010,14 @@ def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): pass if i == 1000: break - assert not rb._sampler.span[0] - # if rb._sampler.span[0]: + assert not rb.sampler.span[0] + # if rb.sampler.span[0]: # assert found_traj_4_truncated_left - if rb._sampler.span[1]: + if rb.sampler.span[1]: assert found_traj_4_truncated_right else: assert not found_traj_4_truncated_right - if strict_length and not rb._sampler.span[1]: + if strict_length and not rb.sampler.span[1]: assert not found_traj_0 else: assert found_traj_0 @@ -3027,27 +3038,27 @@ def test_prb_update_max_priority(self, max_priority_within_buffer): rb.update_priority(idx, 21 - data) if data <= 10: # The max is always going to be the first value - assert rb._sampler._max_priority[0] == 21 - assert rb._sampler._max_priority[1] == 0 + assert rb.sampler._max_priority[0] == 21 + assert rb.sampler._max_priority[1] == 0 elif not max_priority_within_buffer: # The max is the historical max, which was at idx 0 - assert rb._sampler._max_priority[0] == 21 - assert rb._sampler._max_priority[1] == 0 + assert rb.sampler._max_priority[0] == 21 + assert rb.sampler._max_priority[1] == 0 else: # the max is the current max. Find it and compare sumtree = torch.as_tensor( - [rb._sampler._sum_tree[i] for i in range(rb._sampler._max_capacity)] + [rb.sampler._sum_tree[i] for i in range(rb.sampler._max_capacity)] ) - assert rb._sampler._max_priority[0] == sumtree.max() - assert rb._sampler._max_priority[1] == sumtree.argmax() + assert rb.sampler._max_priority[0] == sumtree.max() + assert rb.sampler._max_priority[1] == sumtree.argmax() idx = rb.extend(torch.arange(10)) rb.update_priority(idx, 12) if max_priority_within_buffer: - assert rb._sampler._max_priority[0] == 12 - assert rb._sampler._max_priority[1] == 0 + assert rb.sampler._max_priority[0] == 12 + assert rb.sampler._max_priority[1] == 0 else: - assert rb._sampler._max_priority[0] == 21 - assert rb._sampler._max_priority[1] == 0 + assert rb.sampler._max_priority[0] == 21 + assert rb.sampler._max_priority[1] == 0 @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" @@ -3214,13 +3225,13 @@ def test_prb_ndim(self): ) data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) idx = rb.extend(data) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() rb.update_priority(idx, 2) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() s, info = rb.sample(return_info=True) rb.update_priority(info["index"], 3) assert ( - torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[info["index"]] + torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[info["index"]] == 3 ).all() @@ -3232,13 +3243,13 @@ def test_prb_ndim(self): ) data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) idx = rb.extend(data) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() rb.update_priority(idx, 2) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() s = rb.sample() rb.update_priority(s["index"], 3) assert ( - torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[s["index"]] == 3 + torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[s["index"]] == 3 ).all() # third case: 1d TPRB @@ -3251,17 +3262,15 @@ def test_prb_ndim(self): ) data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) idx = rb.extend(data) - assert ( - torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 0.5 - ).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 0.5).all() rb.update_priority(idx, 2) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() s = rb.sample() s["p"] = torch.ones(4) * 10_000 rb.update_tensordict_priority(s) assert ( - torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[s["index"]] + torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[s["index"]] == 10_000 ).all() @@ -3279,15 +3288,15 @@ def test_prb_ndim(self): {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] ) idx = rb.extend(data) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() rb.update_priority(idx, 2) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() s, info = rb.sample(return_info=True) rb.update_priority(info["index"], 3) - priorities = torch.tensor( - [rb._sampler._sum_tree[i] for i in range(10)] - ).reshape((5, 2)) + priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( + (5, 2) + ) assert (priorities[info["index"]] == 3).all() # fifth case: 2d TRB @@ -3301,15 +3310,15 @@ def test_prb_ndim(self): {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] ) idx = rb.extend(data) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() rb.update_priority(idx, 2) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() s = rb.sample() rb.update_priority(s["index"], 10_000) - priorities = torch.tensor( - [rb._sampler._sum_tree[i] for i in range(10)] - ).reshape((5, 2)) + priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( + (5, 2) + ) assert (priorities[s["index"].unbind(-1)] == 10_000).all() s2 = rb.sample() @@ -3329,18 +3338,16 @@ def test_prb_ndim(self): {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] ) idx = rb.extend(data) - assert ( - torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 0.5 - ).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 0.5).all() rb.update_priority(idx, torch.ones(()) * 2) - assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all() + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() s = rb.sample() # setting the priorities to a value that is so big that the buffer will resample them s["p"] = torch.ones(4) * 10_000 rb.update_tensordict_priority(s) - priorities = torch.tensor( - [rb._sampler._sum_tree[i] for i in range(10)] - ).reshape((5, 2)) + priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( + (5, 2) + ) assert (priorities[s["index"].unbind(-1)] == 10_000).all() s2 = rb.sample() @@ -3777,32 +3784,32 @@ def test_rb_indexing(self, explicit): assert_allclose_td(sample0, sample1) # check indexing of components - assert isinstance(rb._storage[:], StorageEnsemble) - assert isinstance(rb._storage[:2], StorageEnsemble) - assert isinstance(rb._storage[torch.tensor([0, 1])], StorageEnsemble) - assert isinstance(rb._storage[np.array([0, 1])], StorageEnsemble) - assert isinstance(rb._storage[[0, 1]], StorageEnsemble) - assert isinstance(rb._storage[1], LazyMemmapStorage) - - rb._storage[:, :3] - rb._storage[:2, :3] - rb._storage[torch.tensor([0, 1]), :3] - rb._storage[np.array([0, 1]), :3] - rb._storage[[0, 1], :3] - - assert isinstance(rb._sampler[:], SamplerEnsemble) - assert isinstance(rb._sampler[:2], SamplerEnsemble) - assert isinstance(rb._sampler[torch.tensor([0, 1])], SamplerEnsemble) - assert isinstance(rb._sampler[np.array([0, 1])], SamplerEnsemble) - assert isinstance(rb._sampler[[0, 1]], SamplerEnsemble) - assert isinstance(rb._sampler[1], RandomSampler) - - assert isinstance(rb._writer[:], WriterEnsemble) - assert isinstance(rb._writer[:2], WriterEnsemble) - assert isinstance(rb._writer[torch.tensor([0, 1])], WriterEnsemble) - assert isinstance(rb._writer[np.array([0, 1])], WriterEnsemble) - assert isinstance(rb._writer[[0, 1]], WriterEnsemble) - assert isinstance(rb._writer[0], RoundRobinWriter) + assert isinstance(rb.storage[:], StorageEnsemble) + assert isinstance(rb.storage[:2], StorageEnsemble) + assert isinstance(rb.storage[torch.tensor([0, 1])], StorageEnsemble) + assert isinstance(rb.storage[np.array([0, 1])], StorageEnsemble) + assert isinstance(rb.storage[[0, 1]], StorageEnsemble) + assert isinstance(rb.storage[1], LazyMemmapStorage) + + rb.storage[:, :3] + rb.storage[:2, :3] + rb.storage[torch.tensor([0, 1]), :3] + rb.storage[np.array([0, 1]), :3] + rb.storage[[0, 1], :3] + + assert isinstance(rb.sampler[:], SamplerEnsemble) + assert isinstance(rb.sampler[:2], SamplerEnsemble) + assert isinstance(rb.sampler[torch.tensor([0, 1])], SamplerEnsemble) + assert isinstance(rb.sampler[np.array([0, 1])], SamplerEnsemble) + assert isinstance(rb.sampler[[0, 1]], SamplerEnsemble) + assert isinstance(rb.sampler[1], RandomSampler) + + assert isinstance(rb.writer[:], WriterEnsemble) + assert isinstance(rb.writer[:2], WriterEnsemble) + assert isinstance(rb.writer[torch.tensor([0, 1])], WriterEnsemble) + assert isinstance(rb.writer[np.array([0, 1])], WriterEnsemble) + assert isinstance(rb.writer[[0, 1]], WriterEnsemble) + assert isinstance(rb.writer[0], RoundRobinWriter) def _rbtype(datatype): @@ -3965,6 +3972,7 @@ def test_rb_multidim_collector( storage=storage_cls(max_size=10, ndim=2), sampler=sampler_cls(), writer=writer_cls(), + delayed_init=False, ) return rb = rbtype( @@ -3972,7 +3980,7 @@ def test_rb_multidim_collector( sampler=sampler_cls(), writer=writer_cls(), ) - if not isinstance(rb._sampler, SliceSampler) and transform is not None: + if not isinstance(rb.sampler, SliceSampler) and transform is not None: pytest.skip("no need to test this combination") if transform: for t in transform: @@ -4073,7 +4081,7 @@ def test_simple_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): rb.dumps(tmpdir) rb_test.loads(tmpdir) assert_allclose_td(rb_test[:], rb[:]) - assert rb._writer._cursor == rb_test._writer._cursor + assert rb.writer._cursor == rb_test._writer._cursor @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) @pytest.mark.parametrize("frames_per_batch", [22, 122]) @@ -4109,9 +4117,9 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): rb_test.storage.checkpointer = checkpointer() for data in collector: rb.extend(data) - assert rb._storage.max_size == 102 + assert rb.storage.max_size == 102 if frames_per_batch > 100: - assert rb._storage._is_full + assert rb.storage._is_full assert len(rb) == 102 # Checks that when writing to the buffer with a batch greater than the total # size, we get the last step written properly. @@ -4120,7 +4128,7 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): rb.dumps(tmpdir) rb_test.loads(tmpdir) assert_allclose_td(rb_test[:], rb[:]) - assert rb._writer._cursor == rb_test._writer._cursor + assert rb.writer._cursor == rb_test._writer._cursor @pytest.mark.skipif(not _has_ray, reason="ray required for this test.") @@ -4181,6 +4189,48 @@ def test_ray_rb_iter(self): rb.close() +class TestSharedStorageInit: + def worker(self, rb, worker_id, queue): + length = len(rb) + data = TensorDict({"x": torch.full((2,), worker_id)}, batch_size=(2,)) + worker_id * 2 + index = rb.extend(data) + assert len(rb) >= length + 2 + assert (rb[index] == data).all() + queue.put("done") + + @pytest.mark.parametrize( + "storage_cls, use_tmpdir", + [ + (LazyTensorStorage, False), + (LazyMemmapStorage, False), + (LazyMemmapStorage, True), + ], + ) + def test_shared_storage_multiprocess(self, storage_cls, use_tmpdir, tmpdir): + if use_tmpdir: + storage_cls = functools.partial(storage_cls, scratch_dir=tmpdir) + storage = storage_cls(max_size=100, shared_init=True) + rb = ReplayBuffer(storage=storage, batch_size=2).share(True) + queue = mp.Queue() + + processes = [] + for i in range(4): + p = mp.Process(target=self.worker, args=(rb, i, queue)) + processes.append(p) + p.start() + + for p in processes: + p.join() + queue.get() + + all_data = storage.get(slice(0, 8)) + values = set(all_data["x"].tolist()) + expected = {0.0, 1.0, 2.0, 3.0} + assert expected.issubset(values) + assert len(storage) >= 8 + + @pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") class TestCompressedListStorage: """Test cases for CompressedListStorage.""" @@ -4464,6 +4514,61 @@ def test_compressed_storage_memory_efficiency(self): ), f"Compression ratio {compression_ratio} is too low" +class TestRBLazyInit: + def test_lazy_init(self): + def transform(td): + return td + + rb = ReplayBuffer( + storage=partial(ListStorage), + writer=partial(RoundRobinWriter), + sampler=partial(RandomSampler), + transform_factory=lambda: transform, + ) + assert not rb.initialized + assert not hasattr(rb, "_storage") + assert rb._init_storage is not None + assert not hasattr(rb, "_sampler") + assert rb._init_sampler is not None + assert not hasattr(rb, "_writer") + assert rb._init_writer is not None + rb.extend(TensorDict(batch_size=[2])) + assert rb.initialized + assert rb._storage is not None + assert rb._init_storage is None + assert rb._sampler is not None + assert rb._init_sampler is None + assert rb._writer is not None + assert rb._init_writer is None + + rb = ReplayBuffer( + storage=partial(ListStorage), + writer=partial(RoundRobinWriter), + sampler=partial(RandomSampler), + ) + assert rb.initialized + assert rb._storage is not None + assert rb._init_storage is None + assert rb._sampler is not None + assert rb._init_sampler is None + assert rb._writer is not None + assert rb._init_writer is None + + rb = ReplayBuffer( + storage=partial(ListStorage), + writer=partial(RoundRobinWriter), + sampler=partial(RandomSampler), + delayed_init=False, + ) + assert rb.initialized + assert rb._storage is not None + assert rb._init_storage is None + assert rb._sampler is not None + assert rb._init_sampler is None + assert rb._writer is not None + assert rb._init_writer is None + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/llm/weight_update/vllm_v2.py b/torchrl/collectors/llm/weight_update/vllm_v2.py index 3075bbb3a00..0792d7e7de6 100644 --- a/torchrl/collectors/llm/weight_update/vllm_v2.py +++ b/torchrl/collectors/llm/weight_update/vllm_v2.py @@ -204,7 +204,7 @@ def push_weights_from_transformers_optimized( batch = weight_items[i : i + batch_size] self.push_weights(iter(batch)) torchrl_logger.info( - f"Transferred batch {i//batch_size + 1}/{(len(weight_items) + batch_size - 1)//batch_size}" + f"Transferred batch {i // batch_size + 1}/{(len(weight_items) + batch_size - 1) // batch_size}" ) else: # Transfer all at once @@ -266,11 +266,11 @@ def _increment_all_collector_versions(self): try: collector.increment_version() torchrl_logger.debug( - f"Incremented version for collector {i+1}/{len(self.collectors)}" + f"Incremented version for collector {i + 1}/{len(self.collectors)}" ) except Exception as e: torchrl_logger.warning( - f"Failed to increment version for collector {i+1}: {e}" + f"Failed to increment version for collector {i + 1}: {e}" ) torchrl_logger.info("All collector versions incremented") diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index ea8670df4d8..179afdee8b2 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -24,7 +24,8 @@ except ImportError: from torch._dynamo import is_compiling -from functools import partial +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar from tensordict import ( is_tensor_collection, @@ -71,6 +72,22 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import _InvertTransform, Transform +T = TypeVar("T") +if TYPE_CHECKING: + from typing import Self +else: + Self = T + + +def _maybe_delay_init(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self._delayed_init and not self.initialized: + self._init() + return func(self, *args, **kwargs) + + return wrapper + class ReplayBuffer: """A generic, composable replay buffer class. @@ -153,6 +170,12 @@ class ReplayBuffer: compilable (bool, optional): whether the writer is compilable. If ``True``, the writer cannot be shared between multiple processes. Defaults to ``False``. + delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform + the first time the buffer is used rather than during construction. + This is useful when the replay buffer needs to be pickled and sent to remote workers, + particularly when using transforms with modules that require gradients. + If not specified, defaults to ``True`` when ``transform_factory`` is provided, + and ``False`` otherwise. Examples: >>> import torch @@ -243,19 +266,55 @@ def __init__( generator: torch.Generator | None = None, shared: bool = False, compilable: bool | None = None, + delayed_init: bool | None = None, ) -> None: - self._storage = self._maybe_make_storage(storage, compilable=compilable) - self._storage.attach(self) - self._sampler = self._maybe_make_sampler(sampler) - self._writer = self._maybe_make_writer(writer) - self._writer.register_storage(self._storage) + self._delayed_init = delayed_init + self._initialized = False + + # Store init parameters for potential delayed initialization + self._init_storage = storage + self._init_sampler = sampler + self._init_writer = writer + self._init_collate_fn = collate_fn + self._init_transform = transform + self._init_transform_factory = transform_factory + self._init_checkpointer = checkpointer + self._init_generator = generator + self._init_compilable = compilable + + if transform is not None and transform_factory is not None: + raise TypeError( + f"transform and transform_factory are mutually exclusive. " + f"Got transform={transform} and transform_factory={transform_factory}." + ) - self._get_collate_fn(collate_fn) - self._pin_memory = pin_memory + # Auto-detect delayed_init when transform_factory is provided + if transform_factory is not None and delayed_init is None: + delayed_init = True + elif delayed_init is None: + delayed_init = False + + # Update _delayed_init after auto-detection + self._delayed_init = delayed_init + self._pin_memory = pin_memory self._prefetch = bool(prefetch) self._prefetch_cap = prefetch or 0 self._prefetch_queue = collections.deque() + self._batch_size = batch_size + + if batch_size is None and prefetch: + raise ValueError( + "Dynamic batch-size specification is incompatible " + "with multithreaded sampling. " + "When using prefetch, the batch-size must be specified in " + "advance. " + ) + + if dim_extend is not None and dim_extend < 0: + raise ValueError("dim_extend must be a positive value.") + self._dim_extend = dim_extend + if self._prefetch_cap: self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) @@ -267,31 +326,86 @@ def __init__( self._replay_lock = threading.RLock() self._futures_lock = threading.RLock() - self._transform = self._maybe_make_transform(transform, transform_factory) + # If not delayed, initialize immediately + if not self._delayed_init: + self._init() - if batch_size is None and prefetch: - raise ValueError( - "Dynamic batch-size specification is incompatible " - "with multithreaded sampling. " - "When using prefetch, the batch-size must be specified in " - "advance. " + def _init(self) -> None: + """Initialize the replay buffer components. + + This method is called either immediately during __init__ (if delayed_init=False) + or on first use of the buffer (if delayed_init=True). + """ + if self._initialized: + return + + self._initialized = True + try: + # Initialize storage + self._storage = self._maybe_make_storage( + self._init_storage, compilable=self._init_compilable ) - if ( - batch_size is None - and hasattr(self._sampler, "drop_last") - and self._sampler.drop_last - ): - raise ValueError( - "Samplers with drop_last=True must work with a predictable batch-size. " - "Please pass the batch-size to the ReplayBuffer constructor." + self._storage.attach(self) + + # Initialize sampler + self._sampler = self._maybe_make_sampler(self._init_sampler) + + # Initialize writer + self._writer = self._maybe_make_writer(self._init_writer) + self._writer.register_storage(self._storage) + + # Initialize collate function + self._get_collate_fn(self._init_collate_fn) + + # Initialize transform + self._transform = self._maybe_make_transform( + self._init_transform, self._init_transform_factory ) - self._batch_size = batch_size - if dim_extend is not None and dim_extend < 0: - raise ValueError("dim_extend must be a positive value.") - self.dim_extend = dim_extend - self._storage.checkpointer = checkpointer - self.set_rng(generator=generator) - self._initialize_prioritized_sampler() + + # Check batch_size compatibility with sampler + if ( + self._batch_size is None + and hasattr(self._sampler, "drop_last") + and self._sampler.drop_last + ): + raise ValueError( + "Samplers with drop_last=True must work with a predictable batch-size. " + "Please pass the batch-size to the ReplayBuffer constructor." + ) + + # Set dim_extend properly now that storage is initialized + if self._dim_extend is None: + if self._storage is not None: + ndim = self._storage.ndim + self._dim_extend = ndim - 1 + else: + self._dim_extend = 1 + + # Set checkpointer and generator + self._storage.checkpointer = self._init_checkpointer + self.set_rng(generator=self._init_generator) + + # Initialize prioritized sampler if needed + self._initialize_prioritized_sampler() + + # Remove init parameters + self._init_storage = None + self._init_sampler = None + self._init_writer = None + self._init_collate_fn = None + self._init_transform = None + self._init_transform_factory = None + self._init_checkpointer = None + self._init_generator = None + self._init_compilable = None + except Exception as e: + self._initialized = False + raise e + + @property + def initialized(self) -> bool: + """Whether the replay buffer has been initialized.""" + return self._initialized def _initialize_prioritized_sampler(self) -> None: """Initialize priority trees for existing data when using PrioritizedSampler. @@ -385,14 +499,16 @@ def _maybe_make_transform( transform.eval() return transform - def share(self, shared: bool = True): + def share(self, shared: bool = True) -> Self: self.shared = shared if self.shared: self._write_lock = multiprocessing.Lock() else: self._write_lock = contextlib.nullcontext() + return self - def set_rng(self, generator): + @_maybe_delay_init + def set_rng(self, generator) -> None: self._rng = generator self._storage._rng = generator self._sampler._rng = generator @@ -425,7 +541,7 @@ def dim_extend(self, value): ) if value is None: - if self._storage is not None: + if self._initialized and self._storage is not None: ndim = self._storage.ndim value = ndim - 1 else: @@ -447,6 +563,7 @@ def _get_collate_fn(self, collate_fn): ) ) + @_maybe_delay_init def set_storage(self, storage: Storage, collate_fn: Callable | None = None): """Sets a new storage in the replay buffer and returns the previous storage. @@ -462,6 +579,7 @@ def set_storage(self, storage: Storage, collate_fn: Callable | None = None): return prev_storage + @_maybe_delay_init def set_writer(self, writer: Writer): """Sets a new writer in the replay buffer and returns the previous writer.""" prev_writer = self._writer @@ -469,12 +587,14 @@ def set_writer(self, writer: Writer): self._writer.register_storage(self._storage) return prev_writer + @_maybe_delay_init def set_sampler(self, sampler: Sampler): """Sets a new sampler in the replay buffer and returns the previous sampler.""" prev_sampler = self._sampler self._sampler = sampler return prev_sampler + @_maybe_delay_init def __len__(self) -> int: with self._replay_lock: return len(self._storage) @@ -489,6 +609,7 @@ def _setattr(self, attr, value): return None # explicit return for remote calls @property + @_maybe_delay_init def write_count(self) -> int: """The total number of items written so far in the buffer through add and extend.""" return self._writer._write_count @@ -517,6 +638,7 @@ def __repr__(self) -> str: ) return f"{self.__class__.__name__}(\n{storage}, \n{sampler}, \n{writer}, {transform}\n{batch_size}, \n{collate_fn})" + @_maybe_delay_init @pin_memory_output def __getitem__(self, index: int | torch.Tensor | NestedKey) -> Any: if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)): @@ -548,6 +670,7 @@ def __getitem__(self, index: int | torch.Tensor | NestedKey) -> Any: return data + @_maybe_delay_init def __setitem__(self, index, value) -> None: if isinstance(index, str) or (isinstance(index, tuple) and unravel_key(index)): self[:][index] = value @@ -572,6 +695,7 @@ def __setitem__(self, index, value) -> None: self._storage[index] = value return + @_maybe_delay_init def state_dict(self) -> dict[str, Any]: return { "_storage": self._storage.state_dict(), @@ -584,6 +708,7 @@ def state_dict(self) -> dict[str, Any]: else None, } + @_maybe_delay_init def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._storage.load_state_dict(state_dict["_storage"]) self._sampler.load_state_dict(state_dict["_sampler"]) @@ -597,6 +722,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: rng.set_state(state) self.set_rng(generator=rng) + @_maybe_delay_init def dumps(self, path): """Saves the replay buffer on disk at the specified path. @@ -653,6 +779,7 @@ def dumps(self, path): with open(path / "buffer_metadata.json", "w") as file: json.dump({"batch_size": self._batch_size}, file) + @_maybe_delay_init def loads(self, path): """Loads a replay buffer state at the given path. @@ -680,18 +807,22 @@ def loads(self, path): metadata = json.load(file) self._batch_size = metadata["batch_size"] + @_maybe_delay_init def save(self, *args, **kwargs): """Alias for :meth:`dumps`.""" return self.dumps(*args, **kwargs) + @_maybe_delay_init def dump(self, *args, **kwargs): """Alias for :meth:`dumps`.""" return self.dumps(*args, **kwargs) + @_maybe_delay_init def load(self, *args, **kwargs): """Alias for :meth:`loads`.""" return self.loads(*args, **kwargs) + @_maybe_delay_init def register_save_hook(self, hook: Callable[[Any], Any]): """Registers a save hook for the storage. @@ -701,6 +832,7 @@ def register_save_hook(self, hook: Callable[[Any], Any]): """ self._storage.register_save_hook(hook) + @_maybe_delay_init def register_load_hook(self, hook: Callable[[Any], Any]): """Registers a load hook for the storage. @@ -710,6 +842,7 @@ def register_load_hook(self, hook: Callable[[Any], Any]): """ self._storage.register_load_hook(hook) + @_maybe_delay_init def add(self, data: Any) -> int: """Add a single element to the replay buffer. @@ -763,6 +896,7 @@ def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tens self._sampler.extend(index) return index + @_maybe_delay_init def extend( self, data: Sequence, *, update_priority: bool | None = None ) -> torch.Tensor: @@ -803,6 +937,7 @@ def extend( return torch.zeros((0, self._storage.ndim), dtype=torch.long) return self._extend(data, update_priority=update_priority) + @_maybe_delay_init def update_priority( self, index: int | torch.Tensor | tuple[torch.Tensor], @@ -836,6 +971,7 @@ def _sample(self, batch_size: int) -> tuple[Any, dict]: return data, info + @_maybe_delay_init def empty(self, empty_write_count: bool = True): """Empties the replay buffer and reset cursor to 0. @@ -846,6 +982,7 @@ def empty(self, empty_write_count: bool = True): self._sampler._empty() self._storage._empty() + @_maybe_delay_init def sample(self, batch_size: int | None = None, return_info: bool = False) -> Any: """Samples a batch of data from the replay buffer. @@ -904,9 +1041,11 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An return out, info return result[0] + @_maybe_delay_init def mark_update(self, index: int | torch.Tensor) -> None: self._sampler.mark_update(index, storage=self._storage) + @_maybe_delay_init def append_transform( self, transform: Transform, *, invert: bool = False # noqa-F821 ) -> ReplayBuffer: # noqa: D417 @@ -942,6 +1081,7 @@ def append_transform( self._transform.append(transform) return self + @_maybe_delay_init def insert_transform( self, index: int, @@ -970,6 +1110,7 @@ def insert_transform( _iterator = None + @_maybe_delay_init def next(self): """Returns the next item in the replay buffer. @@ -988,6 +1129,7 @@ def next(self): self._iterator = None return None + @_maybe_delay_init def __iter__(self): if self._sampler.ran_out: self._sampler.ran_out = False @@ -1001,6 +1143,7 @@ def __iter__(self): ): yield self.sample() + @_maybe_delay_init def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() if getattr(self, "_rng", None) is not None: @@ -1038,6 +1181,7 @@ def __setstate__(self, state: dict[str, Any]): self.set_rng(rng) @property + @_maybe_delay_init def sampler(self) -> Sampler: """The sampler of the replay buffer. @@ -1047,6 +1191,7 @@ def sampler(self) -> Sampler: return self._sampler @property + @_maybe_delay_init def writer(self) -> Writer: """The writer of the replay buffer. @@ -1056,6 +1201,7 @@ def writer(self) -> Writer: return self._writer @property + @_maybe_delay_init def storage(self) -> Storage: """The storage of the replay buffer. @@ -1064,6 +1210,15 @@ def storage(self) -> Storage: """ return self._storage + @property + @_maybe_delay_init + def transform(self) -> Transform: + """The transform of the replay buffer. + + The transform must be an instance of :class:`~torchrl.envs.transforms.Transform`. + """ + return self._transform + class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized replay buffer. @@ -1133,6 +1288,13 @@ class PrioritizedReplayBuffer(ReplayBuffer): ... rb.add(d) >>> rb.extend(data) + delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform + the first time the buffer is used rather than during construction. + This is useful when the replay buffer needs to be pickled and sent to remote workers, + particularly when using transforms with modules that require gradients. + If not specified, defaults to ``True`` when ``transform_factory`` is provided, + and ``False`` otherwise. + .. note:: Generic prioritized replay buffers (ie. non-tensordict backed) require calling :meth:`~.sample` with the ``return_info`` argument set to @@ -1184,6 +1346,7 @@ def __init__( transform: Transform | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, + delayed_init: bool = False, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -1198,6 +1361,7 @@ def __init__( transform=transform, batch_size=batch_size, dim_extend=dim_extend, + delayed_init=delayed_init, ) @@ -1287,6 +1451,12 @@ class TensorDictReplayBuffer(ReplayBuffer): compilable (bool, optional): whether the writer is compilable. If ``True``, the writer cannot be shared between multiple processes. Defaults to ``False``. + delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform + the first time the buffer is used rather than during construction. + This is useful when the replay buffer needs to be pickled and sent to remote workers, + particularly when using transforms with modules that require gradients. + If not specified, defaults to ``True`` when ``transform_factory`` is provided, + and ``False`` otherwise. Examples: >>> import torch @@ -1404,6 +1574,7 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: return priority + @_maybe_delay_init def add(self, data: TensorDictBase) -> int: if self._transform is not None: with _set_dispatch_td_nn_modules(is_tensor_collection(data)): @@ -1419,6 +1590,7 @@ def add(self, data: TensorDictBase) -> int: self.update_tensordict_priority(data) return index + @_maybe_delay_init def extend( self, tensordicts: TensorDictBase, *, update_priority: bool | None = None ) -> torch.Tensor: @@ -1482,6 +1654,7 @@ def _set_index_in_td(self, tensordict, index): return tensordict.set("index", expand_as_right(index, tensordict)) + @_maybe_delay_init def update_tensordict_priority(self, data: TensorDictBase) -> None: if not isinstance(self._sampler, PrioritizedSampler): return @@ -1664,6 +1837,12 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): compilable (bool, optional): whether the writer is compilable. If ``True``, the writer cannot be shared between multiple processes. Defaults to ``False``. + delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform + the first time the buffer is used rather than during construction. + This is useful when the replay buffer needs to be pickled and sent to remote workers, + particularly when using transforms with modules that require gradients. + If not specified, defaults to ``True`` when ``transform_factory`` is provided, + and ``False`` otherwise. Examples: >>> import torch @@ -1894,6 +2073,12 @@ class ReplayBufferEnsemble(ReplayBuffer): shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + delayed_init (bool, optional): whether to initialize storage, writer, sampler and transform + the first time the buffer is used rather than during construction. + This is useful when the replay buffer needs to be pickled and sent to remote workers, + particularly when using transforms with modules that require gradients. + If not specified, defaults to ``True`` when ``transform_factory`` is provided, + and ``False`` otherwise. Examples: >>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform @@ -1994,6 +2179,14 @@ def __init__( if rbs: if storages is not None or samplers is not None or writers is not None: raise RuntimeError + # Ensure all replay buffers are initialized before creating ensemble + for rb in rbs: + if ( + hasattr(rb, "_delayed_init") + and rb._delayed_init + and not rb.initialized + ): + rb._init() storages = StorageEnsemble( *[rb._storage for rb in rbs], transforms=[rb._transform for rb in rbs] ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 3fe6ce7b919..4669ce103ff 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -6,8 +6,10 @@ import abc import logging +import multiprocessing as mp import os import sys +import tempfile import textwrap import warnings from collections import OrderedDict @@ -29,7 +31,6 @@ from tensordict.base import _NESTED_TENSORS_AS_LISTS from tensordict.memmap import MemoryMappedTensor from tensordict.utils import _zip_strict -from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger @@ -578,7 +579,7 @@ def __init__( storage, max_size=None, *, - device: torch.device = "cpu", + device: torch.device | str = "cpu", ndim: int = 1, compilable: bool = False, ): @@ -746,13 +747,15 @@ def __getstate__(self): del state["_len_value"] state["len__context"] = length elif not self.initialized: - # check that the storage is initialized - raise RuntimeError( - f"Cannot share a storage of type {type(self)} between processes if " - f"it has not been initialized yet. Populate the buffer with " - f"some data in the main process before passing it to the other " - f"subprocesses (or create the buffer explicitly with a TensorStorage)." - ) + if not self.shared_init: + # check that the storage is initialized + raise RuntimeError( + f"Cowardly refusing to share a storage of type {type(self)} between processes if " + f"it has not been initialized yet. You can either:\n" + f"- Populate the buffer with some data in the main process before passing it to the other processes (or create the buffer explicitly with a TensorStorage).\n" + f"- set shared_init=True when creating the storage such that it can be initialized by the remote processes." + ) + return state else: # check that the content is shared, otherwise tell the user we can't help storage = self._storage @@ -782,10 +785,10 @@ def __setstate__(self, state): len = state.pop("len__context", None) if len is not None: if not state["_compilable"]: - state["_len_value"] = len - else: _len_value = mp.Value("i", len) state["_len_value"] = _len_value + else: + state["_len_value"] = len self.__dict__.update(state) def state_dict(self) -> dict[str, Any]: @@ -896,6 +899,8 @@ def set( self._init(tree_map(lambda x: x[0], data)) else: self._init(data) + assert self.initialized + if is_tensor_collection(data): self._storage[cursor] = data else: @@ -939,6 +944,8 @@ def set( # noqa: F811 self._init(data[0]) else: self._init(data) + assert self.initialized + if not isinstance(cursor, (*INT_CLASSES, slice)): if not isinstance(cursor, torch.Tensor): cursor = torch.tensor(cursor, dtype=torch.long) @@ -955,10 +962,15 @@ def set( # noqa: F811 ) self._storage[cursor] = data + def _wait_for_init(self): + pass + def get(self, index: int | Sequence[int] | slice) -> Any: _storage = self._storage is_tc = is_tensor_collection(_storage) if not self.initialized: + if getattr(self, "shared_init", False): + self._wait_for_init() raise RuntimeError("Cannot get elements out of a non-initialized storage.") if not self._is_full: if is_tc: @@ -1058,6 +1070,12 @@ class LazyTensorStorage(TensorStorage): Defaults to ``False``. consolidated (bool, optional): if ``True``, the storage will be consolidated after its first expansion. Defaults to ``False``. + shared_init (bool, optional): if ``True``, enables multiprocess coordination + during storage initialization. First process initializes with memmap, + others wait and load from the shared memmap. Defaults to ``False``. + cleanup_memmap (bool, optional): if ``True`` and ``shared_init=True``, + the temporary memmap will be deleted after initialization and the + storage will operate in RAM. Defaults to ``True``. Examples: >>> data = TensorDict({ @@ -1115,10 +1133,12 @@ def __init__( self, max_size: int, *, - device: torch.device = "cpu", + device: torch.device | str = "cpu", ndim: int = 1, compilable: bool = False, consolidated: bool = False, + shared_init: bool = False, + cleanup_memmap: bool = True, ): super().__init__( storage=None, @@ -1128,11 +1148,59 @@ def __init__( compilable=compilable, ) self.consolidated = consolidated + self.shared_init = shared_init + self.cleanup_memmap = cleanup_memmap + + # Initialize multiprocess coordination objects if shared_init is enabled + if self.shared_init: + if self._compilable: + raise RuntimeError( + "Cannot share a compilable storage between processes." + ) + self._init_lock = mp.Lock() + self._init_event = mp.Event() + self._make_init_directory() + + def _make_init_directory(self): + if getattr(self, "scratch_dir", None) is not None: + self._init_directory = self.scratch_dir + return + # Create a shared directory + self.scratch_dir = self._init_directory = tempfile.mkdtemp( + prefix="torchrl_storage_init_" + ) + return def _init( self, data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 ) -> None: + if not self.shared_init: + return self._init_standard(data) + + # Try to become coordinator + is_coordinator = not self._init_event.is_set() + is_coordinator = is_coordinator and self._init_lock.acquire(block=False) + + if is_coordinator: + try: + # We are the coordinator + self._init_coordinator(data) + finally: + # Signal other processes that initialization is complete + self._init_event.set() + self._init_lock.release() + else: + # Failed to acquire lock, wait for coordinator + self._wait_for_init() + + self.initialized = True + + def _init_standard( + self, + data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 + ) -> None: + """Standard initialization without multiprocess coordination.""" if not self._compilable: # TODO: Investigate why this seems to have a performance impact with # the compiler @@ -1177,6 +1245,39 @@ def max_size_along_dim0(data_shape): f"Initialized LazyTensorStorage with {self._storage.shape} shape" ) + def _init_coordinator( + self, + data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 + ) -> None: + """Initialize storage as the coordinating process using temporary memmap.""" + # Use LazyMemmapStorage which does everything we want + temp_memmap_storage = LazyMemmapStorage( + max_size=self.max_size, + scratch_dir=self._init_directory, + ndim=self.ndim, + existsok=False, + shared_init=False, # Don't recurse + ) + temp_memmap_storage._init_standard(data) + self._storage = temp_memmap_storage._storage + return + + def _wait_for_init(self) -> None: + # wait till coordinator has initialized + self._init_event.wait() + storage = TensorDict.load_memmap(self._init_directory) + self._storage = storage + self.initialized = True + return + + # Read blocks + def get(self, indices: slice) -> TensorDictBase | torch.Tensor | Any: + if not self.initialized and self.shared_init: + # Trigger initialization with dummy data + self._wait_for_init() + idx = super().get(indices) + return idx + class LazyMemmapStorage(LazyTensorStorage): """A memory-mapped storage for tensors and tensordicts. @@ -1187,6 +1288,8 @@ class LazyMemmapStorage(LazyTensorStorage): Keyword Args: scratch_dir (str or path): directory where memmap-tensors will be written. + If ``shared_init=True`` and no ``scratch_dir`` is provided, a shared + temporary directory will be created automatically. device (torch.device, optional): device where the sampled tensors will be stored and sent. Default is :obj:`torch.device("cpu")`. If ``None`` is provided, the device is automatically gathered from the @@ -1199,6 +1302,9 @@ class LazyMemmapStorage(LazyTensorStorage): existsok (bool, optional): whether an error should be raised if any of the tensors already exists on disk. Defaults to ``True``. If ``False``, the tensor will be opened as is, not overewritten. + shared_init (bool, optional): if ``True``, enables multiprocess coordination + during storage initialization. First process initializes the memmap, + others wait and load from the shared directory. Defaults to ``False``. .. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is already stored to avoid executing long copies of data that is already stored on disk. @@ -1273,12 +1379,12 @@ def __init__( max_size: int, *, scratch_dir=None, - device: torch.device = "cpu", + device: torch.device | str = "cpu", ndim: int = 1, existsok: bool = False, compilable: bool = False, + shared_init: bool = False, ): - super().__init__(max_size, ndim=ndim, compilable=compilable) self.initialized = False self.scratch_dir = None self.existsok = existsok @@ -1286,6 +1392,13 @@ def __init__( self.scratch_dir = str(scratch_dir) if self.scratch_dir[-1] != "/": self.scratch_dir += "/" + super().__init__( + max_size, + ndim=ndim, + compilable=compilable, + shared_init=shared_init, + cleanup_memmap=False, + ) self.device = ( _make_ordinal_device(torch.device(device)) if device != "auto" @@ -1355,7 +1468,31 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] - def _init(self, data: TensorDictBase | torch.Tensor) -> None: + def _init( + self, + data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821 + ) -> None: + if not self.shared_init: + return self._init_standard(data) + is_coordinator = not self._init_event.is_set() + is_coordinator = is_coordinator and self._init_lock.acquire(block=False) + + if is_coordinator: + # coordinator init + try: + return self._init_coordinator(data) + finally: + self._init_event.set() + self._init_lock.release() + else: + # Standard initialization + self._wait_for_init() + self.initialized = True + + def _init_coordinator(self, data: TensorDictBase | torch.Tensor | Any) -> None: + return self._init_standard(data) + + def _init_standard(self, data: TensorDictBase | torch.Tensor) -> None: torchrl_logger.debug("Creating a MemmapStorage...") if self.device == "auto": self.device = data.device @@ -1402,6 +1539,9 @@ def max_size_along_dim0(data_shape): self.initialized = True def get(self, index: int | Sequence[int] | slice) -> Any: + if not self.initialized and self.shared_init: + # Trigger initialization with dummy data + self._wait_for_init() result = super().get(index) return result diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 1eed520f49a..814853f57ed 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -157,6 +157,7 @@ class RoundRobinWriter(Writer): def __init__(self, compilable: bool = False) -> None: super().__init__(compilable=compilable) self._cursor = 0 + self._write_count # noqa def dumps(self, path): path = Path(path).absolute() @@ -280,18 +281,28 @@ def __getstate__(self): state = super().__getstate__() if get_spawning_popen() is None: cursor = self._cursor + write_count = self._write_count del state["_cursor_value"] + del state["_write_count_value"] state["cursor__context"] = cursor + state["write_count__context"] = write_count return state def __setstate__(self, state): cursor = state.pop("cursor__context", None) + write_count = state.pop("write_count__context", None) if cursor is not None: if not state["_compilable"]: _cursor_value = mp.Value("i", cursor) else: _cursor_value = cursor state["_cursor_value"] = _cursor_value + if write_count is not None: + if not state["_compilable"]: + _write_count_value = mp.Value("i", write_count) + else: + _write_count_value = write_count + state["_write_count_value"] = _write_count_value self.__dict__.update(state) def __repr__(self): @@ -603,8 +614,19 @@ def __getstate__(self): f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed." ) state = super().__getstate__() + # Handle the mp.Value object for pickling + if "_write_count_value" in state: + write_count = self._write_count + del state["_write_count_value"] + state["write_count__context"] = write_count return state + def __setstate__(self, state): + write_count = state.pop("write_count__context", None) + if write_count is not None: + state["_write_count_value"] = mp.Value("i", write_count) + self.__dict__.update(state) + def dumps(self, path): path = Path(path).absolute() path.mkdir(exist_ok=True) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 4542d55cfd1..f6eca44be41 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6900,7 +6900,7 @@ def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: n = len(shape) if dim < -(n + 1) or dim > n: raise ValueError( - f"Dimension out of range, expected value in the range [{-(n+1)}, {n}], but " + f"Dimension out of range, expected value in the range [{-(n + 1)}, {n}], but " f"got {dim}" ) if dim < 0: diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index a9fd7e28434..f3eef9edea3 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -501,10 +501,10 @@ def _process_llm_response(self, response: str, i: int) -> list[str]: if result["success"]: results.append( - f"Code block {i+1} executed successfully:\n{result['stdout']}" + f"Code block {i + 1} executed successfully:\n{result['stdout']}" ) else: - results.append(f"Code block {i+1} failed:\n{result['stderr']}") + results.append(f"Code block {i + 1} failed:\n{result['stderr']}") return results diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index 5cf2deb3c7a..92c25f118ba 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -1617,7 +1617,7 @@ def select_actor( for i, strategy in enumerate(self.strategies): try: torchrl_logger.debug( - f"Trying strategy {i+1}/{len(self.strategies)}: {strategy}" + f"Trying strategy {i + 1}/{len(self.strategies)}: {strategy}" ) if strategy == "prefix-aware": diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 0a4d49b2232..c1fd12fb34f 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -240,8 +240,8 @@ def from_stateful_net(self, stateful_net: nn.Module): if keyset0 != keyset1: raise RuntimeError( f"The keys of params and provided module differ: " - f"{keyset1-keyset0} are in self.params and not in the module, " - f"{keyset0-keyset1} are in the module but not in self.params." + f"{keyset1 - keyset0} are in self.params and not in the module, " + f"{keyset0 - keyset1} are in the module but not in self.params." ) self.params.data.update_(params.data) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 3295088d7de..1f6daabced1 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -214,8 +214,8 @@ def from_stateful_net(self, network_name: str, stateful_net: nn.Module): if keyset0 != keyset1: raise RuntimeError( f"The keys of params and provided module differ: " - f"{keyset1-keyset0} are in self.params and not in the module, " - f"{keyset0-keyset1} are in the module but not in self.params." + f"{keyset1 - keyset0} are in self.params and not in the module, " + f"{keyset0 - keyset1} are in the module but not in self.params." ) self_params.data.update_(params.data) diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index d355d1888c5..aed3cbb9e3d 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -168,7 +168,7 @@ t1 = time.time() torchrl_logger.info( - f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s." + f"solved after {total_count} steps, {total_episodes} episodes and in {t1 - t0}s." ) ################################# diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 501c9be0a05..0ece54926f1 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -435,7 +435,7 @@ def assert0(x): buffer_lazy.extend(data) for _i, _ in enumerate(buffer_lazy): continue -print(f"A total of {_i+1} batches have been collected") +print(f"A total of {_i + 1} batches have been collected") del buffer_lazy