diff --git a/src/gfn/env.py b/src/gfn/env.py index 646094f0..3146ceb0 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,7 +1,14 @@ import warnings from abc import ABC, abstractmethod from collections import Counter -from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Dict, + Optional, + Tuple, + cast, +) if TYPE_CHECKING: from gfn.gflownet import GFlowNet @@ -17,6 +24,25 @@ NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) +class EnvFastPathMixin: + """Marker mixin for environments exposing tensor-only fast-path helpers. + + Environments inheriting this mixin are expected to override: + + - ``step_tensor``: vectorized transition operating purely on tensors. + - ``forward_action_masks_tensor``: tensor-based forward action masks. + - ``states_from_tensor_fast``: lightweight wrapper that avoids redundant + allocations when reconstructing ``States`` objects from raw tensors. + + The mixin itself does not provide implementations; it purely signals that + the environment intends to support the fast path and enables nominal checks + such as ``isinstance(env, EnvFastPathMixin)`` without relying on structural + typing. + """ + + fast_path_enabled: bool = True + + class Env(ABC): """Base class for all environments. @@ -37,6 +63,22 @@ class Env(ABC): is_discrete: bool = False + @dataclass + class TensorStepResult: + """Container returned by tensor-level step helpers. + + Attributes: + next_states: Tensor containing the next states produced by the step. + is_sink_state: Optional boolean tensor indicating which rows are sink. + forward_masks: Optional boolean tensor with forward action masks. + backward_masks: Optional boolean tensor with backward action masks. + """ + + next_states: torch.Tensor + is_sink_state: torch.Tensor | None = None + forward_masks: torch.Tensor | None = None + backward_masks: torch.Tensor | None = None + def __init__( self, s0: torch.Tensor | GeometricData, @@ -145,6 +187,51 @@ def actions_from_batch_shape(self, batch_shape: Tuple) -> Actions: """ return self.Actions.make_dummy_actions(batch_shape, device=self.device) + @property + def has_tensor_fast_path(self) -> bool: + """Whether this environment opts into the tensor-only fast API.""" + + return isinstance(self, EnvFastPathMixin) + + def states_from_tensor_fast(self, tensor: torch.Tensor) -> States: + """Fallback helper recreating ``States`` objects from tensors. + + Fast-path environments can override this to avoid redundant mask + recomputation or to attach cached metadata. The default simply calls + ``states_from_tensor``. + """ + + return self.states_from_tensor(tensor) + + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> "Env.TensorStepResult": + """Tensor equivalent of `_step` with default object-based fallback. + + Environments can override this method to provide compiler-friendly + implementations that avoid constructing `States`/`Actions`. The default + fallback simply wraps tensors into the standard containers and delegates + to `_step`, ensuring parity with the legacy path. + """ + + states = self.states_from_tensor(states_tensor.clone()) + actions = self.actions_from_tensor(actions_tensor.clone()) + new_states = self._step(states, actions) + return self.TensorStepResult(next_states=new_states.tensor.clone()) + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Tensor helper returning forward masks for the supplied states. + + Base environments do not provide a generic implementation because mask + semantics are environment-specific. Subclasses (e.g., ``DiscreteEnv``) + are expected to override this to expose a fallback compatible with the + fast sampler path. + """ + + raise NotImplementedError( + f"{self.__class__.__name__} does not expose tensor forward masks." + ) + @abstractmethod def step(self, states: States, actions: Actions) -> States: """Forward transition function of the environment. @@ -559,6 +646,21 @@ def states_from_batch_shape( assert isinstance(out, DiscreteStates) return out + def states_from_tensor_fast(self, tensor: torch.Tensor) -> DiscreteStates: + """Return `DiscreteStates` without extra bookkeeping for fast paths.""" + + states = self.states_from_tensor(tensor) + assert isinstance(states, DiscreteStates) + return states + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Recompute forward masks for the supplied state tensor.""" + + states = self.states_from_tensor(states_tensor.clone()) + self.update_masks(states) + assert states.forward_masks is not None + return states.forward_masks.clone() + def reset( self, batch_shape: int | Tuple[int, ...], diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 9cb6839d..ac727a5c 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -241,6 +241,45 @@ def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: return getattr(ctx, "current_estimator_output", None) +class FastPolicyMixin(PolicyMixin): + """Optional mixin for policies that ingest tensors directly on fast paths. + + Estimators inheriting this mixin should implement the tensor-oriented hooks + below so samplers can bypass `States`/`Actions` allocation when environments + expose compatible helpers. + """ + + fast_path_enabled: bool = True + + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: Optional[torch.Tensor] = None, + backward_masks: Optional[torch.Tensor] = None, + conditions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Preprocess raw tensors into module-ready features.""" + + raise NotImplementedError( + f"{self.__class__.__name__} does not implement fast_features." + ) + + def fast_distribution( + self, + features: torch.Tensor, + *, + forward_masks: Optional[torch.Tensor] = None, + backward_masks: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> Distribution: + """Build the action distribution from tensor features.""" + + raise NotImplementedError( + f"{self.__class__.__name__} does not implement fast_distribution." + ) + + class RecurrentPolicyMixin(PolicyMixin): """Mixin for recurrent policies that maintain and update a rollout carry.""" @@ -1227,7 +1266,7 @@ def init_carry( return init_carry_fn(batch_size, device) -class DiffusionPolicyEstimator(PolicyMixin, Estimator): +class DiffusionPolicyEstimator(FastPolicyMixin, Estimator): """Base class for diffusion policy estimators.""" def __init__(self, s_dim: int, module: nn.Module, is_backward: bool = False): @@ -1282,6 +1321,16 @@ def to_probability_distribution( """ raise NotImplementedError + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + class PinnedBrownianMotionForward(DiffusionPolicyEstimator): # TODO: support OU process def __init__( @@ -1345,8 +1394,13 @@ def to_probability_distribution( A IsotropicGaussian distribution (distribution of the next states) """ assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" - s_curr = states.tensor[:, :-1] - t_curr = states.tensor[:, [-1]] + return self._distribution_from_tensor(states.tensor, module_output) + + def _distribution_from_tensor( + self, states_tensor: torch.Tensor, module_output: torch.Tensor + ) -> IsotropicGaussian: + s_curr = states_tensor[:, :-1] + t_curr = states_tensor[:, [-1]] module_output = torch.where( (1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0 @@ -1359,6 +1413,22 @@ def to_probability_distribution( fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1) return IsotropicGaussian(fwd_mean, fwd_std) + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> IsotropicGaussian: + if states_tensor is None: + raise ValueError( + "states_tensor is required for PinnedBrownianMotionForward fast path." + ) + module_output = self.module(features) + return self._distribution_from_tensor(states_tensor, module_output) + class PinnedBrownianMotionBackward(DiffusionPolicyEstimator): # TODO: support OU process def __init__( @@ -1422,10 +1492,15 @@ def to_probability_distribution( A IsotropicGaussian distribution (distribution of the previous states) """ assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" - s_curr = states.tensor[:, :-1] - t_curr = states.tensor[:, [-1]] # shape: (*batch_shape,) + return self._distribution_from_tensor(states.tensor, module_output) + + def _distribution_from_tensor( + self, states_tensor: torch.Tensor, module_output: torch.Tensor + ) -> IsotropicGaussian: + s_curr = states_tensor[:, :-1] + t_curr = states_tensor[:, [-1]] - is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0 + is_s0 = (t_curr - self.dt) < self.dt * 1e-2 bwd_mean = torch.where( is_s0, s_curr, @@ -1437,3 +1512,19 @@ def to_probability_distribution( self.sigma * (self.dt * (t_curr - self.dt) / t_curr).sqrt(), ) return IsotropicGaussian(bwd_mean, bwd_std) + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> IsotropicGaussian: + if states_tensor is None: + raise ValueError( + "states_tensor is required for PinnedBrownianMotionBackward fast path." + ) + module_output = self.module(features) + return self._distribution_from_tensor(states_tensor, module_output) diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 37296c61..e338c86d 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -492,18 +492,32 @@ def get_geometric_within_contributions( Returns: The contributions tensor of shape (max_len * (max_len+1) / 2, n_trajectories). """ - L = self.lamda + max_len = trajectories.max_length - t_idx = trajectories.terminating_idx + if max_len == 0 or len(trajectories) == 0: + return torch.zeros( + (0, len(trajectories)), + device=trajectories.device, + dtype=torch.get_default_dtype(), + ) - # The following tensor represents the weights given to each possible - # sub-trajectory length. - contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to( - torch.get_default_dtype() - ) - contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories)) + dtype = torch.get_default_dtype() + device = trajectories.device + t_idx = trajectories.terminating_idx.to(dtype) + + # Clamp lambda away from 0/1 to avoid divisions by zero or log(0) while keeping + # the computation compatible with torch.compile. + lamda = torch.as_tensor(self.lamda, device=device, dtype=dtype) + finfo = torch.finfo(dtype) + lamda = torch.clamp(lamda, finfo.tiny, 1 - finfo.eps) + + # Geometric weights for each possible sub-trajectory length, computed in log + # space to reduce error when lamda is close to 1. + lengths = torch.arange(max_len, device=device, dtype=dtype) + log_weights = lengths * torch.log(lamda) + contributions = torch.exp(log_weights).unsqueeze(-1).repeat(1, len(trajectories)) contributions = contributions.repeat_interleave( - torch.arange(max_len, 0, -1, device=t_idx.device), + torch.arange(max_len, 0, -1, device=device), dim=0, output_size=int(max_len * (max_len + 1) / 2), ) @@ -512,13 +526,14 @@ def get_geometric_within_contributions( # where n is the length of the trajectory corresponding to that column # We can do it the ugly way, or using the cool identity: # https://www.wolframalpha.com/input?i=sum%28%28n-i%29+*+lambda+%5Ei%2C+i%3D0..n%29 - per_trajectory_denom = ( - 1.0 - / (1 - L) ** 2 - * (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double()) - ).to(torch.get_default_dtype()) - contributions = contributions / per_trajectory_denom / len(trajectories) + # Closed-form normalization: + # sum_{i=0}^{n-1} (n - i) * lamda^i + lamda_pow_n = torch.pow(lamda, t_idx) + numerator = lamda * (lamda_pow_n - 1) + (1 - lamda) * t_idx + denominator = (1 - lamda) ** 2 + per_trajectory_denom = numerator / denominator + contributions = contributions / per_trajectory_denom / len(trajectories) return contributions def loss( diff --git a/src/gfn/gym/bitSequence.py b/src/gfn/gym/bitSequence.py index 2aca72bf..cc13eb9a 100644 --- a/src/gfn/gym/bitSequence.py +++ b/src/gfn/gym/bitSequence.py @@ -6,12 +6,14 @@ from gfn.actions import Actions from gfn.containers import Trajectories -from gfn.env import DiscreteEnv +from gfn.env import DiscreteEnv, EnvFastPathMixin from gfn.states import DiscreteStates from gfn.utils.common import is_int_dtype -# This environment is the torchgfn implmentation of the bit sequences task presented in :Malkin, Nikolay & Jain, Moksh & Bengio, Emmanuel & Sun, Chen & Bengio, Yoshua. (2022). -# Trajectory Balance: Improved Credit Assignment in GFlowNets. https://arxiv.org/pdf/2201.13259 +# This environment is the torchgfn implmentation of the bit sequences task presented in +# :Malkin, Nikolay & Jain, Moksh & Bengio, Emmanuel & Sun, Chen & Bengio, Yoshua. +# (2022). Trajectory Balance: Improved Credit Assignment in GFlowNets. +# https://arxiv.org/pdf/2201.13259 class BitSequenceStates(DiscreteStates): @@ -186,7 +188,7 @@ def row_to_binary_string(row, row_mask): return [row_to_binary_string(tensor[i], mask[i]) for i in range(tensor.shape[0])] -class BitSequence(DiscreteEnv): +class BitSequence(EnvFastPathMixin, DiscreteEnv): """Append-only BitSequence environment. This environment represents a sequence of binary words and provides methods to @@ -347,6 +349,18 @@ def update_masks(self, states: BitSequenceStates) -> None: ) states.backward_masks[~is_sink, last_actions] = True + def _lengths_from_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + return torch.count_nonzero(states_tensor != -1, dim=-1) + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.ones((batch, self.n_actions), dtype=torch.bool, device=device) + lengths = self._lengths_from_tensor(states_tensor) + masks[lengths == self.words_per_seq, :-1] = False + masks[lengths < self.words_per_seq, -1] = False + return masks + def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates: """Performs a step in the environment. @@ -368,6 +382,56 @@ def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates ) return self.States(old_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 2 and actions_tensor.shape[-1] == 1: + actions_vals = actions_tensor.squeeze(-1) + else: + actions_vals = actions_tensor + + exit_val = self.n_actions - 1 + is_exit = actions_vals == exit_val + next_states = states_tensor.clone() + + lengths = self._lengths_from_tensor(states_tensor) + + non_exit_idx = (~is_exit).nonzero(as_tuple=True)[0] + if len(non_exit_idx) > 0: + insert_pos = lengths[non_exit_idx] + next_states[non_exit_idx, insert_pos] = actions_vals[non_exit_idx] + + if is_exit.any(): + sink_row = torch.full( + (self.words_per_seq,), + exit_val, + dtype=torch.long, + device=states_tensor.device, + ) + next_states[is_exit] = sink_row + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros( + (next_states.shape[0], self.n_actions - 1), + dtype=torch.bool, + device=states_tensor.device, + ) + is_sink_state = torch.all(next_states == exit_val, dim=-1) + non_sink = ~is_sink_state + if non_sink.any(): + new_lengths = self._lengths_from_tensor(next_states[non_sink]) + last_idx = torch.clamp(new_lengths - 1, min=0) + rows = non_sink.nonzero(as_tuple=True)[0] + last_actions = next_states[rows, last_idx] + backward_masks[rows, last_actions] = True + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step( self, states: BitSequenceStates, actions: Actions ) -> BitSequenceStates: @@ -794,6 +858,69 @@ def update_masks(self, states: BitSequenceStates) -> None: states.backward_masks[~is_sink, last_actions] = True states.backward_masks[~is_sink, first_actions + (self.n_actions - 1) // 2] = True + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + return super().forward_action_masks_tensor(states_tensor) + + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 2 and actions_tensor.shape[-1] == 1: + actions_vals = actions_tensor.squeeze(-1) + else: + actions_vals = actions_tensor + + exit_val = self.n_actions - 1 + append_threshold = (self.n_actions - 1) // 2 + is_exit = actions_vals == exit_val + append_mask = actions_vals < append_threshold + prepend_mask = (~is_exit) & (~append_mask) + + next_states = states_tensor.clone() + lengths = self._lengths_from_tensor(states_tensor) + + if append_mask.any(): + idx = append_mask.nonzero(as_tuple=True)[0] + insert_pos = lengths[idx] + next_states[idx, insert_pos] = actions_vals[idx] + + if prepend_mask.any(): + idx = prepend_mask.nonzero(as_tuple=True)[0] + next_states[idx, 1:] = next_states[idx, :-1] + next_states[idx, 0] = actions_vals[idx] - append_threshold + + if is_exit.any(): + sink_row = torch.full( + (self.words_per_seq,), + exit_val, + dtype=torch.long, + device=states_tensor.device, + ) + next_states[is_exit] = sink_row + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros( + (next_states.shape[0], self.n_actions - 1), + dtype=torch.bool, + device=states_tensor.device, + ) + is_sink_state = torch.all(next_states == exit_val, dim=-1) + non_sink = ~is_sink_state + if non_sink.any(): + new_lengths = self._lengths_from_tensor(next_states[non_sink]) + last_idx = torch.clamp(new_lengths - 1, min=0) + rows = non_sink.nonzero(as_tuple=True)[0] + last_actions = next_states[rows, last_idx] + first_actions = next_states[rows, 0] + backward_masks[rows, last_actions] = True + backward_masks[rows, first_actions + append_threshold] = True + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates: """Performs a step in the environment. @@ -808,16 +935,18 @@ def step(self, states: BitSequenceStates, actions: Actions) -> BitSequenceStates old_tensor = states.tensor.clone() append_mask = (actions.tensor < (self.n_actions - 1) // 2).squeeze() prepend_mask = ~append_mask - assert states.length - old_tensor[append_mask & ~is_exit, states.length[append_mask & ~is_exit]] = ( - actions.tensor[append_mask & ~is_exit].squeeze() - ) + assert states.length is not None + append_rows = append_mask & ~is_exit + old_tensor[append_rows, states.length[append_rows]] = actions.tensor[ + append_rows + ].squeeze() old_tensor[prepend_mask & ~is_exit, 1:] = old_tensor[ prepend_mask & ~is_exit, :-1 ] - old_tensor[prepend_mask & ~is_exit, 0] = ( - actions.tensor[prepend_mask & ~is_exit].squeeze() - (self.n_actions - 1) // 2 + prepend_rows = prepend_mask & ~is_exit + old_tensor[prepend_rows, 0] = ( + actions.tensor[prepend_rows].squeeze() - (self.n_actions - 1) // 2 ) old_tensor[is_exit] = torch.full_like( diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 583d0b90..90ef4df4 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -4,11 +4,11 @@ import torch from gfn.actions import Actions -from gfn.env import Env +from gfn.env import Env, EnvFastPathMixin from gfn.states import States -class Box(Env): +class Box(EnvFastPathMixin, Env): """Box environment, corresponding to the one in Section 4.1 of https://arxiv.org/abs/2301.12594 Attributes: @@ -101,6 +101,26 @@ def backward_step(self, states: States, actions: Actions) -> States: """ return self.States(states.tensor - actions.tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> Env.TensorStepResult: + next_states = states_tensor.clone() + exit_action = self.exit_action.to(states_tensor.device).to(states_tensor.dtype) + exit_mask = torch.all(actions_tensor == exit_action, dim=-1) + non_exit = ~exit_mask + + next_states[non_exit] = next_states[non_exit] + actions_tensor[non_exit] + + if exit_mask.any(): + assert isinstance(self.sf, torch.Tensor) + sf_tensor = self.sf.to(states_tensor.device).to(states_tensor.dtype) + next_states[exit_mask] = sf_tensor + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=exit_mask.clone(), + ) + @staticmethod def norm(x: torch.Tensor) -> torch.Tensor: """Computes the L2 norm of the input tensor along the last dimension. diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 72e5975e..4ed88f55 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -11,7 +11,7 @@ from scipy.stats import wishart from gfn.actions import Actions -from gfn.env import Env +from gfn.env import Env, EnvFastPathMixin from gfn.gym.helpers.diffusion_utils import viz_2d_slice from gfn.states import States from gfn.utils.common import filter_kwargs_for_callable, temporarily_set_seed @@ -672,7 +672,7 @@ def visualize( ###################################### -class DiffusionSampling(Env): +class DiffusionSampling(EnvFastPathMixin, Env): """Diffusion sampling environment. Attributes: @@ -802,6 +802,45 @@ def step(self, states: States, actions: Actions) -> States: next_states_tensor[..., -1] = next_states_tensor[..., -1] + self.dt return self.States(next_states_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> Env.TensorStepResult: + """Tensor fast-path equivalent of `_step`. + + Mirrors the legacy wrapper by skipping already-sink rows, applying the action + update to the remaining states, and forcing exit actions onto the sink state. + """ + + assert states_tensor.shape[-1] == self.dim + 1 + assert actions_tensor.shape[-1] == self.dim + + device = states_tensor.device + dtype = states_tensor.dtype + sf_tensor = cast(torch.Tensor, self.sf).to(device=device, dtype=dtype) + exit_action = self.exit_action.to(device=device, dtype=dtype) + + # Detect rows that are already padded sink states, and exit rows that should + # transition to the sink regardless of their current state. + sink_mask = torch.all(states_tensor == sf_tensor, dim=-1) + exit_mask = torch.all(actions_tensor == exit_action, dim=-1) + update_mask = ~(sink_mask | exit_mask) + + next_states = states_tensor.clone() + if update_mask.any(): + next_states[update_mask, :-1] = ( + next_states[update_mask, :-1] + actions_tensor[update_mask] + ) + dt = torch.as_tensor(self.dt, device=device, dtype=dtype) + next_states[update_mask, -1] = next_states[update_mask, -1] + dt + + if exit_mask.any(): + next_states[exit_mask] = sf_tensor + + next_sink_mask = sink_mask | exit_mask + return self.TensorStepResult( + next_states=next_states, is_sink_state=next_sink_mask + ) + def backward_step(self, states: States, actions: Actions) -> States: """Backward step function for the SimpleGaussianMixtureModel environment. diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 3ed1c5eb..9162badf 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -5,7 +5,7 @@ import torch.nn as nn from gfn.actions import Actions -from gfn.env import DiscreteEnv +from gfn.env import DiscreteEnv, EnvFastPathMixin from gfn.states import DiscreteStates, States @@ -59,7 +59,7 @@ def forward(self, states: torch.Tensor) -> torch.Tensor: return -(states * tmp).sum(-1) -class DiscreteEBM(DiscreteEnv): +class DiscreteEBM(EnvFastPathMixin, DiscreteEnv): """Environment for discrete energy-based models. This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf. @@ -132,6 +132,16 @@ def update_masks(self, states: DiscreteStates) -> None: states.backward_masks[..., : self.ndim] = states.tensor == 0 states.backward_masks[..., self.ndim : 2 * self.ndim] = states.tensor == 1 + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=device) + available = states_tensor == -1 + masks[:, : self.ndim] = available + masks[:, self.ndim : 2 * self.ndim] = available + masks[:, -1] = torch.all(states_tensor != -1, dim=-1) + return masks + def make_random_states( self, batch_shape: Tuple, device: torch.device | None = None ) -> DiscreteStates: @@ -186,6 +196,49 @@ def step(self, states: States, actions: Actions) -> States: ) return self.States(states.tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 1: + actions_idx = actions_tensor + else: + actions_idx = actions_tensor.squeeze(-1) + + exit_idx = self.n_actions - 1 + next_states = states_tensor.clone() + device = states_tensor.device + + is_exit = actions_idx == exit_idx + mask0 = (actions_idx < self.ndim) & ~is_exit + mask1 = (actions_idx >= self.ndim) & (actions_idx < 2 * self.ndim) & ~is_exit + + if mask0.any(): + rows = mask0.nonzero(as_tuple=True)[0] + cols = actions_idx[rows] + next_states[rows, cols] = 0 + + if mask1.any(): + rows = mask1.nonzero(as_tuple=True)[0] + cols = actions_idx[rows] - self.ndim + next_states[rows, cols] = 1 + + if is_exit.any(): + next_states[is_exit] = self.sf.to(device=device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros_like(forward_masks) + backward_masks[:, : self.ndim] = next_states == 0 + backward_masks[:, self.ndim : 2 * self.ndim] = next_states == 1 + + is_sink_state = torch.all(next_states == self.sf.to(device=device), dim=-1) + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step(self, states: States, actions: Actions) -> States: """Performs a backward step. diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 57933bbf..379d283d 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -8,7 +8,7 @@ from torch import Size, Tensor from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily -from gfn.estimators import Estimator, PolicyMixin +from gfn.estimators import Estimator, FastPolicyMixin from gfn.gym import Box from gfn.states import States from gfn.utils.modules import MLP @@ -936,7 +936,7 @@ def split_PF_module_output( return (exit_probability, mixture_logits, alpha_theta, beta_theta, alpha_r, beta_r) -class BoxPFEstimator(Estimator, PolicyMixin): +class BoxPFEstimator(FastPolicyMixin, Estimator): r"""Estimator for `P_F` for the Box environment. This estimator uses the `DistributionWrapper` distribution. @@ -978,6 +978,7 @@ def __init__( max_concentration: The maximum concentration for the Beta distributions. """ super().__init__(module) + self.env = env self._n_comp_max = max(n_components_s0, n_components) self.n_components_s0 = n_components_s0 self.n_components = n_components @@ -1059,8 +1060,33 @@ def _normalize(x: Tensor) -> Tensor: self.n_components_s0, ) + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> Distribution: + if states_tensor is None: + raise ValueError("states_tensor is required for BoxPFEstimator fast path.") + module_output = self.module(features) + states = self.env.states_from_tensor_fast(states_tensor) + return self.to_probability_distribution(states, module_output) + -class BoxPBEstimator(Estimator, PolicyMixin): +class BoxPBEstimator(FastPolicyMixin, Estimator): r"""Estimator for `P_B` for the Box environment. This estimator uses the `QuarterCircle(northeastern=False)` distribution. @@ -1096,6 +1122,7 @@ def __init__( """ super().__init__(module, is_backward=True) self.module = module + self.env = env self.n_components = n_components self.min_concentration = min_concentration @@ -1145,3 +1172,28 @@ def _normalize(x: Tensor) -> Tensor: alpha=alpha, beta=beta, ) + + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs: Any, + ) -> Distribution: + if states_tensor is None: + raise ValueError("states_tensor is required for BoxPBEstimator fast path.") + module_output = self.module(features) + states = self.env.states_from_tensor_fast(states_tensor) + return self.to_probability_distribution(states, module_output) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 1215b5f5..db6d97b8 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -14,7 +14,7 @@ import torch from gfn.actions import Actions -from gfn.env import DiscreteEnv +from gfn.env import DiscreteEnv, EnvFastPathMixin from gfn.states import DiscreteStates from gfn.utils.common import ensure_same_device @@ -48,7 +48,7 @@ def smallest_multiplier_to_integers(float_vector, precision=3): return smallest_multiplier -class HyperGrid(DiscreteEnv): +class HyperGrid(EnvFastPathMixin, DiscreteEnv): """HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in @@ -159,6 +159,15 @@ def update_masks(self, states: DiscreteStates) -> None: ) states.backward_masks = states.tensor != 0 + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Tensor-only equivalent of `update_masks` for forward masks.""" + + base = states_tensor != (self.height - 1) + exit_column = torch.ones( + (states_tensor.shape[0], 1), dtype=torch.bool, device=states_tensor.device + ) + return torch.cat([base, exit_column], dim=-1) + def make_random_states( self, batch_shape: Tuple[int, ...], device: torch.device | None = None ) -> DiscreteStates: @@ -191,6 +200,51 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: assert new_states_tensor.shape == states.tensor.shape return self.States(new_states_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + """Tensor-only transition combined with mask outputs for fast paths.""" + + assert states_tensor.dtype == torch.long + if actions_tensor.ndim == 1: + actions_idx = actions_tensor.view(-1, 1) + else: + assert actions_tensor.shape[-1] == 1 + actions_idx = actions_tensor + + exit_idx = self.n_actions - 1 + is_exit_action = actions_idx.squeeze(-1) == exit_idx + next_states = states_tensor.clone() + + non_exit_mask = ~is_exit_action + non_exit_mask_exp = non_exit_mask.unsqueeze(-1) + safe_actions = torch.where( + non_exit_mask_exp, actions_idx, torch.zeros_like(actions_idx) + ) + delta = torch.zeros_like(next_states) + delta = delta.scatter(-1, safe_actions, 1, reduce="add") + delta = delta * non_exit_mask_exp.to(next_states.dtype) + next_states = next_states + delta + + sink_state = self.sf.to(device=states_tensor.device) + while sink_state.ndim < next_states.ndim: + sink_state = sink_state.unsqueeze(0) + sink_state = sink_state.expand_as(next_states) + next_states = torch.where(is_exit_action.unsqueeze(-1), sink_state, next_states) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = next_states != 0 + is_sink_state = (next_states == self.sf.to(device=states_tensor.device)).all( + dim=-1 + ) + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a backward step in the environment. diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index e4b34b8c..b2af9df5 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -4,11 +4,11 @@ from torch.distributions import Normal # TODO: extend to Beta from gfn.actions import Actions -from gfn.env import Env +from gfn.env import Env, EnvFastPathMixin from gfn.states import States -class Line(Env): +class Line(EnvFastPathMixin, Env): """Mixture of Gaussians Line environment. Attributes: @@ -84,6 +84,32 @@ def step(self, states: States, actions: Actions) -> States: assert states.tensor.shape == states.batch_shape + (2,) return self.States(states.tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> Env.TensorStepResult: + next_states = states_tensor.clone() + if actions_tensor.ndim == 2 and actions_tensor.shape[-1] == 1: + action_vals = actions_tensor.squeeze(-1) + else: + action_vals = actions_tensor + + exit_val = float(self.exit_action.item()) + exit_mask = action_vals == exit_val + non_exit = ~exit_mask + + next_states[non_exit, 0] = next_states[non_exit, 0] + action_vals[non_exit] + next_states[non_exit, 1] = next_states[non_exit, 1] + 1 + + if exit_mask.any(): + assert isinstance(self.sf, torch.Tensor) + sf_tensor = self.sf.to(states_tensor.device) + sf_tensor = sf_tensor.to(states_tensor.dtype) + next_states[exit_mask] = sf_tensor + + return self.TensorStepResult( + next_states=next_states, is_sink_state=exit_mask.clone() + ) + def backward_step(self, states: States, actions: Actions) -> States: """Performs a backward step in the environment. diff --git a/src/gfn/gym/perfect_tree.py b/src/gfn/gym/perfect_tree.py index 7c3512da..220ef5a0 100644 --- a/src/gfn/gym/perfect_tree.py +++ b/src/gfn/gym/perfect_tree.py @@ -2,11 +2,11 @@ import torch -from gfn.env import Actions, DiscreteEnv, DiscreteStates +from gfn.env import Actions, DiscreteEnv, DiscreteStates, EnvFastPathMixin from gfn.states import States -class PerfectBinaryTree(DiscreteEnv): +class PerfectBinaryTree(EnvFastPathMixin, DiscreteEnv): r"""Perfect Tree Environment. This environment is a perfect binary tree, where there is a bijection between @@ -75,6 +75,8 @@ def __init__( self.inverse_transition_table, self.term_states, ) = self._build_tree() + self._leaf_lower = 2**self.depth - 1 + self._leaf_upper = 2 ** (self.depth + 1) - 1 def _build_tree(self) -> tuple[dict, dict, DiscreteStates]: """Builds the tree and the transition tables. @@ -192,6 +194,76 @@ def update_masks(self, states: DiscreteStates) -> None: # Initial state has no available backward action states.backward_masks[initial_state_mask] = False + def _is_leaf_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + values = states_tensor.view(-1) + return (values >= self._leaf_lower) & (values < self._leaf_upper) + + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=device) + leaf_mask = self._is_leaf_tensor(states_tensor) + sink_mask = (states_tensor == self.sf.to(device)).all(dim=-1) + non_leaf = ~(leaf_mask | sink_mask) + masks[non_leaf, : self.branching_factor] = True + masks[leaf_mask | sink_mask, -1] = True + return masks + + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + if actions_tensor.ndim == 1: + actions_idx = actions_tensor.view(-1, 1) + else: + assert actions_tensor.shape[-1] == 1 + actions_idx = actions_tensor + + exit_idx = self.n_actions - 1 + device = states_tensor.device + next_states = states_tensor.clone() + actions_flat = actions_idx.squeeze(-1) + state_vals = next_states.squeeze(-1) + + is_exit = actions_flat == exit_idx + non_exit = ~is_exit + if non_exit.any(): + parents = state_vals[non_exit] + child_idx = parents.clone() + left_mask = actions_flat[non_exit] == 0 + right_mask = actions_flat[non_exit] == 1 + if left_mask.any(): + child_idx[left_mask] = 2 * parents[left_mask] + 1 + if right_mask.any(): + child_idx[right_mask] = 2 * parents[right_mask] + 2 + next_states[non_exit, 0] = child_idx + + if is_exit.any(): + next_states[is_exit] = self.sf.to(device=device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros( + (next_states.shape[0], self.branching_factor), + dtype=torch.bool, + device=device, + ) + next_vals = next_states.squeeze(-1) + sink_mask = (next_states == self.sf.to(device)).all(dim=-1) + initial_mask = next_vals == self.s0.item() + even_mask = (next_vals % 2 == 0) & ~sink_mask + odd_mask = (next_vals % 2 == 1) & ~sink_mask + backward_masks[even_mask, 1] = True + backward_masks[odd_mask, 0] = True + backward_masks[initial_mask] = False + + is_sink_state = sink_mask + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def get_states_indices(self, states: States): """Returns the indices of the states. diff --git a/src/gfn/gym/set_addition.py b/src/gfn/gym/set_addition.py index 65f0707a..71ea58df 100644 --- a/src/gfn/gym/set_addition.py +++ b/src/gfn/gym/set_addition.py @@ -2,10 +2,10 @@ import torch -from gfn.env import Actions, DiscreteEnv, DiscreteStates +from gfn.env import Actions, DiscreteEnv, DiscreteStates, EnvFastPathMixin -class SetAddition(DiscreteEnv): +class SetAddition(EnvFastPathMixin, DiscreteEnv): """Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html) @@ -118,6 +118,41 @@ def update_masks(self, states: DiscreteStates) -> None: states.backward_masks[..., : self.n_items] = states.tensor != 0 + def forward_action_masks_tensor(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Tensor equivalent of `update_masks` for forward masks.""" + + batch = states_tensor.shape[0] + device = states_tensor.device + masks = torch.zeros((batch, self.n_actions), dtype=torch.bool, device=device) + + n_items_per_state = states_tensor.sum(dim=-1) + states_that_must_end = n_items_per_state >= self.max_traj_len + states_that_may_continue = ~states_that_must_end + + if states_that_may_continue.any(): + cont_states = states_tensor[states_that_may_continue] == 0 + cont_masks = torch.zeros( + (cont_states.shape[0], self.n_actions), + dtype=torch.bool, + device=device, + ) + cont_masks[:, : self.n_items] = cont_states + masks[states_that_may_continue] = cont_masks + + if states_that_must_end.any(): + end_masks = torch.zeros( + (int(states_that_must_end.sum().item()), self.n_actions), + dtype=torch.bool, + device=device, + ) + end_masks[:, -1] = True + masks[states_that_must_end] = end_masks + + if not self.fixed_length: + masks[..., -1] = True + + return masks + def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a step in the environment. @@ -131,6 +166,45 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add") return self.States(new_states_tensor) + def step_tensor( + self, states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> DiscreteEnv.TensorStepResult: + """Tensor-only transition mirroring the legacy `_step` path.""" + + if actions_tensor.ndim == 1: + actions_idx = actions_tensor.view(-1, 1) + else: + assert actions_tensor.shape[-1] == 1 + actions_idx = actions_tensor + + exit_idx = self.n_actions - 1 + is_exit = actions_idx.squeeze(-1) == exit_idx + next_states = states_tensor.clone() + + non_exit_mask = ~is_exit + if torch.any(non_exit_mask): + sel_states = next_states[non_exit_mask] + sel_actions = actions_idx[non_exit_mask] + sel_states = sel_states.scatter(-1, sel_actions, 1, reduce="add") + next_states[non_exit_mask] = sel_states + + if torch.any(is_exit): + next_states[is_exit] = self.sf.to(device=states_tensor.device) + + forward_masks = self.forward_action_masks_tensor(next_states) + backward_masks = torch.zeros_like(forward_masks) + backward_masks[..., : self.n_items] = next_states != 0 + is_sink_state = (next_states == self.sf.to(device=states_tensor.device)).all( + dim=-1 + ) + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_sink_state, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """Performs a backward step in the environment. diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index a4d74409..070c4965 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,17 +1,34 @@ -from typing import Any, List, Optional, Tuple, cast +import warnings +from typing import Any, Callable, List, Optional, Tuple, cast import torch from gfn.actions import Actions from gfn.containers import Trajectories -from gfn.env import Env -from gfn.estimators import Estimator, PolicyEstimatorProtocol +from gfn.env import Env, EnvFastPathMixin +from gfn.estimators import Estimator, FastPolicyMixin, PolicyEstimatorProtocol from gfn.states import GraphStates, States from gfn.utils.common import ensure_same_device from gfn.utils.graphs import graph_states_share_storage from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs +def _mark_cudagraph_step() -> None: + compiler = getattr(torch, "compiler", None) + if compiler is None: + return + marker = getattr(compiler, "cudagraph_mark_step_begin", None) + if callable(marker): + marker() + + +def _fill_like_reference(reference: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + fill = value.to(device=reference.device, dtype=reference.dtype) + while fill.ndim < reference.ndim: + fill = fill.unsqueeze(0) + return fill.expand_as(reference).clone() + + class Sampler: """Estimator‑driven sampler for GFlowNet environments. @@ -890,3 +907,307 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 ) return new_trajectories, new_trajectories_log_pf, new_trajectories_log_pb + + +class CompiledChunkSampler(Sampler): + """Chunked tensor sampler that stays on the fast path for torch.compile.""" + + def __init__( + self, + estimator: Estimator, + *, + chunk_size: int = 32, + compile_mode: str = "reduce-overhead", + ) -> None: + super().__init__(estimator) + self.chunk_size = int(chunk_size) + self.compile_mode = compile_mode + self._compiled_chunk_cache: dict[tuple[int, str], Callable] = {} + + def sample_trajectories( + self, + env: Env, + n: Optional[int] = None, + states: Optional[States] = None, + conditions: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + + # Log-probs: we’d need to store each chunk’s dist (or sampled actions) plus a + # boolean mask of which rows were active, then call policy.fast_distribution + # (...).log_prob(...) during or after the chunk loop. Because done rows get + # forced to the exit action, we’d have to mask those out when accumulating + # log-probs so the padded semantics match Trajectories.log_probs. That means + # keeping per-step tensors shaped (chunk_len, batch, action_dim) and writing + # them into the context at the end. + + # Estimator outputs: same idea—capture the raw tensor returned by policy. + # fast_features/fast_distribution (whatever we consider the “estimator output”) + # for active rows, pad them back to batch size, and append to a list per chunk + # so we can stack them like the legacy sampler. + if save_estimator_outputs or save_logprobs: + raise NotImplementedError( + "CompiledChunkSampler does not yet record log-probs or estimator outputs." + ) + + if not isinstance(env, EnvFastPathMixin): + raise TypeError( + "CompiledChunkSampler requires environments implementing EnvFastPathMixin." + ) + + if not isinstance(self.estimator, FastPolicyMixin): + raise TypeError( + "CompiledChunkSampler requires estimators implementing FastPolicyMixin." + ) + + assert self.chunk_size > 0, "chunk_size must be positive" + + policy = cast(FastPolicyMixin, self.estimator) + + if states is None: + assert n is not None, "Either `n` or `states` must be provided." + states_obj = env.reset(batch_shape=(n,)) + else: + states_obj = states + assert len(states_obj.batch_shape) == 1, "States batch must be 1-D." + + batch = states_obj.batch_shape[0] + device = states_obj.device + + if conditions is not None: + assert ( + conditions.shape[0] == batch + ), "Conditions batch dimension must match states batch size." + ensure_same_device(device, conditions.device) + + curr_states = states_obj.tensor + done = states_obj.is_sink_state.clone() + exit_action_value = env.exit_action.to(device=curr_states.device) + dummy_action_value = env.dummy_action.to(device=curr_states.device) + + # `step_actions_seq` keeps the raw sampled actions (with exits injected for + # finished rows) so we can exactly replay the tensor-forward environment when + # rebuilding the state stack after the chunk loop. `recorded_actions_seq` + # mirrors those actions but rewrites already-finished rows with the env dummy + # action so that downstream Trajectories consumers (DBG/SubTB losses) never see + # transitions originating from sink states. + recorded_actions_seq: List[torch.Tensor] = [] + step_actions_seq: List[torch.Tensor] = [] + sink_seq: List[torch.Tensor] = [] + + chunk_size = int(policy_kwargs.pop("chunk_size", self.chunk_size)) + + def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ + torch.Tensor, + torch.Tensor, + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + torch.Tensor, + ]: + """ + This function is the core of the chunked sampler. It is responsible for + sampling actions for a chunk of states. It is called in a loop until all + states are done. It returns the current states, a boolean mask indicating + which states are done, the actions sampled for the chunk, and a boolean + mask indicating which states are sinks. + + The purpose of this function is to serve as a torch.compile-ed function to + speed up the sampling process. It is called in a loop until all states are + done. + + Args: + current_states: The current states to sample actions for. + done_mask: A boolean mask indicating which states are done. + + Returns: + """ + local_step_actions: List[torch.Tensor] = [] + local_recorded_actions: List[torch.Tensor] = [] + local_sinks: List[torch.Tensor] = [] + step_template: torch.Tensor | None = None + record_template: torch.Tensor | None = None + steps_taken = 0 + + for _ in range(chunk_size): + if bool(done_mask.all().item()): + assert step_template is not None and record_template is not None + pad_step = _fill_like_reference(step_template, exit_action_value) + pad_record = _fill_like_reference(record_template, dummy_action_value) + local_step_actions.append(pad_step) + local_recorded_actions.append(pad_record) + local_sinks.append(done_mask.clone()) + continue + + state_view = current_states + features = policy.fast_features( + state_view, + forward_masks=None, + backward_masks=None, + conditions=conditions, + ) + dist = policy.fast_distribution( + features, + forward_masks=None, + backward_masks=None, + states_tensor=state_view, + **policy_kwargs, + ) + + actions_tensor = dist.sample() + + if done_mask.any(): + # Broadcast the boolean mask and the per-env exit/dummy templates so + # they match the estimator's sampled action tensor shape (covers + # both scalar Discrete actions and potential multi-dim action + # heads). + mask = done_mask + while mask.ndim < actions_tensor.ndim: + mask = mask.unsqueeze(-1) + + exit_fill = exit_action_value.to( + device=actions_tensor.device, dtype=actions_tensor.dtype + ) + while exit_fill.ndim < actions_tensor.ndim: + exit_fill = exit_fill.unsqueeze(0) + + dummy_fill = dummy_action_value.to( + device=actions_tensor.device, dtype=actions_tensor.dtype + ) + while dummy_fill.ndim < actions_tensor.ndim: + dummy_fill = dummy_fill.unsqueeze(0) + + step_actions = torch.where(mask, exit_fill, actions_tensor) + record_actions = torch.where(mask, dummy_fill, actions_tensor) + else: + step_actions = actions_tensor + record_actions = actions_tensor + + # Only the step actions (exit-padded) are used to advance the tensor + # env. The recorded actions (dummy-padded) are used to reconstruct the + # state stack after the chunk loop. + step_res = env.step_tensor(current_states, step_actions) + current_states = step_res.next_states + sinks = step_res.is_sink_state + if sinks is None: + sinks = env.states_from_tensor(current_states).is_sink_state + + done_mask = done_mask | sinks + local_step_actions.append(step_actions) + local_recorded_actions.append(record_actions) + step_template = step_actions.detach() + record_template = record_actions.detach() + local_sinks.append(sinks) + steps_taken += 1 + + return ( + current_states, + done_mask, + local_step_actions, + local_recorded_actions, + local_sinks, + torch.tensor(steps_taken, device=current_states.device), + ) + + chunk_fn: Callable = _chunk_loop + chunk_fn_compiled = False + device_type = curr_states.device.type + compile_allowed = ( + hasattr(torch, "compile") and device_type in ("cuda", "cpu") and conditions is None and not policy_kwargs + ) + cache_key = (id(env), device_type) + if compile_allowed: + cached = self._compiled_chunk_cache.get(cache_key) + if cached is not None: + chunk_fn = cached + chunk_fn_compiled = True + else: + try: + compiled = torch.compile(_chunk_loop, mode=self.compile_mode) # type: ignore[arg-type] + self._compiled_chunk_cache[cache_key] = compiled + chunk_fn = compiled + chunk_fn_compiled = True + except Exception: + warnings.warn( + "Compilation of chunk_loop failed, using non-compiled version.", + stacklevel=2, + ) + chunk_fn = _chunk_loop + + # Main loop: call the compiled function until all states are done. + while not bool(done.all().item()): + if chunk_fn_compiled: + _mark_cudagraph_step() + ( + curr_states, + done, + step_actions_chunk, + recorded_actions_chunk, + sinks_chunk, + steps_taken_tensor, + ) = chunk_fn(curr_states, done) + steps_taken = int(steps_taken_tensor.item()) + if steps_taken: + step_actions_seq.extend(step_actions_chunk[:steps_taken]) + recorded_actions_seq.extend(recorded_actions_chunk[:steps_taken]) + sink_seq.extend(sinks_chunk[:steps_taken]) + + if recorded_actions_seq: + actions_tsr = torch.stack(recorded_actions_seq, dim=0) + T = actions_tsr.shape[0] + + s = states_obj.tensor + states_stack = [s] + for t in range(T): + # Re-simulate using the true step actions so reconstructed states match + # the chunk rollout exactly even though padded (recorded) actions may + # differ. + step = env.step_tensor(s, step_actions_seq[t]) + s = step.next_states + states_stack.append(s) + states_tsr = torch.stack(states_stack, dim=0) + + sinks_tsr = torch.stack(sink_seq, dim=0) + first_sink = torch.argmax(sinks_tsr.to(torch.long), dim=0) + never_sink = ~sinks_tsr.any(dim=0) + first_sink = torch.where( + never_sink, + torch.tensor(T - 1, device=device), + first_sink, + ) + terminating_idx = first_sink + 1 + else: + states_tsr = states_obj.tensor.unsqueeze(0) + actions_tsr = env.actions_from_batch_shape((0, batch)).tensor + terminating_idx = torch.zeros(batch, dtype=torch.long, device=device) + + # Ensure the stacked (dummy-padded) actions respect the environment's action + # shape before wrapping them into an Actions container (e.g., discrete envs + # expect (..., 1)). Without this guard DB/SubTB estimators would fail when the + # chunk sampler returns rank-1 tensors. + action_shape = getattr(env, "action_shape", None) + if action_shape: + tail_shape = tuple(actions_tsr.shape[-len(action_shape) :]) + if tail_shape != tuple(action_shape): + if tuple(action_shape) == (1,): + actions_tsr = actions_tsr.unsqueeze(-1) + else: + raise ValueError( + "CompiledChunkSampler produced actions with shape " + f"{actions_tsr.shape}, expected trailing dims {action_shape}." + ) + + trajectories = Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=conditions, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=policy.is_backward, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + return trajectories diff --git a/src/gfn/utils/compile.py b/src/gfn/utils/compile.py new file mode 100644 index 00000000..db6ded2d --- /dev/null +++ b/src/gfn/utils/compile.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Iterable + +import torch + + +def try_compile_gflownet( + gfn, + *, + mode: str = "default", + components: Iterable[str] = ("pf", "pb", "logZ", "logF"), +) -> dict[str, bool]: + """Best-effort compilation of estimator modules attached to a GFlowNet. + + Args: + gfn: The GFlowNet instance to compile. + mode: Compilation mode forwarded to ``torch.compile``. + components: Attribute names to attempt compilation on (e.g., ``pf``). + + Returns: + Mapping from component name to compilation success status. + """ + + if not hasattr(torch, "compile"): + return {name: False for name in components} + + results: dict[str, bool] = {} + for name in components: + + # If the estimator does not exist, we cannot compile it. + if not hasattr(gfn, name): + results[name] = False + continue + + estimator = getattr(gfn, name) + module = getattr(estimator, "module", None) + + # If the estimator does not have a module, we cannot compile it. + if module is None: + results[name] = False + continue + + # If the estimator does not have a module, we cannot compile it. + try: + assert isinstance(estimator.module, torch.nn.Module) + estimator.module = torch.compile(module, mode=mode) + results[name] = True + except Exception: + results[name] = False + + return results diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 5f1a75e8..b007f6e0 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -96,6 +96,26 @@ def get_trajectory_pfs( valid_actions = trajectories.actions[action_mask] if valid_states.batch_shape != valid_actions.batch_shape: + print( + "[DEBUG get_trajectory_pfs] state_mask shape:", + state_mask.shape, + "action_mask shape:", + action_mask.shape, + ) + print( + "[DEBUG get_trajectory_pfs] valid_states.batch_shape:", + valid_states.batch_shape, + "valid_actions.batch_shape:", + valid_actions.batch_shape, + ) + print( + "[DEBUG get_trajectory_pfs] trajectories.states.is_sink_state:", + trajectories.states.is_sink_state.shape, + ) + print( + "[DEBUG get_trajectory_pfs] trajectories.actions.is_dummy:", + trajectories.actions.is_dummy.shape, + ) raise AssertionError("Something wrong happening with log_pf evaluations") if trajectories.has_log_probs and not recalculate_all_logprobs: diff --git a/testing/test_environments.py b/testing/test_environments.py index 42abc54f..f889447e 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, cast import numpy as np import pytest @@ -7,7 +7,8 @@ from gfn.actions import GraphActions, GraphActionType from gfn.env import Env, NonValidActionsError -from gfn.gym import Box, DiscreteEBM, HyperGrid +from gfn.gym import BitSequence, BitSequencePlus, Box, DiscreteEBM, HyperGrid, Line +from gfn.gym.diffusion_sampling import DiffusionSampling from gfn.gym.graph_building import GraphBuilding from gfn.gym.perfect_tree import PerfectBinaryTree from gfn.gym.set_addition import SetAddition @@ -133,6 +134,45 @@ def test_HyperGrid_bwd_step(): states = env._backward_step(states, failing_actions) +def test_HyperGrid_fast_path_matches_legacy(): + NDIM = 3 + ENV_HEIGHT = 5 + BATCH_SIZE = 64 + + env = HyperGrid(ndim=NDIM, height=ENV_HEIGHT) + states = env.reset(batch_shape=BATCH_SIZE, random=True, seed=123) + + assert states.forward_masks is not None + tensor_masks = env.forward_action_masks_tensor(states.tensor) + assert states.forward_masks is not None + legacy_forward_masks = cast(torch.Tensor, states.forward_masks) + assert torch.equal(tensor_masks, legacy_forward_masks) + + action_dist = torch.distributions.Categorical( + probs=states.forward_masks.to(dtype=torch.float32) + ) + actions_tensor = action_dist.sample().unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_step_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_step_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + fast_is_sink = cast(torch.Tensor, fast.is_sink_state) + assert torch.equal(fast_is_sink, legacy_next.is_sink_state) + assert fast.forward_masks is not None + fast_forward_masks = cast(torch.Tensor, fast.forward_masks) + assert fast.backward_masks is not None + fast_backward_masks = cast(torch.Tensor, fast.backward_masks) + assert torch.equal(fast_forward_masks, legacy_step_forward) + assert torch.equal(fast_backward_masks, legacy_step_backward) + + def test_DiscreteEBM_fwd_step(): NDIM = 2 BATCH_SIZE = 4 @@ -193,6 +233,45 @@ def test_DiscreteEBM_bwd_step(): states = env._backward_step(states, failing_actions) +def test_DiscreteEBM_fast_path_matches_legacy(): + NDIM = 5 + BATCH_SIZE = 48 + env = DiscreteEBM(ndim=NDIM) + states_tensor = torch.randint( + -1, 2, (BATCH_SIZE, NDIM), dtype=torch.long, device=env.device + ) + states = env.states_from_tensor(states_tensor.clone()) + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = torch.distributions.Categorical( + probs=forward_masks.to(dtype=torch.float32) + ).sample() + actions_tensor = actions_tensor.unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(fast.is_sink_state, legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], legacy_forward[non_sink] + ) + fast_backward = cast(torch.Tensor, fast.backward_masks)[non_sink, : 2 * NDIM] + legacy_backward_trim = legacy_backward[non_sink, : 2 * NDIM] + assert torch.equal(fast_backward, legacy_backward_trim) + + @pytest.mark.parametrize("delta", [0.1, 0.5, 1.0]) def test_box_fwd_step(delta: float): env = Box(delta=delta) @@ -592,6 +671,28 @@ def test_graph_env(): assert states.tensor.x.shape == (0, 1) +def test_Line_fast_path_matches_legacy(): + BATCH_SIZE = 32 + env = Line( + mus=[0.0, 2.0], + sigmas=[0.5, 0.75], + init_value=0.1, + n_steps_per_trajectory=5, + ) + states = env.reset(batch_shape=BATCH_SIZE) + actions_tensor = torch.randn(BATCH_SIZE, 1, device=states.device) + exit_mask = torch.rand(BATCH_SIZE, device=states.device) < 0.25 + actions_tensor[exit_mask] = env.exit_action.item() + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.allclose(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + fast_sink = cast(torch.Tensor, fast.is_sink_state) + assert torch.equal(fast_sink, legacy_next.is_sink_state) + + def test_set_addition_fwd_step(): N_ITEMS = 4 MAX_ITEMS = 3 @@ -648,6 +749,147 @@ def test_set_addition_fwd_step(): assert torch.allclose(rewards, expected_rewards) +def test_box_fast_path_matches_legacy(): + BATCH_SIZE = 48 + DELTA = 0.2 + env = Box(delta=DELTA) + states = env.reset(batch_shape=BATCH_SIZE) + actions_tensor = torch.zeros( + BATCH_SIZE, 2, dtype=torch.get_default_dtype(), device=states.device + ) + exit_mask = torch.rand(BATCH_SIZE, device=states.device) < 0.2 + actions_tensor[exit_mask] = env.exit_action.to(actions_tensor.device).to( + actions_tensor.dtype + ) + non_exit_idx = (~exit_mask).nonzero(as_tuple=True)[0] + if len(non_exit_idx) > 0: + radii = torch.rand(len(non_exit_idx), device=states.device) * DELTA + angles = torch.rand(len(non_exit_idx), device=states.device) * 2 * torch.pi + actions_tensor[non_exit_idx, 0] = radii * torch.cos(angles) + actions_tensor[non_exit_idx, 1] = radii * torch.sin(angles) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.allclose(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + +def test_diffusion_sampling_fast_path_matches_legacy(): + BATCH_SIZE = 16 + env = DiffusionSampling( + target_str="gmm2", target_kwargs={}, num_discretization_steps=8.0 + ) + states = env.reset(batch_shape=BATCH_SIZE) + actions_tensor = torch.randn( + BATCH_SIZE, env.dim, device=states.device, dtype=states.tensor.dtype + ) + exit_mask = torch.rand(BATCH_SIZE, device=states.device) < 0.2 + if exit_mask.any(): + exit_action = env.exit_action.to(device=states.device, dtype=states.tensor.dtype) + actions_tensor[exit_mask] = exit_action + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + fast = env.step_tensor(states.tensor, actions_tensor) + + assert torch.allclose(fast.next_states, legacy_next.tensor) + assert fast.is_sink_state is not None + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + +def test_bitsequence_fast_path_matches_legacy(): + BATCH_SIZE = 32 + env = BitSequence(word_size=2, seq_size=8, n_modes=4, temperature=1.0) + states = env.reset(batch_shape=BATCH_SIZE) + for _ in range(3): + assert states.forward_masks is not None + masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + states = env._step(states, env.actions_from_tensor(actions_tensor)) + + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=forward_masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], + cast(torch.Tensor, legacy_next.forward_masks)[non_sink], + ) + assert torch.equal( + cast(torch.Tensor, fast.backward_masks)[non_sink], + cast(torch.Tensor, legacy_next.backward_masks)[non_sink], + ) + + +def test_bitsequence_plus_fast_path_matches_legacy(): + BATCH_SIZE = 24 + env = BitSequencePlus(word_size=2, seq_size=16, n_modes=5, temperature=1.0) + states = env.reset(batch_shape=BATCH_SIZE) + for _ in range(4): + assert states.forward_masks is not None + masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + states = env._step(states, env.actions_from_tensor(actions_tensor)) + + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = ( + torch.distributions.Categorical(probs=forward_masks.to(dtype=torch.float32)) + .sample() + .unsqueeze(-1) + ) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(cast(torch.Tensor, fast.is_sink_state), legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], + cast(torch.Tensor, legacy_next.forward_masks)[non_sink], + ) + assert torch.equal( + cast(torch.Tensor, fast.backward_masks)[non_sink], + cast(torch.Tensor, legacy_next.backward_masks)[non_sink], + ) + + def test_set_addition_bwd_step(): N_ITEMS = 5 MAX_ITEMS = 4 @@ -692,6 +934,49 @@ def test_set_addition_bwd_step(): assert torch.all(states.is_initial_state) +def test_set_addition_fast_path_matches_legacy(): + N_ITEMS = 6 + MAX_ITEMS = 4 + BATCH_SIZE = 48 + + env = SetAddition( + n_items=N_ITEMS, max_items=MAX_ITEMS, reward_fn=lambda s: s.sum(-1) + ) + states_tensor = torch.randint( + 0, 2, (BATCH_SIZE, N_ITEMS), dtype=torch.get_default_dtype() + ) + states = env.states_from_tensor(states_tensor.clone()) + + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = torch.distributions.Categorical( + probs=forward_masks.to(dtype=torch.float32) + ).sample() + actions_tensor = actions_tensor.unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(fast.is_sink_state, legacy_next.is_sink_state) + + fast_forward = cast(torch.Tensor, fast.forward_masks) + fast_backward = cast(torch.Tensor, fast.backward_masks)[..., : env.n_items] + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal(fast_forward[non_sink], legacy_forward[non_sink]) + assert torch.equal(fast_backward[non_sink], legacy_backward[non_sink]) + + def test_perfect_binary_tree_fwd_step(): DEPTH = 3 BATCH_SIZE = 2 @@ -781,6 +1066,49 @@ def test_perfect_binary_tree_bwd_step(): assert torch.all(states.is_initial_state) +def test_perfect_binary_tree_fast_path_matches_legacy(): + DEPTH = 4 + BATCH_SIZE = 64 + + env = PerfectBinaryTree( + depth=DEPTH, + reward_fn=lambda s: s.to(torch.get_default_dtype()) + 1, + ) + states_tensor = torch.randint( + 0, env.n_nodes, (BATCH_SIZE, 1), dtype=torch.long, device=env.device + ) + states = env.states_from_tensor(states_tensor.clone()) + assert states.forward_masks is not None + forward_masks = cast(torch.Tensor, states.forward_masks) + actions_tensor = torch.distributions.Categorical( + probs=forward_masks.to(dtype=torch.float32) + ).sample() + actions_tensor = actions_tensor.unsqueeze(-1) + + legacy_next = env._step(states, env.actions_from_tensor(actions_tensor.clone())) + assert legacy_next.forward_masks is not None + assert legacy_next.backward_masks is not None + legacy_forward = cast(torch.Tensor, legacy_next.forward_masks) + legacy_backward = cast(torch.Tensor, legacy_next.backward_masks) + + fast = env.step_tensor(states.tensor, actions_tensor) + assert fast.forward_masks is not None + assert fast.backward_masks is not None + assert fast.is_sink_state is not None + + assert torch.equal(fast.next_states, legacy_next.tensor) + assert torch.equal(fast.is_sink_state, legacy_next.is_sink_state) + + non_sink = ~legacy_next.is_sink_state + if non_sink.any(): + assert torch.equal( + cast(torch.Tensor, fast.forward_masks)[non_sink], legacy_forward[non_sink] + ) + assert torch.equal( + cast(torch.Tensor, fast.backward_masks)[non_sink], legacy_backward[non_sink] + ) + + # ----------------------------------------------------------------------------- # Tests for default sf fill value based on dtype # ----------------------------------------------------------------------------- diff --git a/tutorials/examples/train_hypergrid_optimized.py b/tutorials/examples/train_hypergrid_optimized.py new file mode 100644 index 00000000..1b9cf93f --- /dev/null +++ b/tutorials/examples/train_hypergrid_optimized.py @@ -0,0 +1,1972 @@ +#!/usr/bin/env python +r""" +Optimized multi-environment (HyperGrid + Diffusion) training/benchmark script with +optional torch.compile, vmap, and chunked sampling across several GFlowNet variants. + + +TODO: + +We need actual profiling on CUDA (start with torch.profiler.profile(use_cuda=True) around chunk_fn) to see the kernel counts and copy sizes. If compile is failing, we must inspect the Dynamo logs to see what op blocks it (maybe the env.actions_from_batch_shape call inside _chunk_loop still triggers Python). If compile succeeds, then GPU is just overwhelmed by host-device copies and we should keep the script fast path on CPU. +Next actions I recommend: +Run the benchmark with TORCH_LOGS="graph_breaks" and TORCHDYNAMO_VERBOSE=1 so we can see why the chunk loops bail out of compilation. Share those snippets if possible. +Profile the “Library Fast Path” on CUDA (PyTorch profiler) to find the hottest ops. If the C++ chunk sampler is dominated by scatter/where, we might need to batch them or increase chunk size to amortize kernel launches. +For the script sampler, try forcing chunk_fn_compiled=False to confirm whether the slowdown is due to torch.compile overhead; if it runs faster without compilation, we’ll know the compiled graph is recomputing templates or copying more than expected. + + +""" + +from __future__ import annotations + +import argparse +import statistics +import time +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Literal, cast + +import torch +from torch.func import vmap +from tqdm import tqdm + +try: # Enable scalar captures for torch.compile to avoid graph breaks on .item(). + import torch._dynamo as _torch_dynamo + + _torch_dynamo.config.capture_scalar_outputs = True +except Exception: # pragma: no cover - defensive fallback on older PyTorch + _torch_dynamo = None + +from gfn.containers import Trajectories +from gfn.env import Env, EnvFastPathMixin +from gfn.estimators import ( + DiscretePolicyEstimator, + FastPolicyMixin, + PinnedBrownianMotionBackward, + PinnedBrownianMotionForward, + ScalarEstimator, +) +from gfn.gflownet import PFBasedGFlowNet, SubTBGFlowNet +from gfn.gflownet.detailed_balance import DBGFlowNet +from gfn.gflownet.trajectory_balance import TBGFlowNet +from gfn.gym import HyperGrid +from gfn.gym.diffusion_sampling import DiffusionSampling +from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor +from gfn.samplers import CompiledChunkSampler, Sampler +from gfn.states import DiscreteStates, States +from gfn.utils.common import set_seed +from gfn.utils.compile import try_compile_gflownet +from gfn.utils.modules import ( + MLP, + DiffusionFixedBackwardModule, + DiffusionPISGradNetForward, +) +from gfn.utils.training import validate + + +def _mark_cudagraph_step() -> None: + compiler = getattr(torch, "compiler", None) + if compiler is None: + return + marker = getattr(compiler, "cudagraph_mark_step_begin", None) + if callable(marker): + marker() + + +def _fill_like_reference(reference: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + """Broadcasts `value` to the shape/dtype/device of `reference`.""" + fill = value.to(device=reference.device, dtype=reference.dtype) + while fill.ndim < reference.ndim: + fill = fill.unsqueeze(0) + return fill.expand_as(reference).clone() + + +# Default HyperGrid configuration (easy to extend to multiple envs later on). +HYPERGRID_KWARGS: Dict[str, Any] = { + "ndim": 2, + "height": 32, + "reward_fn_str": "original", + "reward_fn_kwargs": {"R0": 0.1, "R1": 0.5, "R2": 2.0}, + "calculate_partition": False, + "store_all_states": False, + "check_action_validity": __debug__, +} + +DEFAULT_CHUNK_SIZE = 32 +DEFAULT_COMPILE_MODE = "reduce-overhead" + + +@dataclass +class ScenarioConfig: + name: str + description: str + sampler: Literal["standard", "compiled_chunk", "script_chunk"] + use_script_env: bool + use_compile: bool + use_vmap: bool + + +@dataclass(frozen=True) +class FlowVariant: + key: Literal["tb", "dbg", "subtb"] + label: str + description: str + requires_logf: bool + supports_vmap: bool + + +HYPERGRID_SCENARIOS: list[ScenarioConfig] = [ + ScenarioConfig( + name="Baseline (core)", + description="Stock library path: standard env + sampler, no compilation.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=False, + ), + ScenarioConfig( + name="VMap Only", + description="VMAP Accelerated Loss.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=True, + ), + ScenarioConfig( + name="Compile Only (core)", + description="Standard env + sampler with torch.compile but no chunking.", + sampler="standard", + use_script_env=False, + use_compile=True, + use_vmap=True, + ), + ScenarioConfig( + name="Library Fast Path", + description="Core EnvFastPath + CompiledChunkSampler + compile + vmap TB.", + sampler="compiled_chunk", + use_script_env=False, + use_compile=True, + use_vmap=True, + ), + ScenarioConfig( + name="Script Fast Path", + description="Script-local tensor env/sampler, compile, and vmap TB.", + sampler="script_chunk", + use_script_env=True, + use_compile=True, + use_vmap=True, + ), +] + +DIFFUSION_SCENARIOS: list[ScenarioConfig] = [ + ScenarioConfig( + name="Diffusion Baseline", + description="Pinned Brownian sampler without compilation or chunking.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=False, + ), + ScenarioConfig( + name="Diffusion VMap Only", + description="Pinned Brownian sampler without compilation or chunking.", + sampler="standard", + use_script_env=False, + use_compile=False, + use_vmap=True, + ), + ScenarioConfig( + name="Diffusion Compile Only", + description="Standard diffusion sampler with torch.compile but no chunking.", + sampler="standard", + use_script_env=False, + use_compile=True, + use_vmap=True, + ), + ScenarioConfig( + name="Diffusion Library Fast Path", + description="EnvFastPath + CompiledChunkSampler (library implementation).", + sampler="compiled_chunk", + use_script_env=False, + use_compile=True, + use_vmap=False, + ), + ScenarioConfig( + name="Diffusion Script Fast Path", + description="Script-local tensor sampler tailored to diffusion states.", + sampler="script_chunk", + use_script_env=False, + use_compile=True, + use_vmap=False, + ), +] + + +FLOW_VARIANTS: dict[str, FlowVariant] = { + "tb": FlowVariant( + key="tb", + label="TBGFlowNet", + description="Trajectory Balance baseline with optional torch.compile/vmap.", + requires_logf=False, + supports_vmap=True, + ), + "dbg": FlowVariant( + key="dbg", + label="DBGFlowNet", + description="Detailed Balance loss with learned log-state flows.", + requires_logf=True, + supports_vmap=False, + ), + "subtb": FlowVariant( + key="subtb", + label="SubTBGFlowNet", + description="Sub-trajectory balance variant with configurable weighting.", + requires_logf=True, + supports_vmap=False, + ), +} + +DEFAULT_FLOW_ORDER = ["tb", "dbg", "subtb"] + +# Plot styling: consistent colors for GFlowNet variants, linestyles for scenarios. +VARIANT_COLORS: dict[str, str] = { + "tb": "#000000", # Trajectory Balance -> black + "subtb": "#d62728", # SubTB -> red + "dbg": "#1f77b4", # Detailed Balance -> blue +} +SCENARIO_LINESTYLES: dict[str, Any] = { + "Baseline (core)": "-", + "Compile Only (core)": "-.", + "Library Fast Path": "--", # fast-path compiled + "Script Fast Path": ":", + "Diffusion Baseline": "-", + "Diffusion Compile Only": "-.", + "Diffusion Library Fast Path": "--", + "Diffusion Script Fast Path": ":", +} +LOSS_LINE_ALPHA = 0.5 + + +@dataclass +class EnvironmentBenchmark: + key: Literal["hypergrid", "diffusion"] + label: str + description: str + color: str + scenarios: list[ScenarioConfig] + supported_flows: list[str] + supports_validation: bool + + +ENVIRONMENT_BENCHMARKS: dict[str, EnvironmentBenchmark] = { + "hypergrid": EnvironmentBenchmark( + key="hypergrid", + label="HyperGrid", + description="High-dimensional discrete lattice with known reward landscape.", + color="#4a90e2", + scenarios=HYPERGRID_SCENARIOS, + supported_flows=list(DEFAULT_FLOW_ORDER), + supports_validation=True, + ), + "diffusion": EnvironmentBenchmark( + key="diffusion", + label="Diffusion Sampling", + description="Continuous-state diffusion sampling benchmark (Pinned Brownian).", + color="#a17be7", + scenarios=DIFFUSION_SCENARIOS, + supported_flows=list(DEFAULT_FLOW_ORDER), + supports_validation=False, + ), +} +DEFAULT_ENV_ORDER = ["hypergrid", "diffusion"] + + +def _normalize_flow_keys(requested: list[str]) -> list[str]: + normalized: list[str] = [] + for key in requested: + alias = key.lower() + if alias not in FLOW_VARIANTS: + supported = ", ".join(sorted(FLOW_VARIANTS)) + raise ValueError( + f"Unsupported GFlowNet variant '{key}'. Choose from {supported}." + ) + if alias not in normalized: + normalized.append(alias) + return normalized + + +def _normalize_env_keys(requested: list[str]) -> list[str]: + normalized: list[str] = [] + available = ENVIRONMENT_BENCHMARKS + for key in requested: + alias = key.lower() + if alias not in available: + supported = ", ".join(sorted(available)) + raise ValueError( + f"Unsupported environment '{key}'. Choose from {supported}." + ) + if alias not in normalized: + normalized.append(alias) + return normalized or list(DEFAULT_ENV_ORDER) + + +# Local subclasses for benchmarking-only optimizations (no core library changes) +class HyperGridWithTensorStep(HyperGrid): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + eye = torch.eye(self.ndim, dtype=torch.long) + zero_row = torch.zeros((1, self.ndim), dtype=torch.long) + self._unit_step_template = torch.cat([eye, zero_row], dim=0) + self._sink_state_template = self.sf.to(dtype=torch.long).clone() + + def _get_unit_steps(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + return self._unit_step_template.to(device=device, dtype=dtype) + + def _get_sink_state(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + return self._sink_state_template.to(device=device, dtype=dtype) + + def step_tensor( + self, states: torch.Tensor, actions: torch.Tensor + ) -> Env.TensorStepResult: + assert states.dtype == torch.long + device = states.device + batch = states.shape[0] + exit_idx = self.n_actions - 1 + + if actions.ndim == 1: + action_idx = actions + else: + action_idx = actions.view(-1) + + action_idx = action_idx.to(torch.long) + is_exit = action_idx == exit_idx + + unit_steps = self._get_unit_steps(device, states.dtype) + deltas = unit_steps.index_select(0, action_idx.clamp(max=exit_idx)) + next_states = states + deltas + + sink_state = self._get_sink_state(device, states.dtype).view(1, -1) + next_states = torch.where( + is_exit.view(-1, 1), sink_state.expand_as(next_states), next_states + ) + + forward_masks = torch.cat( + [ + next_states != (self.height - 1), + torch.ones((batch, 1), dtype=torch.bool, device=device), + ], + dim=-1, + ) + backward_masks = next_states != 0 + + return self.TensorStepResult( + next_states=next_states, + is_sink_state=is_exit, + forward_masks=forward_masks, + backward_masks=backward_masks, + ) + + def forward_action_masks(self, states_tensor: torch.Tensor) -> torch.Tensor: + """Returns forward-action masks for a batch of state tensors.""" + base = states_tensor != (self.height - 1) + return torch.cat( + [ + base, + torch.ones( + (states_tensor.shape[0], 1), + dtype=torch.bool, + device=states_tensor.device, + ), + ], + dim=-1, + ) + + +class ChunkedHyperGridSampler(Sampler): + def __init__(self, estimator, chunk_size: int): + super().__init__(estimator) + self.chunk_size = int(chunk_size) + self._compiled_chunk_cache: dict[tuple[int, str], Callable] = {} + + def sample_trajectories( # noqa: C901 + self, + env: HyperGridWithTensorStep, + n: int | None = None, + states: DiscreteStates | None = None, + conditions: torch.Tensor | None = None, + save_estimator_outputs: bool = False, # unused in chunked fast path + save_logprobs: bool = False, # unused in chunked fast path + **policy_kwargs: Any, + ): + assert self.chunk_size > 0 + assert hasattr(env, "step_tensor") + policy_kwargs = dict(policy_kwargs) + epsilon = float(policy_kwargs.pop("epsilon", 0.0)) + + if states is None: + assert n is not None + states_obj = env.reset(batch_shape=(n,)) + else: + states_obj = states + + if not isinstance(self.estimator, FastPolicyMixin): + raise TypeError( + "ChunkedHyperGridSampler requires a FastPolicy-compatible estimator." + ) + policy = cast(FastPolicyMixin, self.estimator) + chunk_size = max(1, self.chunk_size) + exit_idx = env.n_actions - 1 + + curr_states = states_obj.tensor + batch = curr_states.shape[0] + device = curr_states.device + + def compute_forward_masks(states_tensor: torch.Tensor) -> torch.Tensor: + if hasattr(env, "forward_action_masks"): + return env.forward_action_masks(states_tensor) + if hasattr(env, "forward_action_masks_tensor"): + return env.forward_action_masks_tensor(states_tensor) + raise TypeError( + "HyperGrid environment must expose forward action masks for fast path." + ) + + def step_tensor( + states_tensor: torch.Tensor, actions_tensor: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + step_result = env.step_tensor(states_tensor, actions_tensor) + if isinstance(step_result, Env.TensorStepResult): + next_states = step_result.next_states + next_masks = step_result.forward_masks + if next_masks is None: + next_masks = compute_forward_masks(next_states) + is_exit_states = step_result.is_sink_state + if is_exit_states is None: + is_exit_states = env.states_from_tensor(next_states).is_sink_state + return next_states, next_masks, is_exit_states + assert isinstance(step_result, tuple) and len(step_result) == 3 + next_states, next_masks, is_exit_states = step_result + return next_states, next_masks, is_exit_states + + exit_action_value = env.exit_action.to(device=device) + dummy_action_value = env.dummy_action.to(device=device) + + forward_masks = compute_forward_masks(curr_states) + done = torch.zeros(batch, dtype=torch.bool, device=device) + actions_seq: List[torch.Tensor] = [] + dones_seq: List[torch.Tensor] = [] + states_seq: List[torch.Tensor] = [curr_states.clone().unsqueeze(0)] + + exit_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + dummy_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + + def _broadcast_done_mask(mask: torch.Tensor, target_ndim: int) -> torch.Tensor: + view_shape = mask.shape + (1,) * (target_ndim - mask.ndim) + return mask.view(view_shape) + + def _get_template( + cache: dict[tuple[int, torch.dtype], torch.Tensor], + base_value: torch.Tensor, + target_ndim: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + key = (target_ndim, dtype) + tensor = cache.get(key) + if tensor is None: + tensor = base_value.to(device=device, dtype=dtype) + if tensor.ndim > target_ndim: + raise ValueError( + f"Base action tensor has ndim={tensor.ndim}, " + f"but target_ndim={target_ndim}." + ) + leading = (1,) * (target_ndim - tensor.ndim) + tensor = tensor.view(leading + tuple(tensor.shape)) + cache[key] = tensor + return tensor + + device_type = curr_states.device.type + compile_allowed = ( + hasattr(torch, "compile") + and device_type in ("cuda", "cpu") + and conditions is None + and not policy_kwargs + ) + compile_key = (id(env), device_type) + chunk_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple] | None = ( + None + ) + chunk_fn_compiled = False + if compile_allowed: + chunk_fn = self._compiled_chunk_cache.get(compile_key) + if chunk_fn is not None: + chunk_fn_compiled = True + + if chunk_fn is None: + + def _chunk_loop( + current_states: torch.Tensor, + current_masks: torch.Tensor, + done_mask: torch.Tensor, + ): + actions_buf: torch.Tensor | None = None + dones_buf: torch.Tensor | None = None + states_buf: torch.Tensor | None = None + pad_template: torch.Tensor | None = None + steps_taken = 0 + + for step in range(chunk_size): + if bool(done_mask.all().item()): + assert actions_buf is not None + assert dones_buf is not None + assert states_buf is not None + assert pad_template is not None + actions_buf[step].copy_(pad_template) + dones_buf[step].copy_(done_mask) + states_buf[step].copy_(current_states) + continue + + masks = current_masks + if done_mask.any(): + masks = masks.clone() + masks[done_mask] = False + masks[done_mask, exit_idx] = True + + features = policy.fast_features( + current_states, + forward_masks=masks, + backward_masks=None, + conditions=conditions, + ) + dist = policy.fast_distribution( + features, + forward_masks=masks, + backward_masks=None, + states_tensor=current_states, + epsilon=epsilon, + **policy_kwargs, + ) + sampled_actions = dist.sample() + step_actions = sampled_actions + record_actions = sampled_actions + + if done_mask.any(): + mask = _broadcast_done_mask(done_mask, sampled_actions.ndim) + exit_fill = _get_template( + exit_template_cache, + exit_action_value, + sampled_actions.ndim, + sampled_actions.dtype, + sampled_actions.device, + ).expand_as(sampled_actions) + dummy_fill = _get_template( + dummy_template_cache, + dummy_action_value, + sampled_actions.ndim, + sampled_actions.dtype, + sampled_actions.device, + ).expand_as(sampled_actions) + step_actions = torch.where(mask, exit_fill, sampled_actions) + record_actions = torch.where(mask, dummy_fill, sampled_actions) + + next_states, next_masks, is_exit = step_tensor( + current_states, step_actions + ) + + if actions_buf is None: + actions_buf = record_actions.new_empty( + (chunk_size,) + tuple(record_actions.shape) + ) + dones_buf = is_exit.new_empty( + (chunk_size,) + tuple(is_exit.shape) + ) + states_buf = next_states.new_empty( + (chunk_size,) + tuple(next_states.shape) + ) + pad_template = _get_template( + dummy_template_cache, + dummy_action_value, + record_actions.ndim, + record_actions.dtype, + record_actions.device, + ).expand_as(record_actions) + + assert actions_buf is not None + assert dones_buf is not None + assert states_buf is not None + + actions_buf[step].copy_(record_actions) + dones_buf[step].copy_(is_exit) + states_buf[step].copy_(next_states) + + current_states = next_states + current_masks = next_masks + done_mask = done_mask | is_exit + steps_taken += 1 + + if actions_buf is None: + batch = current_states.shape[0] + empty_actions = env.actions_from_batch_shape((0, batch)).tensor.to( + device=current_states.device + ) + actions_out = empty_actions + else: + actions_out = actions_buf[:steps_taken] + + empty_dones = done_mask.new_empty((0,) + done_mask.shape) + empty_states = current_states.new_empty((0,) + current_states.shape) + dones_out = ( + dones_buf[:steps_taken] if dones_buf is not None else empty_dones + ) + states_out = ( + states_buf[:steps_taken] if states_buf is not None else empty_states + ) + + return ( + current_states, + current_masks, + done_mask, + actions_out, + dones_out, + states_out, + torch.tensor(steps_taken, device=current_states.device), + ) + + chunk_fn = _chunk_loop + if compile_allowed: + try: + chunk_fn = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] + self._compiled_chunk_cache[compile_key] = chunk_fn + chunk_fn_compiled = True + except Exception: + chunk_fn = _chunk_loop + + while not bool(done.all().item()): + if chunk_fn_compiled: + _mark_cudagraph_step() + ( + curr_states, + forward_masks, + done, + actions_chunk, + dones_chunk, + states_chunk, + steps_taken_tensor, + ) = chunk_fn(curr_states, forward_masks, done) + steps_taken = int(steps_taken_tensor.item()) + if steps_taken: + actions_seq.append(actions_chunk) + dones_seq.append(dones_chunk) + states_seq.append(states_chunk) + + if actions_seq: + actions_tsr = torch.cat(actions_seq, dim=0) + states_tsr = torch.cat(states_seq, dim=0) + action_shape = getattr(env, "action_shape", None) + if action_shape: + tail_shape = tuple(actions_tsr.shape[-len(action_shape) :]) + if tail_shape != tuple(action_shape): + if tuple(action_shape) == (1,): + actions_tsr = actions_tsr.unsqueeze(-1) + else: + raise ValueError( + "ChunkedHyperGridSampler produced actions with shape " + f"{actions_tsr.shape}, expected trailing dims {action_shape}." + ) + is_exit_seq = torch.cat(dones_seq, dim=0) + T = actions_tsr.shape[0] + first_exit = torch.argmax(is_exit_seq.to(torch.long), dim=0) + never_exited = ~is_exit_seq.any(dim=0) + first_exit = torch.where( + never_exited, torch.tensor(T - 1, device=device), first_exit + ) + terminating_idx = first_exit + 1 + else: + states_tsr = states_obj.tensor.unsqueeze(0) + actions_tsr = env.actions_from_batch_shape((0, states_tsr.shape[1])).tensor + terminating_idx = torch.zeros( + states_tsr.shape[1], dtype=torch.long, device=device + ) + + trajectories = Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=None, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=False, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + return trajectories + + +class ChunkedDiffusionSampler(Sampler): + """Chunked fast-path sampler specialized for DiffusionSampling states.""" + + def __init__(self, estimator: PinnedBrownianMotionForward, chunk_size: int): + super().__init__(estimator) + self.chunk_size = int(chunk_size) + self._compiled_chunk_cache: dict[tuple[int, str], Callable] = {} + + def sample_trajectories( # noqa: C901 + self, + env: DiffusionSampling, + n: int | None = None, + states: States | None = None, + conditions: torch.Tensor | None = None, + save_estimator_outputs: bool = False, + save_logprobs: bool = False, + **policy_kwargs: Any, + ) -> Trajectories: + if save_estimator_outputs or save_logprobs: + raise NotImplementedError( + "ChunkedDiffusionSampler does not record estimator outputs/log-probs yet." + ) + if not isinstance(env, EnvFastPathMixin): + raise TypeError( + "ChunkedDiffusionSampler requires environments with tensor fast paths." + ) + if not isinstance(self.estimator, FastPolicyMixin): + raise TypeError( + "ChunkedDiffusionSampler requires a FastPolicy-compatible estimator." + ) + + policy = cast(FastPolicyMixin, self.estimator) + chunk_size = max(1, self.chunk_size) + + if states is None: + assert n is not None + states_obj = env.reset(batch_shape=(n,)) + else: + states_obj = states + + curr_states = states_obj.tensor + done = states_obj.is_sink_state.clone() + exit_action_value = env.exit_action.to(device=curr_states.device) + dummy_action_value = env.dummy_action.to(device=curr_states.device) + + recorded_actions_seq: List[torch.Tensor] = [] + sink_seq: List[torch.Tensor] = [] + states_stack: List[torch.Tensor] = [curr_states.clone()] + + exit_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + dummy_template_cache: dict[tuple[int, torch.dtype], torch.Tensor] = {} + + def _expand_front(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + expand_dims = target_ndim - tensor.ndim + if expand_dims <= 0: + return tensor + view_shape = (1,) * expand_dims + tuple(tensor.shape) + return tensor.view(view_shape) + + def _expand_back(tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + expand_dims = target_ndim - tensor.ndim + if expand_dims <= 0: + return tensor + view_shape = tuple(tensor.shape) + (1,) * expand_dims + return tensor.view(view_shape) + + def _get_template( + cache: dict[tuple[int, torch.dtype], torch.Tensor], + base_value: torch.Tensor, + target_ndim: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + key = (target_ndim, dtype) + tensor = cache.get(key) + if tensor is None: + tensor = _expand_front( + base_value.to(device=device, dtype=dtype), target_ndim + ) + cache[key] = tensor + return tensor + + def _chunk_loop(current_states: torch.Tensor, done_mask: torch.Tensor) -> tuple[ + torch.Tensor, + torch.Tensor, + List[torch.Tensor], + List[torch.Tensor], + List[torch.Tensor], + torch.Tensor, + ]: + local_recorded_actions: List[torch.Tensor] = [] + local_sinks: List[torch.Tensor] = [] + local_states: List[torch.Tensor] = [] + action_template: torch.Tensor | None = None + steps_taken = 0 + + for _ in range(chunk_size): + if bool(done_mask.all().item()): + assert action_template is not None + pad_actions = _fill_like_reference( + action_template, dummy_action_value + ) + local_recorded_actions.append(pad_actions) + local_sinks.append(done_mask.clone()) + local_states.append(current_states.clone()) + continue + + features = policy.fast_features( + current_states, + forward_masks=None, + backward_masks=None, + conditions=conditions, + ) + dist = policy.fast_distribution( + features, + forward_masks=None, + backward_masks=None, + states_tensor=current_states, + **policy_kwargs, + ) + sampled_actions = dist.sample() + step_actions = sampled_actions + record_actions = sampled_actions + + if torch.any(done_mask): + mask = _expand_back(done_mask, sampled_actions.ndim) + exit_fill = _get_template( + exit_template_cache, + exit_action_value, + sampled_actions.ndim, + sampled_actions.dtype, + sampled_actions.device, + ) + dummy_fill = _get_template( + dummy_template_cache, + dummy_action_value, + sampled_actions.ndim, + sampled_actions.dtype, + sampled_actions.device, + ) + step_actions = torch.where(mask, exit_fill, sampled_actions) + record_actions = torch.where(mask, dummy_fill, sampled_actions) + + step_res = env.step_tensor(current_states, step_actions) + current_states = step_res.next_states + sinks = step_res.is_sink_state + if sinks is None: + sinks = env.states_from_tensor(current_states).is_sink_state + + done_mask = done_mask | sinks + local_recorded_actions.append(record_actions) + action_template = record_actions.detach() + local_sinks.append(sinks) + local_states.append(current_states.clone()) + steps_taken += 1 + + return ( + current_states, + done_mask, + local_recorded_actions, + local_sinks, + local_states, + torch.tensor(steps_taken, device=current_states.device), + ) + + chunk_fn: Callable = _chunk_loop + chunk_fn_compiled = False + device_type = curr_states.device.type + compile_allowed = ( + hasattr(torch, "compile") + and device_type in ("cuda", "cpu") + and conditions is None + and not policy_kwargs + ) + cache_key = (id(env), device_type) + if compile_allowed: + cached = self._compiled_chunk_cache.get(cache_key) + if cached is not None: + chunk_fn = cached + chunk_fn_compiled = True + else: + try: + compiled = torch.compile(_chunk_loop, mode="reduce-overhead") # type: ignore[arg-type] + self._compiled_chunk_cache[cache_key] = compiled + chunk_fn = compiled + chunk_fn_compiled = True + except Exception as exc: # pragma: no cover - compile fallback + warnings.warn( + f"Compilation of diffusion chunk loop failed ({exc}); using eager version.", + stacklevel=2, + ) + chunk_fn = _chunk_loop + + while not bool(done.all().item()): + if chunk_fn_compiled: + _mark_cudagraph_step() + ( + curr_states, + done, + recorded_actions_chunk, + sinks_chunk, + states_chunk, + steps_taken_tensor, + ) = chunk_fn(curr_states, done) + steps_taken = int(steps_taken_tensor.item()) + if steps_taken: + recorded_actions_seq.extend(recorded_actions_chunk[:steps_taken]) + sink_seq.extend(sinks_chunk[:steps_taken]) + states_stack.extend(states_chunk[:steps_taken]) + + if recorded_actions_seq: + actions_tsr = torch.stack(recorded_actions_seq, dim=0) + states_tsr = torch.stack(states_stack, dim=0) + action_shape = getattr(env, "action_shape", None) + if action_shape: + tail_shape = tuple(actions_tsr.shape[-len(action_shape) :]) + if tail_shape != tuple(action_shape): + if tuple(action_shape) == (1,): + actions_tsr = actions_tsr.unsqueeze(-1) + else: + raise ValueError( + "ChunkedDiffusionSampler produced actions with shape " + f"{actions_tsr.shape}, expected trailing dims {action_shape}." + ) + T = actions_tsr.shape[0] + sinks_tsr = torch.stack(sink_seq, dim=0) + first_sink = torch.argmax(sinks_tsr.to(torch.long), dim=0) + never_sink = ~sinks_tsr.any(dim=0) + first_sink = torch.where( + never_sink, + torch.tensor(T - 1, device=curr_states.device), + first_sink, + ) + terminating_idx = first_sink + 1 + else: + states_tsr = states_obj.tensor.unsqueeze(0) + actions_tsr = env.actions_from_batch_shape((0, states_tsr.shape[1])).tensor + terminating_idx = torch.zeros( + states_tsr.shape[1], dtype=torch.long, device=curr_states.device + ) + return Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=conditions, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=False, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + + trajectories = Trajectories( + env=env, + states=env.states_from_tensor(states_tsr), + conditions=conditions, + actions=env.actions_from_tensor(actions_tsr), + terminating_idx=terminating_idx, + is_backward=False, + log_rewards=None, + log_probs=None, + estimator_outputs=None, + ) + return trajectories + + +class FastKHotDiscretePolicyEstimator(FastPolicyMixin, DiscretePolicyEstimator): + """Discrete forward policy with tensor-only helpers for HyperGrid.""" + + def __init__( + self, + env: HyperGrid, + module: torch.nn.Module, + preprocessor: KHotPreprocessor, + ) -> None: + super().__init__( + module=module, + n_actions=env.n_actions, + preprocessor=preprocessor, + is_backward=False, + ) + self.height = int(env.height) + self.ndim = int(env.ndim) + self.exit_idx = env.n_actions - 1 + + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + assert states_tensor.dtype == torch.long + sink_mask = states_tensor < 0 # HyperGrid sink state stores -1 in every dim. + safe_states = torch.where( + sink_mask, torch.zeros_like(states_tensor), states_tensor + ) + khot = torch.nn.functional.one_hot(safe_states, num_classes=self.height).to( + dtype=torch.get_default_dtype() + ) + if sink_mask.any(): + khot = khot * (~sink_mask).unsqueeze(-1).to(khot.dtype) + return khot.view(states_tensor.shape[0], -1) + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + epsilon: float = 0.0, + **policy_kwargs: Any, + ) -> torch.distributions.Categorical: + if states_tensor is None: + raise ValueError( + "states_tensor is required for FastKHotDiscretePolicyEstimator." + ) + + logits = self.module(features) + batch = states_tensor.shape[0] + masks = torch.zeros( + batch, + self.ndim + 1, + dtype=torch.bool, + device=states_tensor.device, + ) + masks[:, : self.ndim] = states_tensor < (self.height - 1) + masks[:, self.exit_idx] = True + + masked_logits = logits.masked_fill(~masks, float("-inf")) + probs = torch.softmax(masked_logits, dim=-1) + + if epsilon > 0.0: + valid_counts = masks.sum(dim=-1, keepdim=True).clamp_min(1) + uniform = masks.to(probs.dtype) / valid_counts.to(probs.dtype) + probs = (1.0 - epsilon) * probs + epsilon * uniform + + return torch.distributions.Categorical(probs=probs) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Compare baseline vs. fast-path HyperGrid training pipelines." + ) + parser.add_argument("--n-iterations", type=int, default=50, dest="n_iterations") + parser.add_argument("--batch-size", type=int, default=16, dest="batch_size") + parser.add_argument("--warmup-iters", type=int, default=25, dest="warmup_iters") + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--lr-logz", type=float, default=1e-1, dest="lr_logz") + parser.add_argument("--lr-logf", type=float, default=1e-3, dest="lr_logf") + parser.add_argument("--epsilon", type=float, default=0.0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--environments", + nargs="+", + choices=sorted(ENVIRONMENT_BENCHMARKS), + default=list(DEFAULT_ENV_ORDER), + help="Benchmark environments to include (e.g., hypergrid diffusion).", + ) + parser.add_argument( + "--validation-interval", type=int, default=100, dest="validation_interval" + ) + parser.add_argument( + "--validation-samples", type=int, default=200_000, dest="validation_samples" + ) + parser.add_argument( + "--device", + choices=["auto", "cpu", "mps", "cuda"], + default="auto", + help="Device to run on; auto prefers CUDA>MPS>CPU.", + ) + parser.add_argument( + "--benchmark-output", + type=str, + default=str(Path.home() / "hypergrid_benchmark.png"), + help="Output path for optional benchmark plot.", + ) + parser.add_argument( + "--skip-plot", + action="store_true", + help="Skip writing the benchmark plot (still prints the summary).", + ) + parser.add_argument( + "--gflownets", + nargs="+", + default=DEFAULT_FLOW_ORDER, + help="GFlowNet variants to benchmark (any of: tb, dbg, subtb).", + ) + parser.add_argument( + "--subtb-weighting", + choices=[ + "DB", + "ModifiedDB", + "TB", + "geometric", + "equal", + "equal_within", + ], + default="ModifiedDB", + dest="subtb_weighting", + help="Weighting strategy for SubTBGFlowNet runs.", + ) + parser.add_argument( + "--subtb-lamda", + type=float, + default=0.9, + dest="subtb_lamda", + help="Lambda discount factor for SubTBGFlowNet geometric weighting.", + ) + # Diffusion-specific knobs (ignored unless `diffusion` is selected). + parser.add_argument( + "--diffusion-target", + type=str, + default="gmm2", + help="Diffusion target alias (see DiffusionSampling.DIFFUSION_TARGETS).", + ) + parser.add_argument( + "--diffusion-dim", + type=int, + default=None, + help="Override target dimensionality when supported.", + ) + parser.add_argument( + "--diffusion-num-components", + type=int, + default=None, + help="Override mixture component count for Gaussian targets.", + ) + parser.add_argument( + "--diffusion-target-seed", + type=int, + default=2, + help="Seed controlling random targets (centers, covariances, etc.).", + ) + parser.add_argument( + "--diffusion-num-steps", + type=int, + default=32, + help="Number of discretization steps for the diffusion process.", + ) + parser.add_argument( + "--diffusion-sigma", + type=float, + default=5.0, + help="Pinned Brownian motion diffusion coefficient.", + ) + parser.add_argument( + "--diffusion-harmonics-dim", + type=int, + default=64, + help="Harmonics embedding dimension for DiffusionPISGradNetForward.", + ) + parser.add_argument( + "--diffusion-t-emb-dim", + type=int, + default=64, + help="Temporal embedding dimension for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-s-emb-dim", + type=int, + default=64, + help="State embedding dimension for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-hidden-dim", + type=int, + default=64, + help="Hidden dimension for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-joint-layers", + type=int, + default=2, + help="Joint layers count for diffusion forward model.", + ) + parser.add_argument( + "--diffusion-zero-init", + action="store_true", + help="Initialize diffusion forward model heads to zero.", + ) + return parser.parse_args() + + +def init_metrics() -> Dict[str, Any]: + return { + "validation_info": {"l1_dist": float("nan")}, + "discovered_modes": set(), + "total_steps": 0, + "measured_steps": 0, + } + + +def main() -> None: + args = parse_args() + device = resolve_device(args.device) + + flow_keys = _normalize_flow_keys(args.gflownets) + env_keys = _normalize_env_keys(args.environments) + if not flow_keys: + raise ValueError("At least one GFlowNet variant must be specified.") + + results: list[dict[str, Any]] = [] + grouped_results: dict[str, dict[str, list[dict[str, Any]]]] = {} + + for env_key in env_keys: + env_cfg = ENVIRONMENT_BENCHMARKS[env_key] + env_flow_keys = [ + flow_key for flow_key in flow_keys if flow_key in env_cfg.supported_flows + ] + if not env_flow_keys: + print( + f"\nSkipping environment '{env_cfg.label}' " + f"(no compatible flows among {', '.join(flow_keys)})." + ) + continue + + grouped_results.setdefault(env_key, {}) + print(f"\n### Environment: {env_cfg.label} ###\n" f"{env_cfg.description}\n") + + for flow_key in env_flow_keys: + flow_variant = FLOW_VARIANTS[flow_key] + grouped_results[env_key].setdefault(flow_key, []) + print( + f"\n=== GFlowNet Variant: {flow_variant.label} " + f"@ {env_cfg.label} ===\n{flow_variant.description}\n" + ) + for scenario in env_cfg.scenarios: + print( + f"\n--- Scenario: {scenario.name} | " + f"{flow_variant.label} ({env_cfg.label}) ---\n" + f"{scenario.description}\n" + ) + result = run_scenario(args, device, scenario, flow_variant, env_cfg) + result["label"] = scenario.name + result["description"] = scenario.description + result["env_key"] = env_cfg.key + result["env_label"] = env_cfg.label + results.append(result) + grouped_results[env_key][flow_key].append(result) + + print("\nBenchmark summary (speedups vs. per-environment baselines):") + for env_key in env_keys: + env_cfg = ENVIRONMENT_BENCHMARKS.get(env_key) + if env_cfg is None: + continue + env_flow_results = grouped_results.get(env_key, {}) + if not env_flow_results: + continue + + baseline_name = env_cfg.scenarios[0].name if env_cfg.scenarios else "baseline" + print(f"\n[{env_cfg.label}] scenario baseline = {baseline_name}") + + for flow_key, flow_results in env_flow_results.items(): + if not flow_results: + continue + flow_variant = FLOW_VARIANTS[flow_key] + baseline_candidate = next( + (res for res in flow_results if res.get("label") == baseline_name), + flow_results[0], + ) + baseline_time = baseline_candidate.get("elapsed", 0.0) or 1.0 + print( + f"\n - {flow_variant.label}: " + f"{baseline_time:.2f}s baseline ({baseline_candidate['label']})" + ) + for result in flow_results: + elapsed = result["elapsed"] + speedup = baseline_time / elapsed if elapsed else float("inf") + print( + f" • {result['label']}: {elapsed:.2f}s ({speedup:.2f}x) | " + f"compile={'yes' if result['use_compile'] else 'no'} | " + f"vmap={'yes' if result['use_vmap'] else 'no'} | " + f"sampler={result['sampler']}" + ) + + if not args.skip_plot: + plot_benchmark(results, args.benchmark_output) + + +def run_scenario( + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, + env_cfg: EnvironmentBenchmark, +) -> dict[str, Any]: + set_seed(args.seed) + ( + env, + gflownet, + sampler, + optimizer, + visited_states, + ) = build_training_components(args, device, scenario, flow_variant, env_cfg) + metrics = init_metrics() + use_vmap = scenario.use_vmap and flow_variant.supports_vmap + compiled_any = False + + if scenario.use_compile: + compile_results = try_compile_gflownet( + gflownet, + mode=DEFAULT_COMPILE_MODE, + ) + compiled_any = any(compile_results.values()) + formatted = ", ".join( + f"{name}:{'✓' if success else 'x'}" + for name, success in compile_results.items() + ) + print(f"[compile] {formatted}") + + if args.warmup_iters > 0: + run_iterations( + env, + gflownet, + sampler, + optimizer, + visited_states, + metrics, + args, + n_iters=args.warmup_iters, + use_vmap=use_vmap, + quiet=True, + collect_metrics=False, + track_time=False, + record_history=False, + supports_validation=env_cfg.supports_validation, + mark_compiled_step=compiled_any, + ) + + elapsed, history = run_iterations( + env, + gflownet, + sampler, + optimizer, + visited_states, + metrics, + args, + n_iters=args.n_iterations, + use_vmap=use_vmap, + quiet=False, + collect_metrics=True, + track_time=True, + record_history=True, + supports_validation=env_cfg.supports_validation, + mark_compiled_step=compiled_any, + ) + + validation_info = metrics["validation_info"] + l1 = validation_info.get("l1_dist", float("nan")) + modes_total = getattr(env, "n_modes", None) + if modes_total is None: + modes_str = "modes=n/a" + else: + modes_str = f"modes={len(metrics['discovered_modes'])} / {modes_total}" + if env_cfg.supports_validation: + validation_str = f"L1 distance={l1:.6f}" + else: + validation_str = "validation=skipped" + print( + f"Finished training ({env_cfg.label}) | " + f"iterations={metrics['measured_steps']} | " + f"{modes_str} | {validation_str}" + ) + + return { + "elapsed": elapsed or 0.0, + "losses": history["losses"] if history else None, + "iter_times": history["iter_times"] if history else None, + "use_compile": scenario.use_compile, + "use_vmap": use_vmap, + "sampler": scenario.sampler, + "gflownet_key": flow_variant.key, + "gflownet_label": flow_variant.label, + } + + +def run_iterations( + env: Env, + gflownet: PFBasedGFlowNet, + sampler: Sampler, + optimizer: torch.optim.Optimizer, + visited_states, + metrics: Dict[str, Any], + args: argparse.Namespace, + *, + n_iters: int, + use_vmap: bool, + quiet: bool, + collect_metrics: bool, + track_time: bool, + record_history: bool, + supports_validation: bool, + mark_compiled_step: bool = False, +) -> tuple[float | None, Dict[str, list[float]] | None]: + if n_iters <= 0: + empty_history = {"losses": [], "iter_times": []} if record_history else None + return (0.0 if track_time else None), empty_history + + iterator: Iterable[int] + if quiet: + iterator = range(n_iters) + else: + iterator = tqdm(range(n_iters), dynamic_ncols=True) + + start_time = time.perf_counter() if track_time else None + last_loss = 0.0 + losses_history: list[float] | None = [] if record_history else None + iter_time_history: list[float] | None = [] if record_history else None + + for _ in iterator: + iter_start = time.perf_counter() if (track_time or record_history) else None + if mark_compiled_step: + _mark_cudagraph_step() + trajectories = sampler.sample_trajectories( + env, + n=args.batch_size, + save_logprobs=False, + save_estimator_outputs=False, + epsilon=args.epsilon, + ) + + terminating_states = cast(States, trajectories.terminating_states) + visited_states.extend(terminating_states) + + optimizer.zero_grad() + loss = compute_loss(gflownet, env, trajectories, use_vmap=use_vmap) + loss.backward() + gflownet.assert_finite_gradients() + torch.nn.utils.clip_grad_norm_(gflownet.parameters(), 1.0) + optimizer.step() + gflownet.assert_finite_parameters() + + metrics["total_steps"] += 1 + if collect_metrics: + metrics["measured_steps"] += 1 + + last_loss = loss.item() + if ( + record_history + and losses_history is not None + and iter_time_history is not None + ): + losses_history.append(last_loss) + iter_duration = ( + (time.perf_counter() - iter_start) if iter_start is not None else 0.0 + ) + iter_time_history.append(iter_duration) + + if collect_metrics and supports_validation: + run_validation_if_needed( + cast(HyperGrid, env), + gflownet, + visited_states, + metrics, + args, + quiet=quiet, + ) + + if not quiet and isinstance(iterator, tqdm): + iterator.set_postfix( + { + "loss": last_loss, + "trajectories_sampled": ( + metrics["measured_steps"] * args.batch_size + ), + } + ) + + if track_time: + env_device = getattr(env, "device", torch.device("cpu")) + synchronize_if_needed(env_device) + assert start_time is not None + elapsed_time = time.perf_counter() - start_time + else: + elapsed_time = None + + history = None + if record_history and losses_history is not None and iter_time_history is not None: + history = { + "losses": losses_history, + "iter_times": iter_time_history, + } + + return elapsed_time, history + + +def compute_loss( + gflownet: PFBasedGFlowNet, + env: Env, + trajectories, + *, + use_vmap: bool, +) -> torch.Tensor: + if use_vmap: + if not isinstance(gflownet, TBGFlowNet): + raise ValueError("vmap trajectory balance loss requires a TBGFlowNet.") + return trajectory_balance_loss_vmap(cast(TBGFlowNet, gflownet), trajectories) + + return gflownet.loss_from_trajectories( + env, trajectories, recalculate_all_logprobs=False + ) + + +def trajectory_balance_loss_vmap( + gflownet: TBGFlowNet, + trajectories, +) -> torch.Tensor: + log_pf, log_pb = gflownet.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=False + ) + log_rewards = trajectories.log_rewards + if log_rewards is None: + raise ValueError("Log rewards required for TB loss.") + + def tb_residual( + log_pf_seq: torch.Tensor, log_pb_seq: torch.Tensor, log_reward: torch.Tensor + ) -> torch.Tensor: + return log_pf_seq.sum() - log_pb_seq.sum() - log_reward + + residuals = vmap(tb_residual)( + log_pf.transpose(0, 1), + log_pb.transpose(0, 1), + log_rewards, + ) + + log_z_value = gflownet.logZ + if not isinstance(log_z_value, torch.Tensor): + log_z_tensor = torch.as_tensor(log_z_value, device=residuals.device) + else: + log_z_tensor = log_z_value + log_z_tensor = log_z_tensor.squeeze() + scores = (residuals + log_z_tensor).pow(2) + + return scores.mean() + + +def run_validation_if_needed( + env: HyperGrid, + gflownet: PFBasedGFlowNet, + visited_states: DiscreteStates, + metrics: Dict[str, Any], + args: argparse.Namespace, + *, + quiet: bool, +) -> None: + if args.validation_interval <= 0: + return + measured_steps = metrics["measured_steps"] + if measured_steps == 0: + return + if measured_steps % args.validation_interval != 0: + return + + validation_info, _ = validate( + env, + gflownet, + args.validation_samples, + visited_states, + ) + metrics["validation_info"] = validation_info + modes_found = env.modes_found(visited_states) + metrics["discovered_modes"].update(modes_found) + + if not quiet: + str_info = ( + f"Iter {measured_steps}: " + f"L1 distance={validation_info.get('l1_dist', float('nan')):.8f} " + f"modes discovered={len(metrics['discovered_modes'])} / {env.n_modes} " + f"n terminating states {len(visited_states)}" + ) + print(str_info) + + +def build_training_components( + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, + env_cfg: EnvironmentBenchmark, +) -> tuple[Env, PFBasedGFlowNet, Sampler, torch.optim.Optimizer, States]: + if env_cfg.key == "hypergrid": + return _build_hypergrid_components(args, device, scenario, flow_variant) + if env_cfg.key == "diffusion": + return _build_diffusion_components(args, device, scenario, flow_variant) + raise ValueError(f"Unsupported environment key: {env_cfg.key}") + + +def _build_hypergrid_components( + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, +) -> tuple[HyperGrid, PFBasedGFlowNet, Sampler, torch.optim.Optimizer, DiscreteStates]: + env_kwargs = dict(HYPERGRID_KWARGS) + env_kwargs["device"] = device + EnvClass = HyperGridWithTensorStep if scenario.use_script_env else HyperGrid + env = EnvClass(**env_kwargs) + + preprocessor = KHotPreprocessor(height=env.height, ndim=env.ndim) + module_pf = MLP( + input_dim=preprocessor.output_dim, + output_dim=env.n_actions, + ) + module_pb = MLP( + input_dim=preprocessor.output_dim, + output_dim=env.n_actions - 1, + trunk=module_pf.trunk, + ) + + if scenario.sampler in {"compiled_chunk", "script_chunk"}: + pf_estimator = FastKHotDiscretePolicyEstimator(env, module_pf, preprocessor) + else: + pf_estimator = DiscretePolicyEstimator( + module_pf, env.n_actions, preprocessor=preprocessor, is_backward=False + ) + pb_estimator = DiscretePolicyEstimator( + module_pb, env.n_actions, preprocessor=preprocessor, is_backward=True + ) + + logF_estimator: ScalarEstimator | None = None + if flow_variant.requires_logf: + logF_module = MLP( + input_dim=preprocessor.output_dim, + output_dim=1, + ) + logF_estimator = ScalarEstimator(module=logF_module, preprocessor=preprocessor) + + if flow_variant.key == "tb": + gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + elif flow_variant.key == "dbg": + assert logF_estimator is not None + gflownet = DBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator) + elif flow_variant.key == "subtb": + assert logF_estimator is not None + gflownet = SubTBGFlowNet( + pf=pf_estimator, + pb=pb_estimator, + logF=logF_estimator, + weighting=args.subtb_weighting, + lamda=args.subtb_lamda, + ) + else: + raise ValueError(f"Unsupported GFlowNet variant: {flow_variant.key}") + + gflownet = gflownet.to(device) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + + logz_params = getattr(gflownet, "logz_parameters", None) + if callable(logz_params): + params = logz_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logz}) + + logf_params = getattr(gflownet, "logF_parameters", None) + if callable(logf_params): + params = logf_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logf}) + + if scenario.sampler == "compiled_chunk": + sampler: Sampler = CompiledChunkSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + elif scenario.sampler == "script_chunk": + sampler = ChunkedHyperGridSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + else: + sampler = Sampler(estimator=pf_estimator) + + visited_states = env.states_from_batch_shape((0,)) + return env, gflownet, sampler, optimizer, visited_states + + +def _build_diffusion_components( + args: argparse.Namespace, + device: torch.device, + scenario: ScenarioConfig, + flow_variant: FlowVariant, +) -> tuple[DiffusionSampling, PFBasedGFlowNet, Sampler, torch.optim.Optimizer, States]: + target_kwargs: dict[str, Any] = {"seed": args.diffusion_target_seed} + if args.diffusion_dim is not None: + target_kwargs["dim"] = args.diffusion_dim + if args.diffusion_num_components is not None: + target_kwargs["num_components"] = args.diffusion_num_components + + env = DiffusionSampling( + target_str=args.diffusion_target, + target_kwargs=target_kwargs, + num_discretization_steps=args.diffusion_num_steps, + device=device, + check_action_validity=False, + ) + + s_dim = env.dim + pf_module = DiffusionPISGradNetForward( + s_dim=s_dim, + harmonics_dim=args.diffusion_harmonics_dim, + t_emb_dim=args.diffusion_t_emb_dim, + s_emb_dim=args.diffusion_s_emb_dim, + hidden_dim=args.diffusion_hidden_dim, + joint_layers=args.diffusion_joint_layers, + zero_init=args.diffusion_zero_init, + ) + pb_module = DiffusionFixedBackwardModule(s_dim=s_dim) + + pf_estimator = PinnedBrownianMotionForward( + s_dim=s_dim, + pf_module=pf_module, + sigma=args.diffusion_sigma, + num_discretization_steps=args.diffusion_num_steps, + ) + pb_estimator = PinnedBrownianMotionBackward( + s_dim=s_dim, + pb_module=pb_module, + sigma=args.diffusion_sigma, + num_discretization_steps=args.diffusion_num_steps, + ) + + logF_estimator: ScalarEstimator | None = None + if flow_variant.requires_logf: + logF_module = MLP( + input_dim=env.state_shape[-1], + output_dim=1, + ) + logF_preprocessor = IdentityPreprocessor(output_dim=env.state_shape[-1]) + logF_estimator = ScalarEstimator( + module=logF_module, preprocessor=logF_preprocessor + ) + + if flow_variant.key == "tb": + gflownet: PFBasedGFlowNet = TBGFlowNet( + pf=pf_estimator, pb=pb_estimator, init_logZ=0.0 + ) + elif flow_variant.key == "dbg": + assert logF_estimator is not None + gflownet = DBGFlowNet( + pf=pf_estimator, + pb=pb_estimator, + logF=logF_estimator, + ) + elif flow_variant.key == "subtb": + assert logF_estimator is not None + gflownet = SubTBGFlowNet( + pf=pf_estimator, + pb=pb_estimator, + logF=logF_estimator, + weighting=args.subtb_weighting, + lamda=args.subtb_lamda, + ) + else: + raise ValueError( + f"Unsupported GFlowNet variant for diffusion: {flow_variant.key}" + ) + + gflownet = gflownet.to(device) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + + logz_params = getattr(gflownet, "logz_parameters", None) + if callable(logz_params): + params = logz_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logz}) + + logf_params = getattr(gflownet, "logF_parameters", None) + if callable(logf_params): + params = logf_params() + if params: + optimizer.add_param_group({"params": params, "lr": args.lr_logf}) + + if scenario.sampler == "compiled_chunk": + sampler: Sampler = CompiledChunkSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + elif scenario.sampler == "script_chunk": + sampler = ChunkedDiffusionSampler( + estimator=pf_estimator, chunk_size=DEFAULT_CHUNK_SIZE + ) + else: + sampler = Sampler(estimator=pf_estimator) + + visited_states = env.states_from_batch_shape((0,)) + return env, gflownet, sampler, optimizer, visited_states + + +def _mps_backend_available() -> bool: + backend = getattr(torch.backends, "mps", None) + return bool(backend and backend.is_available()) + + +def resolve_device(requested: str) -> torch.device: + """MPS backend is not supported for the Diffusion Sampling environment.""" + if requested == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + # if _mps_backend_available(): + # return torch.device("mps") + return torch.device("cpu") + + device = torch.device(requested) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available.") + if device.type == "mps" and not _mps_backend_available(): + raise RuntimeError("MPS requested but not available.") + return device + + +def synchronize_if_needed(device: torch.device) -> None: + if device.type == "cuda" and torch.cuda.is_available(): + torch.cuda.synchronize() + elif device.type == "mps" and _mps_backend_available() and hasattr(torch, "mps"): + torch.mps.synchronize() + + +def _summarize_iteration_times(times: list[float]) -> tuple[float, float]: + if not times: + return 0.0, 0.0 + mean_time = statistics.fmean(times) + std_time = statistics.pstdev(times) if len(times) > 1 else 0.0 + return mean_time, std_time + + +def _render_env_row( + row_axes, + env_results: list[Dict[str, Any]], + env_cfg: EnvironmentBenchmark | None, + palette: list[str], +) -> None: + env_label = env_cfg.label if env_cfg else env_results[0].get("env_label", "Env") + labels = [ + f"{res.get('label', f'Run {idx+1}')} [{res.get('gflownet_label', '?')}]" + for idx, res in enumerate(env_results) + ] + times = [res.get("elapsed", 0.0) for res in env_results] + bar_colors = [palette[i % len(palette)] for i in range(len(env_results))] + + baseline_name = env_cfg.scenarios[0].name if env_cfg and env_cfg.scenarios else None + + # Determine per-flow baselines (default to the baseline scenario if present, else first run). + flow_baselines: dict[str, float] = {} + for res in env_results: + flow_key = res.get("gflownet_key") + if flow_key is None or flow_key in flow_baselines: + continue + if baseline_name is not None and res.get("label") == baseline_name: + flow_baselines[flow_key] = res.get("elapsed", 0.0) or 0.0 + for res in env_results: + flow_key = res.get("gflownet_key") + if flow_key is None: + continue + flow_baselines.setdefault(flow_key, res.get("elapsed", 0.0) or 0.0) + + bars = row_axes[0].bar(labels, times, color=bar_colors) + row_axes[0].set_ylabel("Wall-clock time (s)") + row_axes[0].set_title(f"{env_label} | Total Training Time") + + for bar, value, res in zip(bars, times, env_results): + if value == 0.0: + continue + flow_key = res.get("gflownet_key", "") + flow_baseline = flow_baselines.get(flow_key, value) or value + pct_speedup = ( + (flow_baseline / value - 1.0) * 100.0 if value > 0.0 else float("inf") + ) + row_axes[0].text( + bar.get_x() + bar.get_width() / 2, + value, + f"{value:.2f}s\n{pct_speedup:+.1f}%", + ha="center", + va="bottom", + color="black", + fontsize=9, + ) + + # Subplot 2: training curves + loss_ax = row_axes[1] + for idx, res in enumerate(env_results): + losses = res.get("losses") or [] + if not losses: + continue + variant_key = res.get("gflownet_key", "") + scenario_label = res.get("label", "") + color = VARIANT_COLORS.get(variant_key, palette[idx % len(palette)]) + linestyle = SCENARIO_LINESTYLES.get(scenario_label, "-") + loss_ax.plot( + range(1, len(losses) + 1), + losses, + label=labels[idx], + color=color, + linestyle=linestyle, + linewidth=2.0, + alpha=LOSS_LINE_ALPHA, + ) + loss_ax.set_title(f"{env_label} | Training Loss") + loss_ax.set_xlabel("Iteration") + loss_ax.set_ylabel("Loss") + if loss_ax.lines: + loss_ax.legend(fontsize="small") + + # Subplot 3: per-iteration timing with error bars + + iter_ax = row_axes[2] + iter_stats = [ + _summarize_iteration_times(res.get("iter_times") or []) for res in env_results + ] + means_ms = [mean * 1000.0 for mean, _ in iter_stats] + stds_ms = [std * 1000.0 for _, std in iter_stats] + iter_ax.bar( + labels, + means_ms, + yerr=stds_ms, + capsize=6, + color=bar_colors, + ) + iter_ax.set_ylabel("Per-iteration time (ms)") + iter_ax.set_title(f"{env_label} | Iteration Timing") + + for ax in row_axes: + for label in ax.get_xticklabels(): + label.set_rotation(30) + label.set_ha("right") + + +def plot_benchmark(results: list[Dict[str, Any]], output_path: str) -> None: + try: + import matplotlib.pyplot as plt + except ImportError as exc: + raise RuntimeError( + "matplotlib is required for plotting; install it or omit --benchmark." + ) from exc + + if not results: + print("No benchmark results to plot.") + return + + env_order: list[str] = [] + for res in results: + env_key = res.get("env_key", "unknown") + if env_key not in env_order: + env_order.append(env_key) + + n_rows = max(1, len(env_order)) + fig, axes = plt.subplots(n_rows, 3, figsize=(20, 5 * n_rows)) + if n_rows == 1: + axes = [axes] # type: ignore[list-item] + + palette = ["#6c757d", "#1f77b4", "#2ca02c", "#d62728", "#9467bd", "#8c564b"] + + for row_idx, env_key in enumerate(env_order): + env_results = [res for res in results if res.get("env_key") == env_key] + if not env_results: + continue + + env_cfg = ENVIRONMENT_BENCHMARKS.get(env_key) + row_axes = axes[row_idx] + _render_env_row(row_axes, env_results, env_cfg, palette) + + output = Path(output_path) + output.parent.mkdir(parents=True, exist_ok=True) + fig.tight_layout() + fig.savefig(output, dpi=150) + plt.close(fig) + print(f"Saved benchmark plot to {output}") + + +if __name__ == "__main__": + main() diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 4a5d492d..5fbfc8f8 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -7,7 +7,7 @@ from torch.distributions.independent import Independent from tqdm import trange -from gfn.estimators import Estimator, PolicyMixin +from gfn.estimators import Estimator, FastPolicyMixin from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet from gfn.gym.line import Line from gfn.states import States @@ -168,13 +168,14 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: return out -class StepEstimator(Estimator, PolicyMixin): +class StepEstimator(FastPolicyMixin, Estimator): """Estimator for PF and PB of the Line environment.""" def __init__(self, env: Line, module: torch.nn.Module, backward: bool): super().__init__(module, is_backward=backward) self.backward = backward self.n_steps_per_trajectory = env.n_steps_per_trajectory + self.env = env @property def expected_output_dim(self) -> int: @@ -207,6 +208,31 @@ def to_probability_distribution( n_steps=self.n_steps_per_trajectory, ) + def fast_features( + self, + states_tensor: torch.Tensor, + *, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + conditions: torch.Tensor | None = None, + ) -> torch.Tensor: + return states_tensor + + def fast_distribution( + self, + features: torch.Tensor, + *, + states_tensor: torch.Tensor | None = None, + forward_masks: torch.Tensor | None = None, + backward_masks: torch.Tensor | None = None, + **policy_kwargs, + ) -> Distribution: + if states_tensor is None: + raise ValueError("states_tensor is required for StepEstimator fast path.") + module_output = self.module(features) + states = self.env.states_from_tensor_fast(states_tensor) + return self.to_probability_distribution(states, module_output, **policy_kwargs) + def train( gflownet, diff --git a/tutorials/notebooks/torch_compile_discrete_states.ipynb b/tutorials/notebooks/torch_compile_discrete_states.ipynb new file mode 100644 index 00000000..84b993ee --- /dev/null +++ b/tutorials/notebooks/torch_compile_discrete_states.ipynb @@ -0,0 +1,4889 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# torch.compile with `DiscreteStates`\n", + "\n", + "This short experiment shows that a `DiscreteStates` wrapper can safely flow through `torch.compile`. We instantiate a simple environment, grab its states/actions, and compare the eager and compiled results of a single `_step` call.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jdv/code/torchgfn/src/gfn/gym/hypergrid.py:90: UserWarning: + Warning: height <= 4 can lead to unsolvable environments.\n", + " warnings.warn(\"+ Warning: height <= 4 can lead to unsolvable environments.\")\n", + "/Users/jdv/code/torchgfn/src/gfn/env.py:495: UserWarning: You're using advanced parameters: (sf). These are only needed for custom action handling. For basic environments, you can omit these.\n", + " warnings.warn(\n", + "[W1126 23:25:08.072794000 unwind.cpp:12] Warning: record_context_cpp is not support on non-linux non-x86_64 platforms (function operator())\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] Graph break from `Tensor.item()`, consider setting:\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] torch._dynamo.config.capture_scalar_outputs = True\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] or:\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] to include these operations in the captured graph.\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] \n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] Graph break: from user code at:\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] File \"/Users/jdv/code/torchgfn/src/gfn/env.py\", line 309, in torch_dynamo_resume_in__step_at_307\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] if not self.is_action_valid(valid_states, valid_actions):\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] File \"/Users/jdv/code/torchgfn/src/gfn/env.py\", line 643, in is_action_valid\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] return bool(torch.gather(masks_tensor, 1, actions.tensor).all().item())\n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] \n", + "W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Outputs match: True\n", + "Output device: cpu\n", + "Example compiled output:\n", + " tensor([[0, 1],\n", + " [0, 1],\n", + " [0, 1],\n", + " [0, 1]])\n" + ] + } + ], + "source": [ + "import torch\n", + "from gfn.gym.hypergrid import HyperGrid\n", + "\n", + "# Resolve device (CUDA if available, else CPU)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "# Instantiate a small environment and grab states/actions.\n", + "env = HyperGrid(ndim=2, height=4, device=device)\n", + "states = env.reset(batch_shape=4)\n", + "actions = env.actions_from_batch_shape((4,))\n", + "actions.tensor = torch.ones((4, 1), dtype=torch.long, device=device)\n", + "\n", + "# Define a helper that takes raw tensors, rebuilds the wrappers, and returns the step result.\n", + "def step_once(states_tensor: torch.Tensor, actions_tensor: torch.Tensor) -> torch.Tensor:\n", + " s = env.States(states_tensor)\n", + " a = env.Actions(actions_tensor)\n", + " return env._step(s, a).tensor\n", + "\n", + "compiled_step = torch.compile(step_once, dynamic=True)\n", + "\n", + "eager_out = step_once(states.tensor, actions.tensor)\n", + "compiled_out = compiled_step(states.tensor, actions.tensor)\n", + "\n", + "print(\"Outputs match:\", torch.equal(eager_out, compiled_out))\n", + "print(\"Output device:\", compiled_out.device)\n", + "print(\"Example compiled output:\\n\", compiled_out)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Microbenchmark harness\n", + "\n", + "The cells below build a small timing helper so we can compare `step_once` in eager mode vs the `torch.compile(..., dynamic=True)` variant under identical inputs. We run everything on CPU for consistency.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import statistics\n", + "import warnings\n", + "from typing import Callable, Dict\n", + "\n", + "import torch.utils.benchmark as benchmark\n", + "\n", + "\n", + "def _sync_if_needed() -> None:\n", + " if torch.cuda.is_available():\n", + " torch.cuda.synchronize()\n", + "\n", + "\n", + "def benchmark_step_fn(\n", + " step_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n", + " label: str,\n", + " states_tensor: torch.Tensor,\n", + " actions_tensor: torch.Tensor,\n", + " *,\n", + " iters: int = 200,\n", + ") -> Dict[str, float]:\n", + " \"\"\"Time repeated calls to `step_fn` under identical inputs.\"\"\"\n", + "\n", + " torch.manual_seed(0)\n", + " if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(0)\n", + "\n", + " warmup_iters = max(5, iters // 10)\n", + " for _ in range(warmup_iters):\n", + " step_fn(states_tensor, actions_tensor)\n", + " _sync_if_needed()\n", + "\n", + " timer = benchmark.Timer(\n", + " stmt=\"fn(states_tensor, actions_tensor)\",\n", + " globals={\n", + " \"fn\": step_fn,\n", + " \"states_tensor\": states_tensor,\n", + " \"actions_tensor\": actions_tensor,\n", + " },\n", + " label=label,\n", + " sub_label=f\"device={states_tensor.device}\",\n", + " description=\"step_once microbenchmark\",\n", + " )\n", + " result = timer.timeit(iters)\n", + " std_ms = statistics.pstdev(result.raw_times) * 1000 if result.raw_times else float(\"nan\")\n", + " run_count = len(result.raw_times) if result.raw_times else iters\n", + " return {\n", + " \"label\": label,\n", + " \"mean_ms\": result.mean * 1000,\n", + " \"std_ms\": std_ms,\n", + " \"iters\": run_count,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'Eager step_once',\n", + " 'mean_ms': 0.10937216250458733,\n", + " 'std_ms': 0.0,\n", + " 'iters': 1},\n", + " {'label': 'torch.compile(step_once)',\n", + " 'mean_ms': 0.3456614166498184,\n", + " 'std_ms': 0.0,\n", + " 'iters': 1}]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "benchmark_iters = 20000\n", + "results = []\n", + "\n", + "results.append(\n", + " benchmark_step_fn(\n", + " step_once,\n", + " label=\"Eager step_once\",\n", + " states_tensor=states.tensor,\n", + " actions_tensor=actions.tensor,\n", + " iters=benchmark_iters,\n", + " )\n", + ")\n", + "\n", + "with warnings.catch_warnings(record=True) as caught:\n", + " warnings.simplefilter(\"always\")\n", + " results.append(\n", + " benchmark_step_fn(\n", + " compiled_step,\n", + " label=\"torch.compile(step_once)\",\n", + " states_tensor=states.tensor,\n", + " actions_tensor=actions.tensor,\n", + " iters=benchmark_iters,\n", + " )\n", + " )\n", + " compile_warning_messages = sorted({str(w.message) for w in caught})\n", + "\n", + "results\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mode mean (ms) std (ms) iters\n", + "-----------------------------------------------------------------\n", + "Eager step_once 0.1094 0.0000 1\n", + "torch.compile(step_once) 0.3457 0.0000 1\n", + "\n", + "Speedup (eager / compiled): 0.316x\n", + "\n", + "Dynamo summary -> Graphs: 12, Graph breaks: 11, Break reasons: ['Dynamic shape operator', 'Unsupported Tensor.item() call with capture_scalar_outputs=False']\n", + "\n", + "Warnings during compiled execution: none captured\n" + ] + } + ], + "source": [ + "import torch._dynamo as dynamo\n", + "\n", + "\n", + "def _format_results(rows):\n", + " header = f\"{'Mode':<30} {'mean (ms)':>12} {'std (ms)':>12} {'iters':>8}\"\n", + " lines = [header, \"-\" * len(header)]\n", + " for row in rows:\n", + " lines.append(\n", + " f\"{row['label']:<30} {row['mean_ms']:>12.4f} {row['std_ms']:>12.4f} {row['iters']:>8d}\"\n", + " )\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "def _extract_count(report: str, prefix: str) -> int:\n", + " for line in report.splitlines():\n", + " if line.startswith(prefix):\n", + " return int(line.split(\":\", 1)[1].strip())\n", + " return -1\n", + "\n", + "\n", + "print(_format_results(results))\n", + "\n", + "eager_mean = next(r for r in results if r[\"label\"] == \"Eager step_once\")[\"mean_ms\"]\n", + "compiled_mean = next(r for r in results if \"torch.compile\" in r[\"label\"])[\"mean_ms\"]\n", + "speedup = eager_mean / compiled_mean if compiled_mean else float(\"nan\")\n", + "print(f\"\\nSpeedup (eager / compiled): {speedup:.3f}x\")\n", + "\n", + "compiled_report = dynamo.explain(step_once)(states.tensor, actions.tensor)\n", + "compiled_report_text = str(compiled_report)\n", + "\n", + "graph_count = _extract_count(compiled_report_text, \"Graph Count\")\n", + "graph_breaks = _extract_count(compiled_report_text, \"Graph Break Count\")\n", + "break_reasons = sorted(\n", + " {\n", + " line.strip().split(\":\", 1)[1].strip()\n", + " for line in compiled_report_text.splitlines()\n", + " if line.strip().startswith(\"Reason:\")\n", + " }\n", + ")\n", + "\n", + "print(\n", + " f\"\\nDynamo summary -> Graphs: {graph_count}, Graph breaks: {graph_breaks}, \"\n", + " f\"Break reasons: {break_reasons or ['None']}\"\n", + ")\n", + "\n", + "if compile_warning_messages:\n", + " print(\"\\nWarnings during compiled execution:\")\n", + " for msg in compile_warning_messages:\n", + " print(f\" - {msg}\")\n", + "else:\n", + " print(\"\\nWarnings during compiled execution: none captured\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Full GFlowNet benchmark\n", + "\n", + "The cell below reuses `train_hypergrid_optimized.py`'s benchmarking entry-point so we can time a larger training loop (Baseline vs compiled) directly from this notebook.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calculated tensor of all states in 0.0009723345438639323 minutes\n", + "+ Environment has 1024 states\n", + "+ Environment log partition is 5.711750507354736\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jdv/code/torchgfn/src/gfn/env.py:495: UserWarning: You're using advanced parameters: (sf). These are only needed for custom action handling. For basic environments, you can omit these.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calculated tensor of all states in 0.0007189313570658366 minutes\n", + "+ Environment has 1024 states\n", + "+ Environment log partition is 5.711750507354736\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'elapsed': 6.105575958034024,\n", + " 'losses': [7.572955131530762,\n", + " 2.491760015487671,\n", + " 3.2979516983032227,\n", + " 3.0194754600524902,\n", + " 0.9618801474571228,\n", + " 0.8673862218856812,\n", + " 1.6195870637893677,\n", + " 0.5269477367401123,\n", + " 1.0297844409942627,\n", + " 1.332466959953308,\n", + " 0.6973802447319031,\n", + " 1.6610175371170044,\n", + " 0.47196799516677856,\n", + " 1.980200171470642,\n", + " 2.484879970550537,\n", + " 1.153307557106018,\n", + " 0.5622124671936035,\n", + " 0.7249877452850342,\n", + " 1.2468236684799194,\n", + " 1.9157636165618896,\n", + " 1.5802578926086426,\n", + " 0.9950146675109863,\n", + " 0.9827088713645935,\n", + " 0.9594094753265381,\n", + " 2.0273141860961914,\n", + " 1.0678741931915283,\n", + " 1.7654989957809448,\n", + " 1.8363938331604004,\n", + " 0.5704580545425415,\n", + " 2.0948450565338135,\n", + " 0.8548241853713989,\n", + " 4.518639087677002,\n", + " 1.0827535390853882,\n", + " 1.2317500114440918,\n", + " 0.6395683288574219,\n", + " 1.3933279514312744,\n", + " 1.7131190299987793,\n", + " 1.1856663227081299,\n", + " 1.428055763244629,\n", + " 0.8084158897399902,\n", + " 0.37907153367996216,\n", + " 1.583935260772705,\n", + " 2.161365270614624,\n", + " 1.4849199056625366,\n", + " 1.6980212926864624,\n", + " 0.4082474708557129,\n", + " 1.0781633853912354,\n", + " 0.6617383360862732,\n", + " 0.8540241718292236,\n", + " 0.7804931998252869,\n", + " 1.5323201417922974,\n", + " 1.175217628479004,\n", + " 0.4573594331741333,\n", + " 1.7341632843017578,\n", + " 1.3420581817626953,\n", + " 0.6467706561088562,\n", + " 0.7274094223976135,\n", + " 0.4892826974391937,\n", + " 0.4271280765533447,\n", + " 1.8384681940078735,\n", + " 0.7235350608825684,\n", + " 1.1163365840911865,\n", + " 1.350170612335205,\n", + " 0.42495375871658325,\n", + " 1.4814311265945435,\n", + " 0.9633447527885437,\n", + " 0.7441744804382324,\n", + " 0.49172350764274597,\n", + " 0.8439239859580994,\n", + " 1.7822604179382324,\n", + " 2.700016975402832,\n", + " 1.2268513441085815,\n", + " 1.689005732536316,\n", + " 0.8238610029220581,\n", + " 1.1699678897857666,\n", + " 0.5300710201263428,\n", + " 0.3418184518814087,\n", + " 0.6137208342552185,\n", + " 0.7867594957351685,\n", + " 2.185892343521118,\n", + " 0.521285355091095,\n", + " 1.2726870775222778,\n", + " 0.450527161359787,\n", + " 1.4032469987869263,\n", + " 0.4590965509414673,\n", + " 0.4651802182197571,\n", + " 0.8082336187362671,\n", + " 0.7147279977798462,\n", + " 0.5436726212501526,\n", + " 1.4215888977050781,\n", + " 1.2204718589782715,\n", + " 0.25053414702415466,\n", + " 1.3869221210479736,\n", + " 1.0469372272491455,\n", + " 1.4329015016555786,\n", + " 0.6535708904266357,\n", + " 1.5620619058609009,\n", + " 2.5244903564453125,\n", + " 0.7888399362564087,\n", + " 5.477772235870361,\n", + " 2.438631057739258,\n", + " 1.3707287311553955,\n", + " 0.9398314952850342,\n", + " 0.46227526664733887,\n", + " 0.5834711790084839,\n", + " 1.6012723445892334,\n", + " 1.7451685667037964,\n", + " 0.9019611477851868,\n", + " 0.7608010172843933,\n", + " 1.2909115552902222,\n", + " 0.4646669626235962,\n", + " 0.4738759398460388,\n", + " 0.7258164882659912,\n", + " 1.7222124338150024,\n", + " 2.4299848079681396,\n", + " 2.2367324829101562,\n", + " 2.0737838745117188,\n", + " 0.5498713254928589,\n", + " 1.791583776473999,\n", + " 1.1461840867996216,\n", + " 0.2634006142616272,\n", + " 0.8089989423751831,\n", + " 0.283713698387146,\n", + " 0.14579910039901733,\n", + " 0.8403635621070862,\n", + " 0.5281050205230713,\n", + " 0.3584972620010376,\n", + " 0.6050671935081482,\n", + " 0.4479628801345825,\n", + " 0.5756915211677551,\n", + " 1.256803035736084,\n", + " 0.7478235960006714,\n", + " 0.375349760055542,\n", + " 0.35438239574432373,\n", + " 1.180328369140625,\n", + " 1.3136284351348877,\n", + " 3.27740478515625,\n", + " 1.0790156126022339,\n", + " 1.540788173675537,\n", + " 1.0326931476593018,\n", + " 0.9449985027313232,\n", + " 3.155139684677124,\n", + " 0.9995787143707275,\n", + " 0.7784771919250488,\n", + " 1.443956732749939,\n", + " 0.8618024587631226,\n", + " 0.3689582049846649,\n", + " 0.4708964228630066,\n", + " 1.133431315422058,\n", + " 1.1145482063293457,\n", + " 0.4921965003013611,\n", + " 0.415180504322052,\n", + " 1.5828590393066406,\n", + " 2.7756614685058594,\n", + " 0.6000064611434937,\n", + " 0.7194350957870483,\n", + " 0.8013563752174377,\n", + " 1.2213236093521118,\n", + " 1.1368153095245361,\n", + " 1.3761565685272217,\n", + " 0.346245139837265,\n", + " 0.3820178508758545,\n", + " 0.4246070384979248,\n", + " 0.5360602140426636,\n", + " 0.6117574572563171,\n", + " 1.0365396738052368,\n", + " 0.191411092877388,\n", + " 0.6832408905029297,\n", + " 1.0592424869537354,\n", + " 0.6425431966781616,\n", + " 0.498931348323822,\n", + " 0.774171769618988,\n", + " 0.32929566502571106,\n", + " 0.42261868715286255,\n", + " 0.3470837473869324,\n", + " 0.4950379729270935,\n", + " 0.5027860403060913,\n", + " 0.35800668597221375,\n", + " 1.288243293762207,\n", + " 0.6900274753570557,\n", + " 1.4558058977127075,\n", + " 1.1142147779464722,\n", + " 0.2911098003387451,\n", + " 0.7661944031715393,\n", + " 1.0826146602630615,\n", + " 1.19940984249115,\n", + " 0.884093701839447,\n", + " 0.5238901972770691,\n", + " 0.6807741522789001,\n", + " 0.5270069241523743,\n", + " 0.43598586320877075,\n", + " 0.31679433584213257,\n", + " 0.7662327885627747,\n", + " 0.4052656292915344,\n", + " 0.4683819115161896,\n", + " 0.4934506416320801,\n", + " 0.17495952546596527,\n", + " 0.5440036654472351,\n", + " 0.5274096131324768,\n", + " 0.6581551432609558],\n", + " 'iter_times': [0.029843291034922004,\n", + " 0.02665925002656877,\n", + " 0.026591749861836433,\n", + " 0.02578529203310609,\n", + " 0.0249330410733819,\n", + " 0.02374075003899634,\n", + " 0.02827424998395145,\n", + " 0.016017750138416886,\n", + " 0.025069749914109707,\n", + " 0.02290912508033216,\n", + " 0.01798625010997057,\n", + " 0.022280167089775205,\n", + " 0.021828250028192997,\n", + " 0.025819166796281934,\n", + " 0.023717750096693635,\n", + " 0.027398916892707348,\n", + " 0.0268092080950737,\n", + " 0.02269370900467038,\n", + " 0.028668625047430396,\n", + " 0.025364749832078815,\n", + " 0.026112499879673123,\n", + " 0.03050654218532145,\n", + " 0.026381792034953833,\n", + " 0.02492345799691975,\n", + " 0.024277166929095984,\n", + " 0.025802375050261617,\n", + " 0.023120166966691613,\n", + " 0.022838874952867627,\n", + " 0.02152595785446465,\n", + " 0.02243420807644725,\n", + " 0.026837416924536228,\n", + " 0.02476345794275403,\n", + " 0.027901625027880073,\n", + " 0.02595195802859962,\n", + " 0.029563874937593937,\n", + " 0.02456283406354487,\n", + " 0.023599833017215133,\n", + " 0.027492624940350652,\n", + " 0.022965540876612067,\n", + " 0.02508420799858868,\n", + " 0.028139542089775205,\n", + " 0.025828625075519085,\n", + " 0.028102665906772017,\n", + " 0.026154458988457918,\n", + " 0.027773250127211213,\n", + " 0.025903834030032158,\n", + " 0.028939666924998164,\n", + " 0.024504374945536256,\n", + " 0.02638370799832046,\n", + " 0.025218249997124076,\n", + " 0.023709542118012905,\n", + " 0.03099012514576316,\n", + " 0.021817500004544854,\n", + " 0.025243084179237485,\n", + " 0.031247250037267804,\n", + " 0.026461832923814654,\n", + " 0.02370858401991427,\n", + " 0.02921529207378626,\n", + " 0.02194329211488366,\n", + " 0.02969479188323021,\n", + " 0.02624008315615356,\n", + " 0.025315249804407358,\n", + " 0.030012167058885098,\n", + " 0.032295542070642114,\n", + " 0.029138999991118908,\n", + " 0.02831891691312194,\n", + " 0.02835141704417765,\n", + " 0.02719200006686151,\n", + " 0.027825874974951148,\n", + " 0.024933042004704475,\n", + " 0.03004787489771843,\n", + " 0.0257573330309242,\n", + " 0.02414166694507003,\n", + " 0.028785540955141187,\n", + " 0.029787667095661163,\n", + " 0.03012562496587634,\n", + " 0.0186317500192672,\n", + " 0.026522708125412464,\n", + " 0.025528999976813793,\n", + " 0.02337670815177262,\n", + " 0.028197707841172814,\n", + " 0.028563749976456165,\n", + " 0.023193708853796124,\n", + " 0.023533583153039217,\n", + " 0.023977917153388262,\n", + " 0.025529957842081785,\n", + " 0.025657958118245006,\n", + " 0.028129874961450696,\n", + " 0.0250042078550905,\n", + " 0.024240750120952725,\n", + " 0.02651733416132629,\n", + " 0.03236395795829594,\n", + " 0.03119733394123614,\n", + " 0.023247458040714264,\n", + " 0.02916412497870624,\n", + " 0.03493220801465213,\n", + " 0.026215665973722935,\n", + " 0.0358433339279145,\n", + " 0.03177366708405316,\n", + " 0.03824399993754923,\n", + " 0.03338041715323925,\n", + " 0.034712209133431315,\n", + " 0.03465991700068116,\n", + " 0.02921774983406067,\n", + " 0.02734670788049698,\n", + " 0.032816499937325716,\n", + " 0.02972904103808105,\n", + " 0.032379542011767626,\n", + " 0.03371258289553225,\n", + " 0.03127762512303889,\n", + " 0.02719833399169147,\n", + " 0.025000832974910736,\n", + " 0.03524758294224739,\n", + " 0.03120354190468788,\n", + " 0.03501374996267259,\n", + " 0.03635912504978478,\n", + " 0.0354501660913229,\n", + " 0.03287862497381866,\n", + " 0.029975875047966838,\n", + " 0.03692404204048216,\n", + " 0.02574570896103978,\n", + " 0.03072458296082914,\n", + " 0.03142145904712379,\n", + " 0.034465207951143384,\n", + " 0.032932084053754807,\n", + " 0.03766591614112258,\n", + " 0.031669416930526495,\n", + " 0.03097916697151959,\n", + " 0.024389415979385376,\n", + " 0.026578041957691312,\n", + " 0.028716041008010507,\n", + " 0.032391542103141546,\n", + " 0.03359537501819432,\n", + " 0.029333041980862617,\n", + " 0.03852045815438032,\n", + " 0.03522874996997416,\n", + " 0.039978290908038616,\n", + " 0.03800995904020965,\n", + " 0.03813070897012949,\n", + " 0.03231212496757507,\n", + " 0.039890124928206205,\n", + " 0.03974391706287861,\n", + " 0.040384375024586916,\n", + " 0.03584462497383356,\n", + " 0.03564045880921185,\n", + " 0.03651641705073416,\n", + " 0.037402458023279905,\n", + " 0.03648191690444946,\n", + " 0.03885470796376467,\n", + " 0.03411237499676645,\n", + " 0.03694487502798438,\n", + " 0.02758374996483326,\n", + " 0.03918912494555116,\n", + " 0.03915116610005498,\n", + " 0.03766429191455245,\n", + " 0.034370541106909513,\n", + " 0.03439179202541709,\n", + " 0.03841937496326864,\n", + " 0.039793833857402205,\n", + " 0.03862112481147051,\n", + " 0.03574187494814396,\n", + " 0.03246379108168185,\n", + " 0.036486207973212004,\n", + " 0.03896020818501711,\n", + " 0.03466787491925061,\n", + " 0.0354886669665575,\n", + " 0.03665566607378423,\n", + " 0.03877758304588497,\n", + " 0.03936316608451307,\n", + " 0.04036858305335045,\n", + " 0.03711925004608929,\n", + " 0.03579937503673136,\n", + " 0.0394124158192426,\n", + " 0.03487608302384615,\n", + " 0.033077584113925695,\n", + " 0.03582812496460974,\n", + " 0.03510708408430219,\n", + " 0.03271691710688174,\n", + " 0.024238749872893095,\n", + " 0.0317631671205163,\n", + " 0.037570209009572864,\n", + " 0.037720915861427784,\n", + " 0.03836658294312656,\n", + " 0.03753329208120704,\n", + " 0.03856374998576939,\n", + " 0.03874720796011388,\n", + " 0.038331500021740794,\n", + " 0.03494950011372566,\n", + " 0.03934533311985433,\n", + " 0.034773917170241475,\n", + " 0.0353236251976341,\n", + " 0.03172524995170534,\n", + " 0.03164450009353459,\n", + " 0.02425049990415573,\n", + " 0.03322154190391302,\n", + " 0.036195874912664294,\n", + " 0.03410979197360575,\n", + " 0.03734816703945398,\n", + " 0.039303625002503395,\n", + " 0.03886774997226894],\n", + " 'compile_mode': 'none',\n", + " 'use_compile': False,\n", + " 'requested_vmap': False,\n", + " 'effective_vmap': False,\n", + " 'chunk_size_effective': 0,\n", + " 'label': 'Eager'},\n", + " {'elapsed': 6.3454663751181215,\n", + " 'losses': [7.57296085357666,\n", + " 2.4917588233947754,\n", + " 3.297954797744751,\n", + " 3.0194761753082275,\n", + " 0.9618798494338989,\n", + " 0.8673862218856812,\n", + " 1.61958646774292,\n", + " 0.5269474983215332,\n", + " 1.0297842025756836,\n", + " 1.332466721534729,\n", + " 0.6973803639411926,\n", + " 1.6610198020935059,\n", + " 0.471968412399292,\n", + " 1.9802007675170898,\n", + " 2.4848790168762207,\n", + " 1.1533082723617554,\n", + " 0.5622122883796692,\n", + " 0.7249880433082581,\n", + " 1.2468247413635254,\n", + " 1.9157638549804688,\n", + " 1.580257773399353,\n", + " 0.9950148463249207,\n", + " 0.9827099442481995,\n", + " 0.9594108462333679,\n", + " 2.0273146629333496,\n", + " 1.0678751468658447,\n", + " 1.7654999494552612,\n", + " 1.8363943099975586,\n", + " 0.5704579949378967,\n", + " 2.0948455333709717,\n", + " 0.8548235893249512,\n", + " 4.518638610839844,\n", + " 1.0827548503875732,\n", + " 1.2317492961883545,\n", + " 0.6395676732063293,\n", + " 1.3933277130126953,\n", + " 1.7131195068359375,\n", + " 1.1856666803359985,\n", + " 1.4280558824539185,\n", + " 0.8084155917167664,\n", + " 0.3790717124938965,\n", + " 1.5839354991912842,\n", + " 2.1613659858703613,\n", + " 1.4849202632904053,\n", + " 1.6980226039886475,\n", + " 0.4082470238208771,\n", + " 1.0781641006469727,\n", + " 0.6617385149002075,\n", + " 0.8540250062942505,\n", + " 0.7804922461509705,\n", + " 1.532320499420166,\n", + " 1.1752188205718994,\n", + " 0.45735907554626465,\n", + " 1.7341620922088623,\n", + " 1.3420584201812744,\n", + " 0.6467709541320801,\n", + " 0.7274090647697449,\n", + " 0.4892843961715698,\n", + " 0.42712870240211487,\n", + " 1.8384690284729004,\n", + " 0.7235339879989624,\n", + " 1.1163374185562134,\n", + " 1.3501720428466797,\n", + " 0.42495307326316833,\n", + " 1.4814316034317017,\n", + " 0.9633446335792542,\n", + " 0.7441746592521667,\n", + " 0.49172279238700867,\n", + " 0.8439255952835083,\n", + " 1.7822625637054443,\n", + " 2.70001482963562,\n", + " 1.226853370666504,\n", + " 1.6889946460723877,\n", + " 0.8238599300384521,\n", + " 1.1699568033218384,\n", + " 0.5300586223602295,\n", + " 0.34182092547416687,\n", + " 0.6136969327926636,\n", + " 0.7866867780685425,\n", + " 2.1859169006347656,\n", + " 0.5212369561195374,\n", + " 1.2727247476577759,\n", + " 0.45056644082069397,\n", + " 1.4032206535339355,\n", + " 0.4590488374233246,\n", + " 0.4650927782058716,\n", + " 0.8082565665245056,\n", + " 0.7144088745117188,\n", + " 0.543725311756134,\n", + " 1.4212666749954224,\n", + " 1.2204252481460571,\n", + " 0.25085633993148804,\n", + " 1.3868451118469238,\n", + " 1.0468615293502808,\n", + " 1.4329462051391602,\n", + " 0.6538652777671814,\n", + " 1.561804175376892,\n", + " 2.5266776084899902,\n", + " 0.7883028984069824,\n", + " 5.479814529418945,\n", + " 2.439664840698242,\n", + " 1.371835470199585,\n", + " 0.9409430027008057,\n", + " 0.46193727850914,\n", + " 0.5832473039627075,\n", + " 1.6056299209594727,\n", + " 1.7503899335861206,\n", + " 0.9054805040359497,\n", + " 0.7594331502914429,\n", + " 1.2933785915374756,\n", + " 0.4637310206890106,\n", + " 0.47456109523773193,\n", + " 0.7240238189697266,\n", + " 1.725350260734558,\n", + " 2.4346182346343994,\n", + " 2.2380924224853516,\n", + " 2.0717313289642334,\n", + " 0.5502132177352905,\n", + " 1.7969448566436768,\n", + " 1.1502264738082886,\n", + " 0.26300373673439026,\n", + " 0.8080339431762695,\n", + " 0.28202709555625916,\n", + " 0.145426943898201,\n", + " 0.8435786962509155,\n", + " 0.5270522236824036,\n", + " 0.3560110628604889,\n", + " 0.609140157699585,\n", + " 0.4524146318435669,\n", + " 0.5814225077629089,\n", + " 1.2689967155456543,\n", + " 0.5928022861480713,\n", + " 0.2952418625354767,\n", + " 0.24432311952114105,\n", + " 0.7377331256866455,\n", + " 0.8195648193359375,\n", + " 1.2766600847244263,\n", + " 0.9409489035606384,\n", + " 0.830878496170044,\n", + " 0.4374370574951172,\n", + " 0.3859502971172333,\n", + " 2.2971324920654297,\n", + " 0.41110262274742126,\n", + " 0.7398536801338196,\n", + " 0.43272972106933594,\n", + " 0.5752124190330505,\n", + " 0.3000510632991791,\n", + " 0.8548444509506226,\n", + " 0.373995304107666,\n", + " 1.8866313695907593,\n", + " 1.077983021736145,\n", + " 0.5132556557655334,\n", + " 1.1473032236099243,\n", + " 0.6178485155105591,\n", + " 0.4333594739437103,\n", + " 1.4737694263458252,\n", + " 0.9747253060340881,\n", + " 1.5080058574676514,\n", + " 1.314931869506836,\n", + " 0.9588497877120972,\n", + " 0.39719900488853455,\n", + " 1.0430501699447632,\n", + " 1.1309837102890015,\n", + " 0.43614932894706726,\n", + " 0.58064204454422,\n", + " 1.1400829553604126,\n", + " 0.3988802433013916,\n", + " 0.963148832321167,\n", + " 2.17482328414917,\n", + " 1.3901969194412231,\n", + " 0.3156747817993164,\n", + " 0.5436887741088867,\n", + " 0.36854708194732666,\n", + " 0.37455135583877563,\n", + " 0.2726321220397949,\n", + " 0.3139721155166626,\n", + " 0.5012255311012268,\n", + " 0.820480227470398,\n", + " 1.0951125621795654,\n", + " 0.6761919856071472,\n", + " 0.790934145450592,\n", + " 0.9907330274581909,\n", + " 0.8022286891937256,\n", + " 0.3866922855377197,\n", + " 0.7084116339683533,\n", + " 0.866324245929718,\n", + " 0.46766072511672974,\n", + " 0.26419714093208313,\n", + " 0.32584092020988464,\n", + " 1.2846601009368896,\n", + " 0.39885473251342773,\n", + " 1.0205860137939453,\n", + " 0.27573293447494507,\n", + " 0.24224549531936646,\n", + " 0.6909810900688171,\n", + " 0.3044925034046173,\n", + " 0.25011563301086426,\n", + " 0.44614750146865845,\n", + " 0.6451624035835266,\n", + " 0.5779326558113098],\n", + " 'iter_times': [0.03169812494888902,\n", + " 0.027584708062931895,\n", + " 0.02858670800924301,\n", + " 0.026709083002060652,\n", + " 0.0258675420191139,\n", + " 0.02468925016000867,\n", + " 0.030730166006833315,\n", + " 0.016805625054985285,\n", + " 0.026752999983727932,\n", + " 0.023861791007220745,\n", + " 0.01803375012241304,\n", + " 0.022661666851490736,\n", + " 0.02204920817166567,\n", + " 0.02710441709496081,\n", + " 0.02583591709844768,\n", + " 0.027296457905322313,\n", + " 0.027498583076521754,\n", + " 0.023009249940514565,\n", + " 0.029670832911506295,\n", + " 0.026142749935388565,\n", + " 0.026309833861887455,\n", + " 0.032059666933491826,\n", + " 0.02789283310994506,\n", + " 0.02576716709882021,\n", + " 0.02486624987795949,\n", + " 0.027173375012353063,\n", + " 0.02338912500999868,\n", + " 0.024146082811057568,\n", + " 0.02213795785792172,\n", + " 0.023384332889690995,\n", + " 0.02873941697180271,\n", + " 0.0250537081155926,\n", + " 0.028922915924340487,\n", + " 0.02695966698229313,\n", + " 0.03005241695791483,\n", + " 0.025074792094528675,\n", + " 0.02407074999064207,\n", + " 0.028803542023524642,\n", + " 0.024359500035643578,\n", + " 0.025547208031639457,\n", + " 0.02867558295838535,\n", + " 0.02669191686436534,\n", + " 0.02979383314959705,\n", + " 0.026735665975138545,\n", + " 0.028341416968032718,\n", + " 0.026234874967485666,\n", + " 0.030128249898552895,\n", + " 0.02529741614125669,\n", + " 0.027294500032439828,\n", + " 0.026511665899306536,\n", + " 0.024816541001200676,\n", + " 0.03207574994303286,\n", + " 0.02337020798586309,\n", + " 0.026867209002375603,\n", + " 0.03262049984186888,\n", + " 0.026773500023409724,\n", + " 0.025173667119815946,\n", + " 0.030965832993388176,\n", + " 0.023638332961127162,\n", + " 0.03076474997214973,\n", + " 0.026863333070650697,\n", + " 0.025366832967847586,\n", + " 0.03145316708832979,\n", + " 0.03426345810294151,\n", + " 0.03046545898541808,\n", + " 0.029700499959290028,\n", + " 0.029783915961161256,\n", + " 0.02838366595096886,\n", + " 0.029486834071576595,\n", + " 0.02670641685836017,\n", + " 0.031350333942100406,\n", + " 0.027266208082437515,\n", + " 0.025731500005349517,\n", + " 0.030091083142906427,\n", + " 0.031152040930464864,\n", + " 0.03152075014077127,\n", + " 0.01954770809970796,\n", + " 0.026827292051166296,\n", + " 0.027664915891364217,\n", + " 0.024289875058457255,\n", + " 0.02911649993620813,\n", + " 0.029995207907631993,\n", + " 0.02480362495407462,\n", + " 0.024720708839595318,\n", + " 0.024827249813824892,\n", + " 0.027033083140850067,\n", + " 0.027515250025317073,\n", + " 0.029572209110483527,\n", + " 0.02584737492725253,\n", + " 0.02511425013653934,\n", + " 0.02771108318120241,\n", + " 0.03237133310176432,\n", + " 0.032289124792441726,\n", + " 0.024745125090703368,\n", + " 0.03070554183796048,\n", + " 0.0387380828615278,\n", + " 0.027526458026841283,\n", + " 0.038875750033184886,\n", + " 0.033374999882653356,\n", + " 0.04179483302868903,\n", + " 0.03558491705916822,\n", + " 0.03673141589388251,\n", + " 0.03694891603663564,\n", + " 0.031565249897539616,\n", + " 0.029297207947820425,\n", + " 0.03449812508188188,\n", + " 0.032412667060270905,\n", + " 0.03472112491726875,\n", + " 0.03546695807017386,\n", + " 0.03327808412723243,\n", + " 0.028670792002230883,\n", + " 0.02604879206046462,\n", + " 0.037921792129054666,\n", + " 0.03335845796391368,\n", + " 0.03821487491950393,\n", + " 0.038532041013240814,\n", + " 0.038210750091820955,\n", + " 0.034492208855226636,\n", + " 0.03084229095838964,\n", + " 0.03950933297164738,\n", + " 0.026532666059210896,\n", + " 0.03208570904098451,\n", + " 0.03370637493208051,\n", + " 0.03659145790152252,\n", + " 0.03467929200269282,\n", + " 0.03974362509325147,\n", + " 0.03336887480691075,\n", + " 0.03240949986502528,\n", + " 0.025664541870355606,\n", + " 0.02777195884846151,\n", + " 0.030273291980847716,\n", + " 0.039200125029310584,\n", + " 0.02961912495084107,\n", + " 0.030783750116825104,\n", + " 0.03524379199370742,\n", + " 0.039315874921157956,\n", + " 0.04285550001077354,\n", + " 0.03660220908932388,\n", + " 0.04133837507106364,\n", + " 0.03899241704493761,\n", + " 0.03553324984386563,\n", + " 0.03786025010049343,\n", + " 0.032377874944359064,\n", + " 0.028963749995455146,\n", + " 0.03383512492291629,\n", + " 0.028710374841466546,\n", + " 0.03110875003039837,\n", + " 0.030619458062574267,\n", + " 0.03356295800767839,\n", + " 0.03945783409290016,\n", + " 0.039811542024835944,\n", + " 0.03152520814910531,\n", + " 0.040032291086390615,\n", + " 0.038056124933063984,\n", + " 0.03658033302053809,\n", + " 0.03990195784717798,\n", + " 0.038062167121097445,\n", + " 0.041290540946647525,\n", + " 0.040220499970018864,\n", + " 0.036755625158548355,\n", + " 0.03777787508442998,\n", + " 0.04015366593375802,\n", + " 0.0368197918869555,\n", + " 0.038513916078954935,\n", + " 0.03783300006762147,\n", + " 0.04108829190954566,\n", + " 0.038560917135328054,\n", + " 0.04114095913246274,\n", + " 0.04210358299314976,\n", + " 0.04152379208244383,\n", + " 0.03358387481421232,\n", + " 0.03409449988976121,\n", + " 0.03836424998007715,\n", + " 0.03440766618587077,\n", + " 0.0377882921602577,\n", + " 0.03659795899875462,\n", + " 0.037964458810165524,\n", + " 0.03920579212717712,\n", + " 0.04200679203495383,\n", + " 0.03882258292287588,\n", + " 0.039652334060519934,\n", + " 0.037086208118125796,\n", + " 0.04016891587525606,\n", + " 0.04108100011944771,\n", + " 0.038782832911238074,\n", + " 0.037558333948254585,\n", + " 0.04046133300289512,\n", + " 0.03616187488660216,\n", + " 0.03772379201836884,\n", + " 0.040625750087201595,\n", + " 0.0360937500372529,\n", + " 0.039527124958112836,\n", + " 0.02647212496958673,\n", + " 0.03991162497550249,\n", + " 0.034061457961797714,\n", + " 0.03235474997200072,\n", + " 0.025128833018243313,\n", + " 0.03827429190278053,\n", + " 0.03287583403289318,\n", + " 0.03880300000309944],\n", + " 'compile_mode': 'reduce-overhead',\n", + " 'use_compile': True,\n", + " 'requested_vmap': False,\n", + " 'effective_vmap': False,\n", + " 'chunk_size_effective': 0,\n", + " 'label': 'Compiled'}]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import importlib\n", + "import json\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add project root to path (notebook is in tutorials/notebooks/)\n", + "project_root = Path.cwd().parent.parent\n", + "sys.path.append(str(project_root))\n", + "from tutorials.examples import train_hypergrid_optimized as hypergrid_train\n", + "\n", + "\n", + "# Reload to pick up local edits without restarting the kernel.\n", + "importlib.reload(hypergrid_train)\n", + "\n", + "\n", + "def notebook_benchmark_run(\n", + " *,\n", + " compile_mode: str = \"none\",\n", + " use_compile: bool = False,\n", + " chunk_size: int = 0,\n", + " n_iterations: int = 200,\n", + " warmup_iters: int = 50,\n", + " seed: int = 0,\n", + " device: str = \"cpu\",\n", + " label: str,\n", + ") -> dict:\n", + " argv_backup = sys.argv\n", + " try:\n", + " sys.argv = [sys.argv[0]]\n", + " args = hypergrid_train.parse_args()\n", + " finally:\n", + " sys.argv = argv_backup\n", + " args.compile = use_compile\n", + " args.compile_mode = compile_mode\n", + " args.chunk_size = chunk_size\n", + " args.n_iterations = n_iterations\n", + " args.warmup_iters = warmup_iters\n", + " args.seed = seed\n", + " args.device = device\n", + " args.benchmark = True\n", + " args.use_vmap = False\n", + " args.loss = \"TB\"\n", + " args.batch_size = 16\n", + " args.height = 32\n", + " args.ndim = 2\n", + "\n", + " result = hypergrid_train.train_with_options(\n", + " args,\n", + " device=hypergrid_train.resolve_device(device),\n", + " enable_compile=use_compile,\n", + " use_vmap=False,\n", + " warmup_iters=warmup_iters,\n", + " quiet=True,\n", + " timing=True,\n", + " record_history=True,\n", + " use_chunk=(chunk_size > 0),\n", + " )\n", + " result[\"label\"] = label\n", + " result[\"compile_mode\"] = compile_mode if use_compile else \"none\"\n", + " return result\n", + "\n", + "\n", + "scenarios = [\n", + " dict(label=\"Eager\", use_compile=False),\n", + " dict(label=\"Compiled\", use_compile=True, compile_mode=\"reduce-overhead\"),\n", + "]\n", + "\n", + "benchmark_runs = []\n", + "for scenario in scenarios:\n", + " run_result = notebook_benchmark_run(**scenario)\n", + " benchmark_runs.append(run_result)\n", + "\n", + "benchmark_runs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelelapsedcompile_modeeffective_vmapchunk_size_effective
0Eager6.105576noneFalse0
1Compiled6.345466reduce-overheadFalse0
\n", + "
" + ], + "text/plain": [ + " label elapsed compile_mode effective_vmap chunk_size_effective\n", + "0 Eager 6.105576 none False 0\n", + "1 Compiled 6.345466 reduce-overhead False 0" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline label: Eager elapsed: 6.11s\n", + "Compiled elapsed=6.35s (0.96x vs baseline), compile_mode=reduce-overhead\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "benchmark_df = pd.DataFrame(benchmark_runs)\n", + "display(\n", + " benchmark_df[\n", + " [\n", + " \"label\",\n", + " \"elapsed\",\n", + " \"compile_mode\",\n", + " \"effective_vmap\",\n", + " \"chunk_size_effective\",\n", + " ]\n", + " ]\n", + ")\n", + "\n", + "baseline = benchmark_df.iloc[0]\n", + "print(\"Baseline label:\", baseline[\"label\"], \"elapsed:\", f\"{baseline['elapsed']:.2f}s\")\n", + "for idx in range(1, len(benchmark_df)):\n", + " row = benchmark_df.iloc[idx]\n", + " speedup = baseline[\"elapsed\"] / row[\"elapsed\"] if row[\"elapsed\"] else float(\"inf\")\n", + " print(\n", + " f\"{row['label']} elapsed={row['elapsed']:.2f}s \"\n", + " f\"({speedup:.2f}x vs baseline), compile_mode={row['compile_mode']}\"\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dynamo trace analysis\n", + "\n", + "`torch._dynamo.explain` gives a per-graph summary: captured ops, guards, and where graph breaks (if any) occur. The cell below reuses the state/action tensors above and prints the explanation so you can confirm there is only one graph and zero breaks.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Graph Count: 16\n", + "Graph Break Count: 15\n", + "Op Count: 24\n", + "Break Reasons:\n", + " Break Reason 1:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " Break Reason 2:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + " Break Reason 3:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " Break Reason 4:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " Break Reason 5:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " Break Reason 6:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " Break Reason 7:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " Break Reason 8:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " Break Reason 9:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + " Break Reason 10:\n", + " Reason: Dynamic shape operator\n", + " Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.\n", + " Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`\n", + "\n", + " Developer debug context: aten.repeat_interleave.Tensor\n", + "\n", + " User Stack:\n", + " \n", + " \n", + " \n", + " \n", + "Ops per Graph:\n", + " Ops 1:\n", + " \n", + " \n", + " Ops 2:\n", + " \n", + " Ops 3:\n", + " \n", + " Ops 4:\n", + " \n", + " Ops 5:\n", + " Ops 6:\n", + " \n", + " Ops 7:\n", + " \n", + " Ops 8:\n", + " Ops 9:\n", + " \n", + " Ops 10:\n", + " \n", + " Ops 11:\n", + " Ops 12:\n", + " Ops 13:\n", + " \n", + " aten._assert_async.msg\n", + " \n", + " \n", + " Ops 14:\n", + " Ops 15:\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " Ops 16:\n", + " \n", + " \n", + " \n", + " \n", + "Out Guards:\n", + " Guard 1:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 2:\n", + " Name: \"G['__import_gfn_dot_states'].torch\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch, 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 3:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['states_tensor'].size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 4:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 5:\n", + " Name: \"G['__import_gfn_dot_states']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'], 5499072048)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 6:\n", + " Name: \"L['states_tensor']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states_tensor'], 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 7:\n", + " Name: \"G['env'].States.s0\"\n", + " Source: global\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(G['env'].States.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 8:\n", + " Name: \"G['env']\"\n", + " Source: global\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(G['env'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 9:\n", + " Name: \"L['actions_tensor']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions_tensor'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 10:\n", + " Name: \"G['__builtins_dict___66']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 11:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 12:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 13:\n", + " Name: \"G['__builtins_dict___66']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 14:\n", + " Name: \"G['__builtins_dict___66']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 15:\n", + " Name: \"G['__builtins_dict___66']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 16:\n", + " Name: \"G['__import_gfn_dot_states'].torch.bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch.bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 17:\n", + " Name: \"L['states_tensor']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states_tensor'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 18:\n", + " Name: \"G['env'].States.sf\"\n", + " Source: global\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(G['env'].States.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 19:\n", + " Name: \"G['__builtins_dict___66']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___66']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 20:\n", + " Name: \"G['env'].States.n_actions\"\n", + " Source: global\n", + " Create Function: EQUALS_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['env'].States.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 21:\n", + " Name: \"G['__import_gfn_dot_states'].torch.ones\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.ones, 4428225552)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 22:\n", + " Name: \"G['env'].Actions.action_shape\"\n", + " Source: global\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(G['env'].Actions.action_shape, 4305555088)\", \"len(G['env'].Actions.action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 23:\n", + " Name: \"G['env'].States.state_shape\"\n", + " Source: global\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(G['env'].States.state_shape, 4305555088)\", \"len(G['env'].States.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 24:\n", + " Name: \"G['env'].Actions.action_shape[0]\"\n", + " Source: global\n", + " Create Function: EQUALS_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['env'].Actions.action_shape[0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 25:\n", + " Name: \"G['env'].States.state_shape[0]\"\n", + " Source: global\n", + " Create Function: EQUALS_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['env'].States.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 26:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 27:\n", + " Name: \"L['states_tensor'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['states_tensor'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 28:\n", + " Name: \"G['env'].States\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['env'].States, 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 29:\n", + " Name: \"G['env'].Actions\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['env'].Actions, 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 30:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 31:\n", + " Name: \"G['__builtins_dict___70']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 32:\n", + " Name: \"G['__import_gfn_dot_states'].torch\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch, 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 33:\n", + " Name: \"G['__builtins_dict___70']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 34:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 35:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 36:\n", + " Name: \"G['__builtins_dict___70']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 37:\n", + " Name: \"G['__import_gfn_dot_states']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'], 5499072048)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 38:\n", + " Name: \"G['__builtins_dict___70']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___70']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 39:\n", + " Name: \"L['actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 40:\n", + " Name: \"G['__import_gfn_dot_states'].torch.Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 41:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 42:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 43:\n", + " Name: \"L['states'].__class__.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'].__class__.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 44:\n", + " Name: \"L['actions'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 45:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: DUPLICATE_INPUT\n", + " Guard Types: ['DUPLICATE_INPUT']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch is G['torch']\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 46:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 47:\n", + " Name: \"G['__import_gfn_dot_states'].torch.bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch.bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 48:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 49:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 50:\n", + " Name: \"L['self'].check_action_validity\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['BOOL_MATCH']\n", + " Code List: [\"L['self'].check_action_validity == True\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 51:\n", + " Name: \"L['states'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].state_shape, 4305555088)\", \"len(L['states'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 52:\n", + " Name: \"G['__import_gfn_dot_states'].ensure_same_device\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 53:\n", + " Name: \"L['states'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].tensor.shape, 4891320080)\", \"len(L['states'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 54:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 55:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 56:\n", + " Name: \"L['actions'].action_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['actions'].action_shape, 4305555088)\", \"len(L['actions'].action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 57:\n", + " Name: \"L['states'].state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['states'].state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 58:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 59:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 60:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 61:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 62:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 63:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 64:\n", + " Name: \"G['__builtins_dict___74']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___74']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 65:\n", + " Name: \"L['index']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['index'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 66:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 67:\n", + " Name: \"G['torch'].Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 68:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 69:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 70:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 71:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 72:\n", + " Name: \"L['self'].action_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].action_shape, 4305555088)\", \"len(L['self'].action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 73:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 74:\n", + " Name: \"G['__builtins_dict___77']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___77']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 75:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 76:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 77:\n", + " Name: \"G['__builtins_dict___77']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___77']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 78:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 79:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 80:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 81:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 82:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 83:\n", + " Name: \"L['mask'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['mask'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 84:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 85:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 86:\n", + " Name: \"G['torch'].arange\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].arange, 4428197584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 87:\n", + " Name: \"G['torch'].repeat_interleave\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].repeat_interleave, 4428289808)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 88:\n", + " Name: \"G['_expand_mask_to_batch']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 89:\n", + " Name: \"L['batch_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['batch_shape'][0] == 4\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 90:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType, 4305497480)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 91:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 92:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 93:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].cmp_eq\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 94:\n", + " Name: \"G['torch'].int64\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].int64 == torch.int64\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 95:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 96:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 97:\n", + " Name: \"L['data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 98:\n", + " Name: \"G['__builtins_dict___79']['type']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['type'], 4305563712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 99:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 100:\n", + " Name: \"L['device']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['device'] == device(type='cpu')\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 101:\n", + " Name: \"G['__builtins_dict___79']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 102:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 103:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 104:\n", + " Name: \"G['__builtins_dict___79']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 105:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types, 4308052752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 106:\n", + " Name: \"G['__builtins_dict___79']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___79']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 107:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 108:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['value_shape'][0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 109:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 110:\n", + " Name: \"L['data'].reshape\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['data'], 'reshape')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 111:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'], 5486842384)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 112:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 113:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 114:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 115:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 116:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 117:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 118:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 119:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 120:\n", + " Name: \"L['original_ndim']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['original_ndim'] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 121:\n", + " Name: \"G['torch'].index_select\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].index_select, 4428298288)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 122:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 123:\n", + " Name: \"L['flat_data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['flat_data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 124:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['value_shape'][0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 125:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 126:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 127:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 128:\n", + " Name: \"L['valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 129:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 130:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 131:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 132:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 133:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 134:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 135:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 136:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 137:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 138:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 139:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 140:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 141:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 142:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 143:\n", + " Name: \"G['__builtins_dict___90']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___90']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 144:\n", + " Name: \"G['torch'].Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 145:\n", + " Name: \"L['index']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['index'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 146:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 147:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['self'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 148:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 149:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 150:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 151:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 152:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 153:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 154:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 155:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 156:\n", + " Name: \"G['__builtins_dict___93']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___93']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 157:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 158:\n", + " Name: \"L['self'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].tensor.shape, 4891320080)\", \"len(L['self'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 159:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 160:\n", + " Name: \"G['__builtins_dict___93']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___93']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 161:\n", + " Name: \"L['self'].state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 162:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 163:\n", + " Name: \"L['mask'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['mask'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 164:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 165:\n", + " Name: \"G['__builtins_dict___96']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 166:\n", + " Name: \"G['torch'].arange\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].arange, 4428197584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 167:\n", + " Name: \"G['torch'].repeat_interleave\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].repeat_interleave, 4428289808)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 168:\n", + " Name: \"G['_expand_mask_to_batch']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 169:\n", + " Name: \"L['batch_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['batch_shape'][0] == 4\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 170:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType, 4305497480)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 171:\n", + " Name: \"L['data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 172:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['data'].stride()[0] == L['data'].size()[1]\", \"L['value_shape'][0] == L['data'].size()[1]\", \"2 <= L['data'].size()[0]\", \"2 <= L['data'].size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 173:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 174:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].cmp_eq\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 175:\n", + " Name: \"G['torch'].int64\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].int64 == torch.int64\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 176:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 177:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4305555088)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 178:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 179:\n", + " Name: \"L['device']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['device'] == device(type='cpu')\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 180:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 181:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 182:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 183:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types, 4308052752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 184:\n", + " Name: \"G['__builtins_dict___96']['type']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['type'], 4305563712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 185:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 186:\n", + " Name: \"G['__builtins_dict___96']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 187:\n", + " Name: \"G['__builtins_dict___96']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___96']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 188:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 189:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 190:\n", + " Name: \"L['data'].reshape\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['data'], 'reshape')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 191:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'], 5486842384)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 192:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['flat_data'].stride()[0] == L['flat_data'].size()[1]\", \"L['flat_data']._base.stride()[0] == L['flat_data']._base.size()[1]\", \"L['value_shape'][0] == L['flat_data'].size()[1]\", \"2 <= L['flat_data'].size()[1]\", \"2 <= L['flat_data']._base.size()[0]\", \"2 <= L['flat_data']._base.size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 193:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 194:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 195:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 196:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 197:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4305555088)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 198:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 199:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 200:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 201:\n", + " Name: \"L['original_ndim']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['original_ndim'] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 202:\n", + " Name: \"G['torch'].index_select\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].index_select, 4428298288)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 203:\n", + " Name: \"L['flat_data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['flat_data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 204:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 205:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 206:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['self'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 207:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 208:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 209:\n", + " Name: \"L['self'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 210:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 211:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 212:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 213:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 214:\n", + " Name: \"L['self'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].tensor.shape, 4891320080)\", \"len(L['self'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 215:\n", + " Name: \"G['__builtins_dict___102']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___102']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 216:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 217:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 218:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 219:\n", + " Name: \"L['self'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 220:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 221:\n", + " Name: \"G['__builtins_dict___102']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___102']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 222:\n", + " Name: \"L['bool_mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['bool_mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 223:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 224:\n", + " Name: \"L['mask'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['mask'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 225:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 226:\n", + " Name: \"G['torch'].arange\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].arange, 4428197584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 227:\n", + " Name: \"G['__builtins_dict___106']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 228:\n", + " Name: \"G['torch'].repeat_interleave\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].repeat_interleave, 4428289808)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 229:\n", + " Name: \"G['_expand_mask_to_batch']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 230:\n", + " Name: \"L['batch_shape'][0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['batch_shape'][0] == 4\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 231:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types.FunctionType, 4305497480)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 232:\n", + " Name: \"G['__builtins_dict___106']['type']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['type'], 4305563712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 233:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['data'].stride()[0] == L['data'].size()[1]\", \"L['value_shape'][0] == L['data'].size()[1]\", \"2 <= L['data'].size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 234:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 235:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].cmp_eq\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 236:\n", + " Name: \"G['torch'].int64\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].int64 == torch.int64\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 237:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 238:\n", + " Name: \"L['mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 239:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 240:\n", + " Name: \"L['data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 241:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 242:\n", + " Name: \"L['device']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['device'] == device(type='cpu')\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 243:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 244:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 245:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 246:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills'].types\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'].types, 4308052752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 247:\n", + " Name: \"G['__builtins_dict___106']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 248:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 249:\n", + " Name: \"G['__builtins_dict___106']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___106']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 250:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 251:\n", + " Name: \"L['data'].reshape\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['data'], 'reshape')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 252:\n", + " Name: \"G['__import_torch_dot__dynamo_dot_polyfills']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_torch_dot__dynamo_dot_polyfills'], 5486842384)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 253:\n", + " Name: \"L['flat_data']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['flat_data'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 254:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV', 'SHAPE_ENV']\n", + " Code List: [\"L['flat_data'].stride()[0] == L['flat_data'].size()[1]\", \"L['flat_data']._base.stride()[0] == L['flat_data']._base.size()[1]\", \"L['value_shape'][0] == L['flat_data'].size()[1]\", \"2 <= L['flat_data'].size()[1]\", \"2 <= L['flat_data']._base.size()[1]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 255:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 256:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 257:\n", + " Name: \"L['value_shape'][0]\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['value_shape'][0], 4305558200)\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 258:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 259:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 260:\n", + " Name: \"L['value_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['value_shape'], 4891320080)\", \"len(L['value_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 261:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 262:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 263:\n", + " Name: \"L['original_ndim']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['original_ndim'] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 264:\n", + " Name: \"G['torch'].index_select\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].index_select, 4428298288)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 265:\n", + " Name: \"L['batch_shape']\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['batch_shape'], 4305555088)\", \"len(L['batch_shape']) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 266:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 267:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['self'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 268:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 269:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 270:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 271:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 272:\n", + " Name: \"G['__builtins_dict___112']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___112']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 273:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 274:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 275:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 276:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 277:\n", + " Name: \"L['self'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].tensor.shape, 4891320080)\", \"len(L['self'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 278:\n", + " Name: \"G['boolean_mask_select']\"\n", + " Source: global\n", + " Create Function: CLOSURE_MATCH\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 279:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 280:\n", + " Name: \"L['bool_mask']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['bool_mask'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 281:\n", + " Name: \"L['self'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 282:\n", + " Name: \"G['__builtins_dict___112']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___112']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 283:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 284:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 285:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 286:\n", + " Name: \"G['__builtins_dict___115']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 287:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 288:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 289:\n", + " Name: \"L['self'].__class__.n_actions\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].__class__.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 290:\n", + " Name: \"L['forward_masks']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['forward_masks'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 291:\n", + " Name: \"L['___stack0'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 292:\n", + " Name: \"L['self'].__class__.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].__class__.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 293:\n", + " Name: \"L['forward_masks'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['forward_masks'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 294:\n", + " Name: \"L['self'].__class__.state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].__class__.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 295:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 296:\n", + " Name: \"L['self'].__class__.s0\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].__class__.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 297:\n", + " Name: \"G['__builtins_dict___115']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 298:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 299:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 300:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 301:\n", + " Name: \"G['__builtins_dict___115']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 302:\n", + " Name: \"L['self'].__class__.state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].__class__.state_shape, 4305555088)\", \"len(L['self'].__class__.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 303:\n", + " Name: \"L['states'].to\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['states'], 'to')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 304:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 305:\n", + " Name: \"G['__builtins_dict___115']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 306:\n", + " Name: \"G['__builtins_dict___115']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___115']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 307:\n", + " Name: \"L['valid_actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['valid_actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 308:\n", + " Name: \"G['torch'].ops.aten\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].ops.aten, 4848906736)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 309:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 310:\n", + " Name: \"G['__builtins_dict___118']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 311:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 312:\n", + " Name: \"L['actions'].__class__.exit_action\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions'].__class__.exit_action, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 313:\n", + " Name: \"G['torch'].ops\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].ops, 4697612160)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 314:\n", + " Name: \"L['___stack0'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 315:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 316:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 317:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 318:\n", + " Name: \"G['torch'].compiler.is_compiling\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].compiler.is_compiling, 4965572688)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 319:\n", + " Name: \"G['torch'].gather\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].gather, 4428298368)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 320:\n", + " Name: \"L['actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 321:\n", + " Name: \"L['___stack0'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 322:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 323:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 324:\n", + " Name: \"G['__builtins_dict___118']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 325:\n", + " Name: \"G['__builtins_dict___118']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 326:\n", + " Name: \"L['valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 327:\n", + " Name: \"L['valid_actions'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['valid_actions'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 328:\n", + " Name: \"G['__builtins_dict___118']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___118']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 329:\n", + " Name: \"L['self'].is_action_valid.__defaults__[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['BOOL_MATCH']\n", + " Code List: [\"L['self'].is_action_valid.__defaults__[0] == False\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 330:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 331:\n", + " Name: \"G['torch'].compiler\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].compiler, 4965862352)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 332:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 333:\n", + " Name: \"G['torch'].bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['torch'].bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 334:\n", + " Name: \"L['actions'].action_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['actions'].action_shape[0] == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 335:\n", + " Name: \"L['actions'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['actions'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 336:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 337:\n", + " Name: \"G['torch'].Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 338:\n", + " Name: \"L['actions'].action_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['actions'].action_shape, 4305555088)\", \"len(L['actions'].action_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 339:\n", + " Name: \"G['torch'].ops.aten._assert_async\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].ops.aten._assert_async, 4848920192)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 340:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 341:\n", + " Name: \"L['___stack0'].__class__.state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['___stack0'].__class__.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 342:\n", + " Name: \"L['___stack0'].backward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].backward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 343:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 344:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 345:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 346:\n", + " Name: \"L['___stack0'].__class__.n_actions\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['___stack0'].__class__.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 347:\n", + " Name: \"L['___stack0'].__class__.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].__class__.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 348:\n", + " Name: \"G['__builtins_dict___122']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 349:\n", + " Name: \"L['actions']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['actions'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 350:\n", + " Name: \"L['___stack0'].backward_masks.clone\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'].backward_masks, 'clone')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 351:\n", + " Name: \"L['___stack0'].__class__.s0\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].__class__.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 352:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 353:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 354:\n", + " Name: \"G['__builtins_dict___122']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 355:\n", + " Name: \"L['___stack0'].__class__.state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['___stack0'].__class__.state_shape, 4305555088)\", \"len(L['___stack0'].__class__.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 356:\n", + " Name: \"L['new_valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['new_valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 357:\n", + " Name: \"L['___stack0'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 358:\n", + " Name: \"G['__builtins_dict___122']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 359:\n", + " Name: \"G['__builtins_dict___122']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 360:\n", + " Name: \"G['__builtins_dict___122']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___122']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 361:\n", + " Name: \"L['___stack0'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 362:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 363:\n", + " Name: \"L['___stack0'].tensor.clone\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'].tensor, 'clone')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 364:\n", + " Name: \"L['___stack0'].forward_masks.clone\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['___stack0'].forward_masks, 'clone')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 365:\n", + " Name: \"L['self'].States.state_shape[0]\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].States.state_shape[0] == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 366:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 367:\n", + " Name: \"G['__builtins_dict___125']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 368:\n", + " Name: \"L['___stack0']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['___stack0'], 6146095712)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 369:\n", + " Name: \"G['__import_gfn_dot_states'].torch\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch, 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 370:\n", + " Name: \"L['not_done_states'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['not_done_states'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 371:\n", + " Name: \"G['__builtins_dict___125']['zip']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['zip'], 4305487296)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 372:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: ['SHAPE_ENV']\n", + " Code List: [\"2 <= L['states'].tensor.size()[0]\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 373:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 374:\n", + " Name: \"G['__import_gfn_dot_states']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'], 5499072048)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 375:\n", + " Name: \"G['__builtins_dict___125']['super']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['super'], 4305490664)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 376:\n", + " Name: \"G['__import_gfn_dot_states'].torch.Tensor\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.Tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 377:\n", + " Name: \"L['self'].States.s0\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].States.s0, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 378:\n", + " Name: \"G['__builtins_dict___125']['isinstance']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['isinstance'], 4307013760)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 379:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 380:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 381:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['states'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 382:\n", + " Name: \"L['self'].States.sf.repeat\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['self'].States.sf, 'repeat')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 383:\n", + " Name: \"L['self'].States\"\n", + " Source: local\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(L['self'].States, 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 384:\n", + " Name: \"L['not_done_states'].tensor.scatter\"\n", + " Source: local\n", + " Create Function: HASATTR\n", + " Guard Types: ['HASATTR']\n", + " Code List: [\"hasattr(L['not_done_states'].tensor, 'scatter')\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 385:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146079840)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 386:\n", + " Name: \"G['States']\"\n", + " Source: global\n", + " Create Function: ID_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['States'], 6146032752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 387:\n", + " Name: \"G['__import_gfn_dot_states'].torch.bool\"\n", + " Source: global\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"G['__import_gfn_dot_states'].torch.bool == torch.bool\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 388:\n", + " Name: \"L['states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 389:\n", + " Name: \"L['states'].tensor\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['states'].tensor, 5106997584)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 390:\n", + " Name: \"L['self'].States.state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].States.state_shape, 4305555088)\", \"len(L['self'].States.state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 391:\n", + " Name: \"L['states'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].state_shape, 4305555088)\", \"len(L['states'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 392:\n", + " Name: \"L['new_valid_states_idx']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['new_valid_states_idx'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 393:\n", + " Name: \"L['self'].States.sf\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].States.sf, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 394:\n", + " Name: \"L['self'].States.n_actions\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['EQUALS_MATCH']\n", + " Code List: [\"L['self'].States.n_actions == 3\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 395:\n", + " Name: \"G['__import_gfn_dot_states'].torch.ones\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__import_gfn_dot_states'].torch.ones, 4428225552)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 396:\n", + " Name: \"L['not_done_states']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['not_done_states'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 397:\n", + " Name: \"L['___stack0'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['___stack0'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 398:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 399:\n", + " Name: \"L['states'].tensor.shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['states'].tensor.shape, 4891320080)\", \"len(L['states'].tensor.shape) == 2\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 400:\n", + " Name: \"G['__builtins_dict___125']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___125']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 401:\n", + " Name: \"L['self'].tensor\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].tensor, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 402:\n", + " Name: \"G['torch'].zeros\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].zeros, 4428230352)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 403:\n", + " Name: \"G['__builtins_dict___127']['tuple']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___127']['tuple'], 4305555088)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 404:\n", + " Name: ''\n", + " Source: shape_env\n", + " Create Function: SHAPE_ENV\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 405:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: GRAD_MODE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 406:\n", + " Name: \"G['torch']\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'], 4351320976)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 407:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DETERMINISTIC_ALGORITHMS\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 408:\n", + " Name: \"L['cond']\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['cond'], '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 409:\n", + " Name: \"L['self'].state_shape\"\n", + " Source: local\n", + " Create Function: SEQUENCE_LENGTH\n", + " Guard Types: ['TYPE_MATCH', 'SEQUENCE_LENGTH']\n", + " Code List: [\"___check_type_id(L['self'].state_shape, 4305555088)\", \"len(L['self'].state_shape) == 1\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 410:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS\n", + " Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']\n", + " Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 411:\n", + " Name: \"L['self']\"\n", + " Source: local\n", + " Create Function: TYPE_MATCH\n", + " Guard Types: ['TYPE_MATCH']\n", + " Code List: [\"___check_type_id(L['self'], 6146094768)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 412:\n", + " Name: \"G['torch'].cat\"\n", + " Source: global\n", + " Create Function: FUNCTION_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['torch'].cat, 4428310752)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 413:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: TORCH_FUNCTION_STATE\n", + " Guard Types: None\n", + " Code List: None\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + " Guard 414:\n", + " Name: \"L['allow_exit']\"\n", + " Source: local\n", + " Create Function: CONSTANT_MATCH\n", + " Guard Types: ['BOOL_MATCH']\n", + " Code List: [\"L['allow_exit'] == True\"]\n", + " Object Weakref: None\n", + " Guarded Class Weakref: \n", + " Guard 415:\n", + " Name: \"L['self'].forward_masks\"\n", + " Source: local\n", + " Create Function: TENSOR_MATCH\n", + " Guard Types: ['TENSOR_MATCH']\n", + " Code List: [\"hasattr(L['self'].forward_masks, '_dynamo_dynamic_indices') == False\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 416:\n", + " Name: \"G['__builtins_dict___127']['len']\"\n", + " Source: global\n", + " Create Function: BUILTIN_MATCH\n", + " Guard Types: ['ID_MATCH']\n", + " Code List: [\"___check_obj_id(G['__builtins_dict___127']['len'], 4307014080)\"]\n", + " Object Weakref: \n", + " Guarded Class Weakref: \n", + " Guard 417:\n", + " Name: ''\n", + " Source: global\n", + " Create Function: DEFAULT_DEVICE\n", + " Guard Types: ['DEFAULT_DEVICE']\n", + " Code List: ['utils_device.CURRENT_DEVICE == None']\n", + " Object Weakref: None\n", + " Guarded Class Weakref: None\n", + "Compile Times: TorchDynamo compilation metrics:\n", + "Function, Runtimes (s)\n", + "_compile.compile_inner, 0.0755, 0.0391, 0.0135, 0.0140, 0.0205, 0.0074, 0.0033, 0.0027, 0.0187, 0.0188, 0.0175, 0.0315, 0.0340, 0.0261, 0.0341, 0.0167, 0.0232, 0.0155, 0.0428, 0.1327, 0.0327, 0.0096\n", + "compile_attempt_0, 0.0499, 0.0194, 0.0083, 0.0081, 0.0069, 0.0045, 0.0030, 0.0024, 0.0108, 0.0124, 0.0096, 0.0136, 0.0277, 0.0127, 0.0168, 0.0121, 0.0115, 0.0118, 0.0223, 0.1121, 0.0279, 0.0063\n", + "bytecode_tracing, 0.0484, 0.0159, 0.0188, 0.0110, 0.0079, 0.0011, 0.0077, 0.0016, 0.0064, 0.0053, 0.0027, 0.0026, 0.0020, 0.0103, 0.0001, 0.0118, 0.0013, 0.0092, 0.0019, 0.0131, 0.0104, 0.0253, 0.0122, 0.0045, 0.0161, 0.0096, 0.0101, 0.0110, 0.0038, 0.0096, 0.0218, 0.0104, 0.1115, 0.0103, 0.0242, 0.0044\n", + "compile_attempt_1, 0.0200, 0.0156, 0.0023, 0.0028, 0.0097, 0.0051, 0.0035, 0.0043, 0.0127, 0.0090, 0.0120, 0.0081, 0.0161, 0.0166\n", + "OutputGraph.call_user_compiler, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001\n", + "build_guards, 0.0046, 0.0035, 0.0023, 0.0026, 0.0034, 0.0025, 0.0023, 0.0024, 0.0031, 0.0046, 0.0054, 0.0038, 0.0046, 0.0041, 0.0032, 0.0032, 0.0037, 0.0033, 0.0042, 0.0028\n", + "gc, 0.0006, 0.0003, 0.0002, 0.0002, 0.0002, 0.0001, 0.0001, 0.0001, 0.0003, 0.0002, 0.0003, 0.0003, 0.0004, 0.0003, 0.0003, 0.0002, 0.0002, 0.0002, 0.0005, 0.0003, 0.0002, 0.0002\n", + "pgo.dynamic_whitelist, 0.0000, 0.0000, 0.0000, 0.0000\n", + "\n" + ] + } + ], + "source": [ + "import torch._dynamo as dynamo\n", + "\n", + "explanation = dynamo.explain(step_once)(states.tensor, actions.tensor)\n", + "print(explanation)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torchgfn", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}