Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3307fff
sync changes to cluster
josephdviviano Oct 30, 2025
dbace9d
sync to mila
josephdviviano Oct 31, 2025
cde64bb
changes to sync for agent
josephdviviano Oct 31, 2025
5f4033b
simple prototype (need to add FAISS)
josephdviviano Oct 31, 2025
418f49f
getting ready for sync
josephdviviano Oct 31, 2025
32df6c4
merge
josephdviviano Oct 31, 2025
e78e363
sketch of scroring is working
josephdviviano Nov 1, 2025
1cd9cf6
remove perf tracker
younik Nov 4, 2025
2766082
fix score for single node
younik Nov 6, 2025
b41878b
improve logging
younik Nov 7, 2025
581d283
refactor: move assertion from get_trajectory_pfs_and_pbs to test
ali-m07 Nov 26, 2025
e6ac53b
black fix
josephdviviano Nov 27, 2025
6c9ec7b
requested device tweaks
josephdviviano Nov 28, 2025
0819089
utility to compile gflownets
josephdviviano Nov 28, 2025
c31a0e6
added default actions
josephdviviano Nov 28, 2025
f1fed0e
added benchmarks for various model / loss compilations
josephdviviano Nov 28, 2025
d605e02
changed performance mode to deterministic mode -- which is off by def…
josephdviviano Nov 30, 2025
ceb686c
style: run black formatter on test file
ali-m07 Nov 30, 2025
31858d7
Merge pull request #438 from ali-m07/fix-todo-move-assertion-to-test
josephdviviano Dec 1, 2025
6b5d57c
changes for train hypergrid
younik Dec 10, 2025
6063f24
Merge branch 'master' into multinode_experiments
younik Dec 10, 2025
aedb7d8
pre-commit
younik Dec 10, 2025
2087ff0
fix pre-commit
younik Dec 10, 2025
3ccb8db
fix isort
younik Dec 10, 2025
7c22f79
set_seed is no longer by default deterministic
josephdviviano Dec 10, 2025
977db75
added debug path to is_initial/sink_state
josephdviviano Dec 10, 2025
9c15288
Merge pull request #444 from GFNOrg/seed_fix
josephdviviano Dec 10, 2025
d19cd20
improved paths for mask computation
josephdviviano Dec 11, 2025
e825e07
wired in debug flags into all envs, added checking mechanism to assis…
josephdviviano Dec 11, 2025
c51639d
added debug_protected hot paths to all actions, reformatted all scrip…
josephdviviano Dec 11, 2025
27767a1
fixed test args
josephdviviano Dec 11, 2025
d62340a
Merge branch 'master' of github.com:GFNOrg/torchgfn into optimize_states
josephdviviano Dec 11, 2025
c0d0237
added benchmarking code and plotting notebook
josephdviviano Dec 11, 2025
7ed6244
loosened test
josephdviviano Dec 11, 2025
5ca6408
Merge pull request #447 from GFNOrg/multinode_experiments
josephdviviano Dec 11, 2025
ece656f
loosened test
josephdviviano Dec 11, 2025
9c6ef3a
skipping the benchmarking code which takes too long
josephdviviano Dec 11, 2025
ab56308
made a new folder for the benchmarking notebook
josephdviviano Dec 11, 2025
b1bb28c
Merge pull request #449 from GFNOrg/optimize_states
josephdviviano Dec 11, 2025
862c873
vectorized operation
josephdviviano Dec 12, 2025
ff99fbc
Merge branch 'master' of github.com:GFNOrg/torchgfn into compile_utility
josephdviviano Dec 12, 2025
9684536
swapped check_action_validity for debug
josephdviviano Dec 12, 2025
9748192
Merge pull request #443 from GFNOrg/compile_utility
josephdviviano Dec 12, 2025
cf641d1
move dummy/exit action to actions.py, changed comments in compile.py
josephdviviano Dec 12, 2025
953c89b
merge
josephdviviano Dec 12, 2025
6f620ad
Merge pull request #450 from GFNOrg/dummy_exit_actions_for_graph_buil…
josephdviviano Dec 12, 2025
ebdc37b
merge conflicts
josephdviviano Dec 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ jobs:
pip install .[all] # Install all dependencies, including dev

- name: Test tutorials notebooks
run: pytest --nbmake tutorials/notebooks --nbmake-timeout=600
run: pytest --nbmake tutorials/notebooks --nbmake-timeout=600
--ignore tutorials/notebooks/intro_torchgfn_performance_tuning.ipynb
20 changes: 20 additions & 0 deletions docs/source/guides/states_actions_containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ Because multiple trajectories can have different lengths, batching requires appe

