From 8d964641063d0491a0ae2f86c7ea725bce53eae1 Mon Sep 17 00:00:00 2001 From: beabevi Date: Thu, 25 Aug 2022 23:42:43 -0400 Subject: [PATCH 1/3] Added subgraph_mode Use as adjacency matrix the subgraph around nodes having __ground-truth__ hints that changed from the previous iteration. --- clrs/_src/baselines.py | 11 ++++-- clrs/_src/nets.py | 35 ++++++++++++++--- clrs/_src/subgraphs_utils.py | 74 ++++++++++++++++++++++++++++++++++++ clrs/examples/run.py | 9 +++++ 4 files changed, 119 insertions(+), 10 deletions(-) create mode 100644 clrs/_src/subgraphs_utils.py diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index ee08d08b..0807213f 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -71,6 +71,7 @@ def __init__( dropout_prob: float = 0.0, hint_teacher_forcing_noise: float = 0.0, name: str = 'base_model', + subgraph_mode: str = None, ): """Constructor for BaselineModel. @@ -133,18 +134,20 @@ def __init__( self.nb_dims.append(nb_dims) self._create_net_fns(hidden_dim, encode_hints, processor_factory, use_lstm, - dropout_prob, hint_teacher_forcing_noise) + dropout_prob, hint_teacher_forcing_noise, subgraph_mode) self.params = None self.opt_state = None self.opt_state_skeleton = None def _create_net_fns(self, hidden_dim, encode_hints, processor_factory, - use_lstm, dropout_prob, hint_teacher_forcing_noise): + use_lstm, dropout_prob, hint_teacher_forcing_noise, subgraph_mode): def _use_net(*args, **kwargs): return nets.Net(self._spec, hidden_dim, encode_hints, self.decode_hints, self.decode_diffs, processor_factory, use_lstm, dropout_prob, - hint_teacher_forcing_noise, self.nb_dims)(*args, **kwargs) + hint_teacher_forcing_noise, + subgraph_mode, + self.nb_dims)(*args, **kwargs) self.net_fn = hk.transform(_use_net) self.net_fn_apply = jax.jit(self.net_fn.apply, @@ -328,7 +331,7 @@ class BaselineModelChunked(BaselineModel): """ def _create_net_fns(self, hidden_dim, encode_hints, processor_factory, - use_lstm, dropout_prob, hint_teacher_forcing_noise): + use_lstm, dropout_prob, hint_teacher_forcing_noise, subgraph_mode): def _use_net(*args, **kwargs): return nets.NetChunked( self._spec, hidden_dim, encode_hints, diff --git a/clrs/_src/nets.py b/clrs/_src/nets.py index 4f48e8c4..3559c5b6 100644 --- a/clrs/_src/nets.py +++ b/clrs/_src/nets.py @@ -27,6 +27,7 @@ from clrs._src import processors from clrs._src import samplers from clrs._src import specs +from clrs._src import subgraphs_utils import haiku as hk import jax @@ -86,9 +87,11 @@ def __init__( use_lstm: bool, dropout_prob: float, hint_teacher_forcing_noise: float, + subgraph_mode: str, nb_dims=None, name: str = 'net', ): + """Constructs a `Net`.""" super().__init__(name=name) @@ -102,6 +105,7 @@ def __init__( self.processor_factory = processor_factory self.nb_dims = nb_dims self.use_lstm = use_lstm + self.subgraph_mode = subgraph_mode def _msg_passing_step(self, mp_state: _MessagePassingScanState, @@ -147,29 +151,38 @@ def _msg_passing_step(self, probing.DataPoint( name=hint.name, location=loc, type_=typ, data=hint_data)) - gt_diffs = None - if hints[0].data.shape[0] > 1 and self.decode_diffs: + def get_gt_diffs(hints, first_idx, second_idx, batch_size, nb_nodes): gt_diffs = { _Location.NODE: jnp.zeros((batch_size, nb_nodes)), _Location.EDGE: jnp.zeros((batch_size, nb_nodes, nb_nodes)), _Location.GRAPH: jnp.zeros((batch_size)) } for hint in hints: - hint_cur = jax.lax.dynamic_index_in_dim(hint.data, i, 0, keepdims=False) + hint_cur = jax.lax.dynamic_index_in_dim(hint.data, first_idx, 0, keepdims=False) hint_nxt = jax.lax.dynamic_index_in_dim( - hint.data, i+1, 0, keepdims=False) + hint.data, second_idx, 0, keepdims=False) if len(hint_cur.shape) == len(gt_diffs[hint.location].shape): hint_cur = jnp.expand_dims(hint_cur, -1) hint_nxt = jnp.expand_dims(hint_nxt, -1) gt_diffs[hint.location] += jnp.any(hint_cur != hint_nxt, axis=-1) for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: gt_diffs[loc] = (gt_diffs[loc] > 0.0).astype(jnp.float32) * 1.0 + return gt_diffs + + gt_diffs = None + if hints[0].data.shape[0] > 1 and self.decode_diffs: + gt_diffs = get_gt_diffs(hints, i, i+1, batch_size, nb_nodes) + + gt_diffs_prev = None + if hints[0].data.shape[0] > 1 and self.subgraph_mode is not None: + if not first_step: + gt_diffs_prev = get_gt_diffs(hints, i, i-1, batch_size, nb_nodes) (hiddens, output_preds_cand, hint_preds, diff_logits, lstm_state) = self._one_step_pred(inputs, cur_hint, mp_state.hiddens, batch_size, nb_nodes, mp_state.lstm_state, - spec, encs, decs, diff_decs) + spec, encs, decs, diff_decs, gt_diffs_prev) if first_step: output_preds = output_preds_cand @@ -377,6 +390,7 @@ def _one_step_pred( encs: Dict[str, List[hk.Module]], decs: Dict[str, Tuple[hk.Module]], diff_decs: Dict[str, Any], + gt_diffs_prev: Dict[_Location, Any], ): """Generates one-step predictions.""" @@ -405,12 +419,21 @@ def _one_step_pred( except Exception as e: raise Exception(f'Failed to process {dp}') from e + msg_adj_mat = adj_mat + if gt_diffs_prev is not None: + if self.subgraph_mode == "egonets": + msg_adj_mat = subgraphs_utils.get_egonets(gt_diffs_prev[_Location.NODE], adj_mat) + elif self.subgraph_mode == "stars": + msg_adj_mat = subgraphs_utils.get_stars(gt_diffs_prev[_Location.NODE], adj_mat) + else: + raise ValueError(f"Invalid subgraph_mode {self.subgraph_mode}") + # PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nxt_hidden = self.processor( node_fts, edge_fts, graph_fts, - adj_mat, + msg_adj_mat, hidden, batch_size=batch_size, nb_nodes=nb_nodes, diff --git a/clrs/_src/subgraphs_utils.py b/clrs/_src/subgraphs_utils.py new file mode 100644 index 00000000..2458d455 --- /dev/null +++ b/clrs/_src/subgraphs_utils.py @@ -0,0 +1,74 @@ +import jax.numpy as jnp + + +# ([B, N], [B, N, N]) -> [B, N, N] +# NOTE: also keeps edges across egonets +def get_egonets(center_nodes, adj_mat): + """Returns the adj matrix consisting of ego nets around `center_nodes` + + Since jnp.nonzero is not compatible with JIT, we add an auxiliary node for each graph, + and an auxiliary graph in the batch and use the `fill_value` arg to return indices to those + added tensors. + """ + num_graphs, num_nodes = center_nodes.shape + + # Add one node and one graph + center_nodes = jnp.concatenate([jnp.zeros((1, num_nodes)), center_nodes]) + center_nodes = jnp.concatenate([jnp.zeros((num_graphs + 1, 1)), center_nodes], axis=-1) + + adj_mat = jnp.concatenate([jnp.zeros((1, num_nodes, num_nodes)), adj_mat]) + adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, 1, num_nodes)), adj_mat], axis=1) + adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, num_nodes + 1, 1)), adj_mat], axis=-1) + + graph_idx, node_idx = center_nodes.nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0) + + # [K, N] where K is the total number of center_nodes (summed over graphs) + center_adj_cols = adj_mat[graph_idx, :, node_idx] + + # [B, N]: for each graph, whether node n is a neighbour of a center_node + center_neighbors = jnp.zeros((adj_mat.shape[0], adj_mat.shape[-1])) + center_neighbors = jnp.array(center_neighbors).at[graph_idx].add(center_adj_cols) + + # Add center nodes + ego_nodes = center_neighbors + center_nodes + + graph_idx, removed_node_idx = (ego_nodes == 0).nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0) + + # Zero out edges incoming/outgoing to/from removed nodes + adj_mat = jnp.array(adj_mat).at[graph_idx, removed_node_idx].set(0) + adj_mat = jnp.array(adj_mat).at[graph_idx, :, removed_node_idx].set(0) + + # Remove the added node and graph + adj_mat = adj_mat[1:, 1:, 1:] + return adj_mat + +# ([B, N], [B, N, N]) -> [B, N, N] +def get_stars(center_nodes, adj_mat): + """Returns the adj matrix consisting of star subgraphs around `center_nodes` + + Since jnp.nonzero is not compatible with JIT, we add an auxiliary node for each graph, + and an auxiliary graph in the batch and use the `fill_value` arg to return indices to those + added tensors. + """ + num_graphs, num_nodes = center_nodes.shape + + # Add one node and one graph + center_nodes = jnp.concatenate([jnp.zeros((1, num_nodes)), center_nodes]) + center_nodes = jnp.concatenate([jnp.zeros((num_graphs + 1, 1)), center_nodes], axis=-1) + + adj_mat = jnp.concatenate([jnp.zeros((1, num_nodes, num_nodes)), adj_mat]) + adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, 1, num_nodes)), adj_mat], axis=1) + adj_mat = jnp.concatenate([jnp.zeros((num_graphs + 1, num_nodes + 1, 1)), adj_mat], axis=-1) + + graph_idx, node_idx = center_nodes.nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0) + + # [K, N] where K is the total number of center_nodes (summed over graphs) + center_adj_cols = adj_mat[graph_idx, :, node_idx] + + # Zero out all edges, except those outgoing from center_nodes + new_adj_mat = jnp.zeros(adj_mat.shape) + new_adj_mat = new_adj_mat.at[graph_idx, :, node_idx].set(center_adj_cols) + + # Remove the added node and graph + new_adj_mat = new_adj_mat[1:, 1:, 1:] + return new_adj_mat diff --git a/clrs/examples/run.py b/clrs/examples/run.py index 49967318..1fe8228d 100644 --- a/clrs/examples/run.py +++ b/clrs/examples/run.py @@ -95,6 +95,11 @@ 'Path in which dataset is stored.') flags.DEFINE_boolean('freeze_processor', False, 'Whether to freeze the processor of the model.') +flags.DEFINE_enum('subgraph_mode', 'none', + ['stars', 'egonets', 'none'], + 'If not `None`, then use as adjacency matrix the subgraph ' + 'around the nodes having hints that changed from the ' + 'last timestep ') FLAGS = flags.FLAGS @@ -219,6 +224,9 @@ def main(unused_argv): else: raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.') + if FLAGS.subgraph_mode == 'none': + FLAGS.subgraph_mode = None + common_args = dict(folder=dataset_folder, algorithm=FLAGS.algorithm, batch_size=FLAGS.batch_size) @@ -257,6 +265,7 @@ def main(unused_argv): freeze_processor=FLAGS.freeze_processor, dropout_prob=FLAGS.dropout_prob, hint_teacher_forcing_noise=FLAGS.hint_teacher_forcing_noise, + subgraph_mode=FLAGS.subgraph_mode, ) eval_model = clrs.models.BaselineModel( From 26e1cc83874f9a1db27e0afd856f3c9634b3c734 Mon Sep 17 00:00:00 2001 From: beabevi Date: Tue, 30 Aug 2022 22:02:14 -0400 Subject: [PATCH 2/3] Made subgraph adj_mat symmetric Keep all edges incoming to and outgoing from center nodes --- clrs/_src/subgraphs_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/clrs/_src/subgraphs_utils.py b/clrs/_src/subgraphs_utils.py index 2458d455..f3cd0678 100644 --- a/clrs/_src/subgraphs_utils.py +++ b/clrs/_src/subgraphs_utils.py @@ -27,7 +27,7 @@ def get_egonets(center_nodes, adj_mat): # [B, N]: for each graph, whether node n is a neighbour of a center_node center_neighbors = jnp.zeros((adj_mat.shape[0], adj_mat.shape[-1])) - center_neighbors = jnp.array(center_neighbors).at[graph_idx].add(center_adj_cols) + center_neighbors = center_neighbors.at[graph_idx].add(center_adj_cols) # Add center nodes ego_nodes = center_neighbors + center_nodes @@ -35,8 +35,8 @@ def get_egonets(center_nodes, adj_mat): graph_idx, removed_node_idx = (ego_nodes == 0).nonzero(size=(num_graphs+1)*(num_nodes+1), fill_value=0) # Zero out edges incoming/outgoing to/from removed nodes - adj_mat = jnp.array(adj_mat).at[graph_idx, removed_node_idx].set(0) - adj_mat = jnp.array(adj_mat).at[graph_idx, :, removed_node_idx].set(0) + adj_mat = adj_mat.at[graph_idx, removed_node_idx].set(0) + adj_mat = adj_mat.at[graph_idx, :, removed_node_idx].set(0) # Remove the added node and graph adj_mat = adj_mat[1:, 1:, 1:] @@ -64,10 +64,12 @@ def get_stars(center_nodes, adj_mat): # [K, N] where K is the total number of center_nodes (summed over graphs) center_adj_cols = adj_mat[graph_idx, :, node_idx] + center_adj_rows = adj_mat[graph_idx, node_idx, :] - # Zero out all edges, except those outgoing from center_nodes + # Zero out all edges, except those incoming/outgoing to/from center_nodes new_adj_mat = jnp.zeros(adj_mat.shape) new_adj_mat = new_adj_mat.at[graph_idx, :, node_idx].set(center_adj_cols) + new_adj_mat = new_adj_mat.at[graph_idx, node_idx].add(center_adj_rows) # Remove the added node and graph new_adj_mat = new_adj_mat[1:, 1:, 1:] From d57c3df496dc2c0134a260db83fdf0355c50bfb7 Mon Sep 17 00:00:00 2001 From: beabevi Date: Tue, 30 Aug 2022 22:14:28 -0400 Subject: [PATCH 3/3] Added gradient clipping --- clrs/_src/baselines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index 0807213f..43e2daab 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -117,7 +117,7 @@ def __init__( self.checkpoint_path = checkpoint_path self.name = name self._freeze_processor = freeze_processor - self.opt = optax.adam(learning_rate) + self.opt = optax.chain(optax.clip_by_global_norm(1), optax.adam(learning_rate)) self.nb_dims = [] if isinstance(dummy_trajectory, _Feedback):