Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
self.training_container: ContainerUnion | None = None
self.prioritized_capacity = prioritized_capacity
self.prioritized_sampling = prioritized_sampling
self.pending_container: ContainerUnion | None = None

# Remote buffer fields
self.remote_manager_rank = remote_manager_rank
Expand Down Expand Up @@ -113,8 +114,22 @@ def add(self, training_container: ContainerUnion) -> dict[str, float] | None:
# Handle remote buffer communication.
if self.remote_manager_rank is not None:
self._add_counter += 1

if self.pending_container is None:
self.pending_container = self.initialize(training_container)
assert self.pending_container is not None
assert isinstance(training_container, type(self.pending_container)) # type: ignore

self.pending_container.extend(training_container) # type: ignore

if isinstance(self.pending_container, (Trajectories, Transitions)):
self.pending_container.log_probs = None
if isinstance(self.pending_container, Trajectories):
self.pending_container.estimator_outputs = None
if self._add_counter % self.remote_buffer_freq == 0:
return self._send_objs(training_container)
score = self._send_objs(self.pending_container)
self.pending_container = None
return score

def _send_objs(self, training_container: ContainerUnion) -> dict[str, float]:
"""Sends a training container to the remote manager."""
Expand Down Expand Up @@ -170,11 +185,11 @@ def initialize(self, training_container: ContainerUnion) -> None:
object to set the buffer type.
"""
if isinstance(training_container, Trajectories):
self.training_container = cast(ContainerUnion, Trajectories(self.env))
return cast(ContainerUnion, Trajectories(self.env)) # type: ignore
elif isinstance(training_container, Transitions):
self.training_container = cast(ContainerUnion, Transitions(self.env))
return cast(ContainerUnion, Transitions(self.env)) # type: ignore
elif isinstance(training_container, StatesContainer):
self.training_container = cast(ContainerUnion, StatesContainer(self.env))
return cast(ContainerUnion, StatesContainer(self.env)) # type: ignore
else:
raise ValueError(f"Unsupported type: {type(training_container)}")

Expand All @@ -186,7 +201,7 @@ def _add_objs(self, training_container: ContainerUnion):
to add.
"""
if self.training_container is None:
self.initialize(training_container)
self.training_container = self.initialize(training_container)
assert self.training_container is not None
assert isinstance(training_container, type(self.training_container)) # type: ignore

Expand Down
1 change: 1 addition & 0 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class HypergridArgs(CommonArgs):
replay_buffer_size: int = 0
timing: bool = True
half_precision: bool = False
remote_buffer_freq = 1


@dataclass
Expand Down
9 changes: 8 additions & 1 deletion tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]:
capacity=args.replay_buffer_size,
prioritized_capacity=False,
remote_manager_rank=distributed_context.assigned_buffer,
remote_buffer_freq=1,
remote_buffer_freq=args.remote_buffer_freq,
)

gflownet = gflownet.to(device)
Expand Down Expand Up @@ -1141,6 +1141,13 @@ def cleanup():
help="Distributed backend to use: gloo, ccl or mpi",
)

parser.add_argument(
"--remote_buffer_freq",
type=int,
default=1,
help="Frequency (in training iterations) at which training ranks sends trajectories to remote replay buffer",
)

# Selective averaging settings.
parser.add_argument(
"--use_selective_averaging",
Expand Down