diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index f2ff6f03..f135c6f2 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -74,6 +74,7 @@ def __init__( self.training_container: ContainerUnion | None = None self.prioritized_capacity = prioritized_capacity self.prioritized_sampling = prioritized_sampling + self.pending_container: ContainerUnion | None = None # Remote buffer fields self.remote_manager_rank = remote_manager_rank @@ -113,8 +114,22 @@ def add(self, training_container: ContainerUnion) -> dict[str, float] | None: # Handle remote buffer communication. if self.remote_manager_rank is not None: self._add_counter += 1 + + if self.pending_container is None: + self.pending_container = self.initialize(training_container) + assert self.pending_container is not None + assert isinstance(training_container, type(self.pending_container)) # type: ignore + + self.pending_container.extend(training_container) # type: ignore + + if isinstance(self.pending_container, (Trajectories, Transitions)): + self.pending_container.log_probs = None + if isinstance(self.pending_container, Trajectories): + self.pending_container.estimator_outputs = None if self._add_counter % self.remote_buffer_freq == 0: - return self._send_objs(training_container) + score = self._send_objs(self.pending_container) + self.pending_container = None + return score def _send_objs(self, training_container: ContainerUnion) -> dict[str, float]: """Sends a training container to the remote manager.""" @@ -170,11 +185,11 @@ def initialize(self, training_container: ContainerUnion) -> None: object to set the buffer type. """ if isinstance(training_container, Trajectories): - self.training_container = cast(ContainerUnion, Trajectories(self.env)) + return cast(ContainerUnion, Trajectories(self.env)) # type: ignore elif isinstance(training_container, Transitions): - self.training_container = cast(ContainerUnion, Transitions(self.env)) + return cast(ContainerUnion, Transitions(self.env)) # type: ignore elif isinstance(training_container, StatesContainer): - self.training_container = cast(ContainerUnion, StatesContainer(self.env)) + return cast(ContainerUnion, StatesContainer(self.env)) # type: ignore else: raise ValueError(f"Unsupported type: {type(training_container)}") @@ -186,7 +201,7 @@ def _add_objs(self, training_container: ContainerUnion): to add. """ if self.training_container is None: - self.initialize(training_container) + self.training_container = self.initialize(training_container) assert self.training_container is not None assert isinstance(training_container, type(self.training_container)) # type: ignore diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 1b1b71ee..cacd2637 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -107,6 +107,7 @@ class HypergridArgs(CommonArgs): replay_buffer_size: int = 0 timing: bool = True half_precision: bool = False + remote_buffer_freq = 1 @dataclass diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 0313eeb8..6f3db577 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -775,7 +775,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: capacity=args.replay_buffer_size, prioritized_capacity=False, remote_manager_rank=distributed_context.assigned_buffer, - remote_buffer_freq=1, + remote_buffer_freq=args.remote_buffer_freq, ) gflownet = gflownet.to(device) @@ -1141,6 +1141,13 @@ def cleanup(): help="Distributed backend to use: gloo, ccl or mpi", ) + parser.add_argument( + "--remote_buffer_freq", + type=int, + default=1, + help="Frequency (in training iterations) at which training ranks sends trajectories to remote replay buffer", + ) + # Selective averaging settings. parser.add_argument( "--use_selective_averaging",