From dcd18171eb549df9fe604a5f23b982375d47ff7f Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Sun, 2 Nov 2025 06:51:29 -0800 Subject: [PATCH 1/3] pending container added --- src/gfn/containers/replay_buffer.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 3b6cd49a..b1e7169e 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -74,6 +74,7 @@ def __init__( self.training_container: ContainerUnion | None = None self.prioritized_capacity = prioritized_capacity self.prioritized_sampling = prioritized_sampling + self.pending_container: ContainerUnion | None = None # Remote buffer fields self.remote_manager_rank = remote_manager_rank @@ -115,8 +116,23 @@ def add(self, training_container: ContainerUnion) -> float | None: # Handle remote buffer communication if self.remote_manager_rank is not None: self._add_counter += 1 + + if self.pending_container is None: + self.pending_container = self.initialize(training_container) + assert self.pending_container is not None + assert isinstance(training_container, type(self.pending_container)) # type: ignore + + self.pending_container.extend(training_container) # type: ignore + + if isinstance(self.pending_container, (Trajectories, Transitions)): + self.pending_container.log_probs = None + if isinstance(self.pending_container, Trajectories): + self.pending_container.estimator_outputs = None if self._add_counter % self.remote_buffer_freq == 0: - return self._send_objs(training_container) + score = self._send_objs(self.pending_container) + self.pending_container = None + print("Cleared pending container.", flush=True) + return score def _send_objs(self, training_container: ContainerUnion) -> float: """Sends a training container to the remote manager.""" @@ -166,11 +182,11 @@ def initialize(self, training_container: ContainerUnion) -> None: object to set the buffer type. """ if isinstance(training_container, Trajectories): - self.training_container = cast(ContainerUnion, Trajectories(self.env)) + return cast(ContainerUnion, Trajectories(self.env)) # type: ignore elif isinstance(training_container, Transitions): - self.training_container = cast(ContainerUnion, Transitions(self.env)) + return cast(ContainerUnion, Transitions(self.env)) # type: ignore elif isinstance(training_container, StatesContainer): - self.training_container = cast(ContainerUnion, StatesContainer(self.env)) + return cast(ContainerUnion, StatesContainer(self.env)) # type: ignore else: raise ValueError(f"Unsupported type: {type(training_container)}") @@ -182,7 +198,7 @@ def _add_objs(self, training_container: ContainerUnion): to add. """ if self.training_container is None: - self.initialize(training_container) + self.training_container = self.initialize(training_container) assert self.training_container is not None assert isinstance(training_container, type(self.training_container)) # type: ignore From d41266981301c12cb750d93ca0e29b6676567bae Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Mon, 10 Nov 2025 08:40:09 -0800 Subject: [PATCH 2/3] Added args and removed print --- src/gfn/containers/replay_buffer.py | 1 - tutorials/examples/train_hypergrid.py | 9 ++++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index b1e7169e..c2630409 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -131,7 +131,6 @@ def add(self, training_container: ContainerUnion) -> float | None: if self._add_counter % self.remote_buffer_freq == 0: score = self._send_objs(self.pending_container) self.pending_container = None - print("Cleared pending container.", flush=True) return score def _send_objs(self, training_container: ContainerUnion) -> float: diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index e251fc70..306e2be0 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -629,7 +629,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: capacity=args.replay_buffer_size, prioritized_capacity=False, remote_manager_rank=distributed_context.assigned_buffer, - remote_buffer_freq=1, + remote_buffer_freq=args.remote_buffer_freq, ) gflownet = gflownet.to(device) @@ -989,6 +989,13 @@ def cleanup(): help="Distributed backend to use: gloo, ccl or mpi", ) + parser.add_argument( + "--remote_buffer_freq", + type=int, + default=1, + help="Frequency (in training iterations) at which training ranks sends trajectories to remote replay buffer", + ) + # Selective averaging settings. parser.add_argument( "--use_selective_averaging", From 4e3289539e413fb691e3883511d2722bc81d5809 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 11 Dec 2025 22:41:09 -0500 Subject: [PATCH 3/3] tests fixed --- tutorials/examples/test_scripts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 1b1b71ee..cacd2637 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -107,6 +107,7 @@ class HypergridArgs(CommonArgs): replay_buffer_size: int = 0 timing: bool = True half_precision: bool = False + remote_buffer_freq = 1 @dataclass