Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def make_env():
if i == 10:
t0 = time.time()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
collector.shutdown()
exit()

Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/delayed_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def make_env():
if i == 10:
t0 = time.time()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
collector.shutdown()
exit()

Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ def gym_make():
t0 = time.time()
collector.shutdown()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
exit()
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,5 @@ def gym_make():
t0 = time.time()
collector.shutdown()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
exit()
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,5 @@ def gym_make():
t0 = time.time()
collector.shutdown()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
exit()
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,5 @@ def gym_make():
t0 = time.time()
collector.shutdown()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
exit()
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,5 @@ def gym_make():
t0 = time.time()
collector.shutdown()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
exit()
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,5 @@ def gym_make():
t0 = time.time()
collector.shutdown()
t1 = time.time()
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
exit()
2 changes: 1 addition & 1 deletion examples/rlhf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def main(cfg):
elif it % log_interval == 0:
# loss as float. note: this is a CPU-GPU sync point
loss = batch.loss.item()
msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms"
msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt * 1000:.2f}ms"
torchrl_logger.info(msg)
loss_logger.info(msg)

Expand Down
2 changes: 1 addition & 1 deletion examples/rlhf/train_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main(cfg):
acc = _accuracy(
batch.chosen_data.end_scores, batch.rejected_data.end_scores
)
msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms"
msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt * 1000:.2f}ms"
torchrl_logger.info(msg)
loss_logger.info(msg)

Expand Down
2 changes: 1 addition & 1 deletion test/llm/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,7 +2536,7 @@ def test_batching(
# Create 2 threads and send inputs
inputs = [
TensorDict(
text=Text(prompt=[f"Question {i}?", f"Question {i+2}?"]),
text=Text(prompt=[f"Question {i}?", f"Question {i + 2}?"]),
batch_size=(2,),
)
for i in range(2)
Expand Down
12 changes: 6 additions & 6 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3465,7 +3465,7 @@ def test_direct_download(self, task, tmpdir):
def test_d4rl_dummy(self, task):
t0 = time.time()
_ = D4RLExperienceReplay(task, split_trajs=True, from_env=True, batch_size=2)
torchrl_logger.info(f"terminated test after {time.time()-t0}s")
torchrl_logger.info(f"terminated test after {time.time() - t0}s")

@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
@pytest.mark.parametrize("split_trajs", [True, False])
Expand All @@ -3490,7 +3490,7 @@ def test_dataset_build(self, task, split_trajs, from_env):
offline = sample.get(key)
# assert sim.dtype == offline.dtype, key
assert sim.shape[-1] == offline.shape[-1], key
torchrl_logger.info(f"terminated test after {time.time()-t0}s")
torchrl_logger.info(f"terminated test after {time.time() - t0}s")

@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
@pytest.mark.parametrize("split_trajs", [True, False])
Expand All @@ -3509,7 +3509,7 @@ def test_d4rl_iteration(self, task, split_trajs):
for sample in data: # noqa: B007
i += 1
assert len(data) // i == batch_size
torchrl_logger.info(f"terminated test after {time.time()-t0}s")
torchrl_logger.info(f"terminated test after {time.time() - t0}s")


_MINARI_DATASETS = []
Expand Down Expand Up @@ -3769,7 +3769,7 @@ def test_load(self, dataset_idx, split):
t0 = time.time()
for i, sample in enumerate(data):
t1 = time.time()
torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms")
torchrl_logger.info(f"sampling time {1000 * (t1 - t0): 4.4f}ms")
assert data.metadata["action_space"].is_in(sample["action"])
assert data.metadata["observation_space"].is_in(sample["observation"])
t0 = time.time()
Expand Down Expand Up @@ -3907,7 +3907,7 @@ def test_load(self):
t0 = time.time()
for i, _ in enumerate(data):
t1 = time.time()
torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms")
torchrl_logger.info(f"sampling time {1000 * (t1 - t0): 4.4f}ms")
t0 = time.time()
if i == 10:
break
Expand Down Expand Up @@ -3961,7 +3961,7 @@ def test_load(self, image_size):
assert (batch.get("pixels") != 0).any()
assert (batch.get(("next", "pixels")) != 0).any()
t1 = time.time()
torchrl_logger.info(f"sampling time {1000 * (t1-t0): 4.4f}ms")
torchrl_logger.info(f"sampling time {1000 * (t1 - t0): 4.4f}ms")
t0 = time.time()
if i == 10:
break
Expand Down
Loading
Loading