For discrete environments, the action set is represented with the set $\{0, \dots, n_{actions} - 1\}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \rightarrow s_f$, but not all actions are possible at all states. For discrete environments, each `States` object is endowed with two extra attributes: `forward_masks` and `backward_masks`, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the `DiscreteStates` abstract subclass of `States`. The `forward_masks` tensor is of shape `(*batch_shape, n_{actions})`, and `backward_masks` is of shape `(*batch_shape, n_{actions} - 1)`. Each subclass of `DiscreteStates` needs to implement the `update_masks` function, that uses the environment's logic to define the two tensors.

### Debug guards and factory signatures

To keep compiled hot paths fast, `States`/`DiscreteStates`/`GraphStates` expect a `debug` flag passed at construction time. When `debug=False` (default) no Python-side checks run in hot paths; when `debug=True`, shape/device/type guards run to catch silent bugs. Environments carry an env-level `debug` and pass it when they instantiate `States`.

When defining your own `States` subclass or environment factories, make sure all state factories accept `debug`:

- Constructors: `__init__(..., debug: bool = False, ...)` should store `self.debug` and pass it along when cloning or slicing.
- Factory classmethods: `make_random_states`, `make_initial_states`, `make_sink_states` (and any overrides) **must** accept `debug` (or `**kwargs`) and forward it to `States(...)`. The base class enforces this and will raise a clear `TypeError` otherwise.
- Env helpers: if you override `states_from_tensor` or `reset` in an environment, thread `self.debug` into state construction so all emitted states share the env-level setting.

This pattern avoids graph breaks in `torch.compile` by letting you keep `debug=False` in compiled runs while still enabling strong checks in development and tests.

## Actions

Actions should be though of as internal actions of an agent building a compositional object. They correspond to transitions $s \rightarrow s'$. An abstract `Actions` class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise.
Expand All @@ -24,6 +36,14 @@ Additionally, each subclass needs to define two more class variable tensors:
- `dummy_action`: A tensor that is padded to sequences of actions in the shorter trajectories of a batch of trajectories. It is `[-1]` for discrete environments.
- `exit_action`: A tensor that corresponds to the termination action. It is `[n_{actions} - 1]` fo discrete environments.

### Debug guards and factory signatures

`Actions` mirrors the `States` pattern: constructors and factories accept `debug: bool = False`. Keep `debug=False` in compiled/hot paths to avoid Python-side asserts; flip it on in development/tests to run shape/device validations. When defining custom subclasses, ensure:

- `__init__(..., debug: bool = False, ...)` stores `self.debug` and only runs validations when `debug` is True.
- Factory classmethods (`make_dummy_actions`, `make_exit_actions`, helpers like `from_tensor_dict`) accept `debug` (or `**kwargs`) and forward it to the constructor.
- Environment helpers (`actions_from_tensor`, `actions_from_batch_shape`) should thread the env-level `debug` so all emitted actions share the setting.

## Containers

Containers are collections of `States`, along with other information, such as reward values, or densities $p(s' \mid s)$. Three containers are available:
Expand Down
138 changes: 101 additions & 37 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ class Actions(ABC):
# The following class variable corresponds to $s \rightarrow s_f$ transitions.
exit_action: ClassVar[torch.Tensor] # Action to exit the environment.

def __init__(self, tensor: torch.Tensor):
def __init__(self, tensor: torch.Tensor, debug: bool = False):
"""Initializes an Actions object with a batch of actions.

Args:
tensor: Tensor of shape (*batch_shape, *action_shape) representing a batch of
actions.
"""
assert (
tensor.shape[-len(self.action_shape) :] == self.action_shape
), f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}."
# Debug-only validation keeps hot paths tensor-only when debug is False.
if debug:
if tensor.shape[-len(self.action_shape) :] != self.action_shape:
raise ValueError(
f"Batched actions tensor has shape {tensor.shape}, expected {self.action_shape}."
)

self.tensor = tensor
self.debug = debug

@property
def device(self) -> torch.device:
Expand All @@ -72,13 +76,17 @@ def batch_shape(self) -> tuple[int, ...]:

@classmethod
def make_dummy_actions(
cls, batch_shape: tuple[int, ...], device: torch.device | None = None
cls,
batch_shape: tuple[int, ...],
device: torch.device | None = None,
debug: bool = False,
) -> Actions:
"""Creates an Actions object filled with dummy actions.

Args:
batch_shape: Shape of the batch dimensions.
device: The device to create the actions on.
debug: Whether to run debug validations on the constructed Actions.

Returns:
An Actions object with the specified batch shape filled with dummy actions.
Expand All @@ -87,17 +95,21 @@ def make_dummy_actions(
tensor = cls.dummy_action.repeat(*batch_shape, *((1,) * action_ndim))
if device is not None:
tensor = tensor.to(device)
return cls(tensor)
return cls(tensor, debug=debug)

@classmethod
def make_exit_actions(
cls, batch_shape: tuple[int, ...], device: torch.device | None = None
cls,
batch_shape: tuple[int, ...],
device: torch.device | None = None,
debug: bool = False,
) -> Actions:
"""Creates an Actions object filled with exit actions.

Args:
batch_shape: Shape of the batch dimensions.
device: The device to create the actions on.
debug: Whether to run debug validations on the constructed Actions.

Returns:
An Actions object with the specified batch shape filled with exit actions.
Expand All @@ -106,7 +118,7 @@ def make_exit_actions(
tensor = cls.exit_action.repeat(*batch_shape, *((1,) * action_ndim))
if device is not None:
tensor = tensor.to(device)
return cls(tensor)
return cls(tensor, debug=debug)

def __len__(self) -> int:
"""Returns the number of actions in the batch.
Expand Down Expand Up @@ -142,7 +154,7 @@ def __getitem__(
A new Actions object with the selected actions.
"""
actions = self.tensor[index]
return self.__class__(actions)
return self.__class__(actions, debug=self.debug)

def __setitem__(
self,
Expand All @@ -158,7 +170,7 @@ def __setitem__(
self.tensor[index] = actions.tensor

@classmethod
def stack(cls, actions_list: List[Actions]) -> Actions:
def stack(cls, actions_list: List[Actions], debug: bool | None = None) -> Actions:
"""Stacks a list of Actions objects along a new dimension (0).

The individual actions need to have the same batch shape. An example application
Expand All @@ -173,8 +185,13 @@ def stack(cls, actions_list: List[Actions]) -> Actions:
Returns:
A new Actions object with the stacked actions.
"""
if debug is None:
# Reuse caller-provided debug setting when available to keep behavior consistent.
debug = getattr(actions_list[0], "debug", False) if actions_list else False
debug = bool(debug)

actions_tensor = torch.stack([actions.tensor for actions in actions_list], dim=0)
return cls(actions_tensor)
return cls(actions_tensor, debug=debug)

def extend(self, other: Actions) -> None:
"""Concatenates another Actions object along the final batch dimension.
Expand Down Expand Up @@ -215,7 +232,7 @@ def extend_with_dummy_actions(self, required_first_dim: int) -> None:
return
n = required_first_dim - self.batch_shape[0]
dummy_actions = self.__class__.make_dummy_actions(
(n, self.batch_shape[1]), device=self.device
(n, self.batch_shape[1]), device=self.device, debug=self.debug
)
self.tensor = torch.cat((self.tensor, dummy_actions.tensor), dim=0)
else:
Expand All @@ -234,21 +251,41 @@ def _compare(self, other: torch.Tensor) -> torch.Tensor:
equal.
"""
n_batch_dims = len(self.batch_shape)
if n_batch_dims == 1:
assert (other.shape == self.action_shape) or (
other.shape == self.batch_shape + self.action_shape
), f"Expected shape {self.action_shape} or {self.batch_shape + self.action_shape}, got {other.shape}."
else:
assert (
other.shape == self.batch_shape + self.action_shape
), f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}."
if self.debug:
if n_batch_dims == 1:
# other.shape can either have only the action shape, or the
# flattened batch_shape + action_shape.
if other.shape not in (
self.action_shape,
self.batch_shape + self.action_shape,
):
raise ValueError(
(
f"Expected shape {self.action_shape} or "
f"{self.batch_shape + self.action_shape}, got {other.shape}."
)
)
else:
# other.shape must have the full batch and action shape.
if other.shape != self.batch_shape + self.action_shape:
raise ValueError(
(
f"Expected shape {self.batch_shape + self.action_shape}, "
f"got {other.shape}."
)
)

out = self.tensor == other
if len(self.action_shape) > 1:
out = out.flatten(start_dim=n_batch_dims)
out = out.all(dim=-1)

assert out.shape == self.batch_shape
if self.debug:
if out.shape != self.batch_shape:
raise ValueError(
f"Comparison output has shape {out.shape}, expected {self.batch_shape}."
)

return out

@property
Expand Down Expand Up @@ -287,7 +324,7 @@ def clone(self) -> Actions:
Returns:
A new Actions object with the same tensor.
"""
return self.__class__(self.tensor.clone())
return self.__class__(self.tensor.clone(), debug=self.debug)


class GraphActionType(enum.IntEnum):
Expand Down Expand Up @@ -329,19 +366,28 @@ class GraphActions(Actions):
EDGE_INDEX_KEY: 4,
}

def __init__(self, tensor: torch.Tensor):
# Required by the Actions base class for DB/SubTB style algorithms.
action_shape = (5,)
dummy_action = torch.tensor(
[GraphActionType.DUMMY, -2, -2, -2, -2], dtype=torch.long
)
exit_action = torch.tensor([GraphActionType.EXIT, -1, -1, -1, -1], dtype=torch.long)

def __init__(self, tensor: torch.Tensor, debug: bool = False):
"""Initializes a GraphActions object.

Args:
tensor: A tensor of shape (*batch_shape, 5) containing the action type,
node class, edge class, and edge index components.
"""
if tensor.shape[-1] != 5:
raise ValueError(
f"Expected tensor of shape (*batch_shape, 5), got {tensor.shape}.\n"
"The last dimension should contain the action type, node class, node index, edge class, and edge index."
)
if debug:
if tensor.shape[-1] != 5:
raise ValueError(
f"Expected tensor of shape (*batch_shape, 5), got {tensor.shape}.\n"
"The last dimension should contain the action type, node class, node index, edge class, and edge index."
)
self.tensor = tensor
self.debug = debug

@property
def batch_shape(self) -> tuple[int, ...]:
Expand All @@ -350,11 +396,14 @@ def batch_shape(self) -> tuple[int, ...]:
Returns:
The batch shape as a tuple.
"""
assert self.tensor.shape[-1] == 5
if self.debug:
assert self.tensor.shape[-1] == 5
return self.tensor.shape[:-1]

@classmethod
def from_tensor_dict(cls, tensor_dict: TensorDict) -> GraphActions:
def from_tensor_dict(
cls, tensor_dict: TensorDict, debug: bool = False
) -> GraphActions:
"""Creates a GraphActions object from a tensor dict.

Args:
Expand All @@ -374,7 +423,8 @@ def from_tensor_dict(cls, tensor_dict: TensorDict) -> GraphActions:
return cls(
torch.cat(
[action_type, node_class, node_index, edge_class, edge_index], dim=-1
)
),
debug=debug,
)

def __repr__(self):
Expand Down Expand Up @@ -450,7 +500,10 @@ def edge_index(self) -> torch.Tensor:

@classmethod
def make_dummy_actions(
cls, batch_shape: tuple[int], device: torch.device
cls,
batch_shape: tuple[int],
device: torch.device | None = None,
debug: bool = False,
) -> GraphActions:
"""Creates a GraphActions object filled with dummy actions.

Expand All @@ -462,13 +515,20 @@ def make_dummy_actions(
A GraphActions object with the specified batch shape filled with dummy
actions.
"""
tensor = torch.zeros(batch_shape + (5,), dtype=torch.long, device=device)
tensor = torch.zeros(
batch_shape + (5,),
dtype=torch.long,
device=device,
)
tensor[..., cls.ACTION_INDICES[cls.ACTION_TYPE_KEY]] = GraphActionType.DUMMY
return cls(tensor)
return cls(tensor, debug=debug)

@classmethod
def make_exit_actions(
cls, batch_shape: tuple[int], device: torch.device
cls,
batch_shape: tuple[int],
device: torch.device | None = None,
debug: bool = False,
) -> GraphActions:
"""Creates a GraphActions object filled with exit actions.

Expand All @@ -479,9 +539,13 @@ def make_exit_actions(
Returns:
A GraphActions object with the specified batch shape filled with exit actions.
"""
tensor = torch.zeros(batch_shape + (5,), dtype=torch.long, device=device)
tensor = torch.zeros(
batch_shape + (5,),
dtype=torch.long,
device=device,
)
tensor[..., cls.ACTION_INDICES[cls.ACTION_TYPE_KEY]] = GraphActionType.EXIT
return cls(tensor)
return cls(tensor, debug=debug)

@classmethod
def edge_index_action_to_src_dst(
Expand Down
Loading