From c4860d4871e5703f616652975a0959b84cd4231f Mon Sep 17 00:00:00 2001 From: Marc Brockschmidt Date: Tue, 8 Sep 2020 09:28:49 +0000 Subject: [PATCH] feat(GraphNorm): Support GraphNorm as option --- tf2_gnn/layers/gnn.py | 15 +++++ tf2_gnn/layers/graph_norm.py | 109 +++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 tf2_gnn/layers/graph_norm.py diff --git a/tf2_gnn/layers/gnn.py b/tf2_gnn/layers/gnn.py index 45c7305..617a422 100644 --- a/tf2_gnn/layers/gnn.py +++ b/tf2_gnn/layers/gnn.py @@ -16,6 +16,7 @@ GraphGlobalGRUExchange, GraphGlobalMLPExchange, ) +from .graph_norm import GraphNorm, GraphNormInput class GNNInput(NamedTuple): @@ -61,6 +62,7 @@ def get_default_hyperparameters(cls, mp_style: Optional[str] = None) -> Dict[str "dense_every_num_layers": 2, "residual_every_num_layers": 2, "use_inter_layer_layernorm": False, + "use_graphnorm": False, "hidden_dim": 16, "layer_input_dropout_rate": 0.0, "global_exchange_mode": "gru", # One of "mean", "mlp", "gru" @@ -87,6 +89,7 @@ def __init__(self, params: Dict[str, Any]): self._dense_every_num_layers = params["dense_every_num_layers"] self._residual_every_num_layers = params["residual_every_num_layers"] self._use_inter_layer_layernorm = params["use_inter_layer_layernorm"] + self._use_graphnorm = params.get("use_graphnorm") or False self._initial_node_representation_activation_fn = get_activation_function( params["initial_node_representation_activation"] ) @@ -110,6 +113,7 @@ def __init__(self, params: Dict[str, Any]): # Layer member variables. To be filled in in the `build` method. self._initial_projection_layer: tf.keras.layers.Layer = None self._mp_layers: List[MessagePassing] = [] + self._graphnorm_layers: List[GraphNorm] = [] self._inter_layer_layernorms: List[tf.keras.layers.Layer] = [] self._dense_layers: Dict[str, tf.keras.layers.Layer] = {} self._global_exchange_layers: Dict[str, GraphGlobalExchange] = {} @@ -151,6 +155,12 @@ def build(self, tensor_shapes: GNNInput): MessagePassingInput(embedded_shape, adjacency_list_shapes) ) + if self._use_graphnorm: + self._graphnorm_layers.append(GraphNorm()) + self._graphnorm_layers[-1].build( + GraphNormInput(embedded_shape, tf.TensorShape((None,))), + ) + # If required, prepare for a LayerNorm: if self._use_inter_layer_layernorm: with tf.name_scope(f"LayerNorm"): @@ -302,6 +312,11 @@ def _internal_call(self, inputs: GNNInput, training: bool = False): ), training=training, ) + if self._use_graphnorm: + cur_node_representations = self._graphnorm_layers[layer_idx]( + GraphNormInput(cur_node_representations, inputs.node_to_graph_map), + training=training, + ) all_node_representations.append(cur_node_representations) if layer_idx and layer_idx % self._global_exchange_every_num_layers == 0: diff --git a/tf2_gnn/layers/graph_norm.py b/tf2_gnn/layers/graph_norm.py new file mode 100644 index 0000000..b2c0d80 --- /dev/null +++ b/tf2_gnn/layers/graph_norm.py @@ -0,0 +1,109 @@ +from typing import NamedTuple, Optional + +import tensorflow as tf +from tensorflow.python.keras import initializers + +from tf2_gnn.utils.constants import SMALL_NUMBER + + +class GraphNormInput(NamedTuple): + """Input named tuple for the GraphNorm.""" + + node_features: tf.Tensor + node_to_graph_map: tf.Tensor + + +class GraphNorm(tf.keras.layers.Layer): + """Implementation of Graph Norm (https://arxiv.org/pdf/2009.03294.pdf). + Normalises node representations by the graph mean/variance. + Given node representations h_{i, j} from a single graph, computes + GraphNorm(h_{i, j}) = \gamma_j * (h_{i, j} - \alpha_j * \mu_j) / \sigma_j + \beta_j + with \alpha_j, \beta_j, \gamma_j learnable, and + \mu_j = 1/n \sum_i^n h_{i, j} + \sigma_j^2 = 1/n \sum_i^n (h_{i, j} - \alpha_j * \mu_j)^2 + """ + def __init__( + self, + center: bool = True, + scale: bool = True, + learnable_shift: bool=True, + **kwargs + ): + super().__init__(**kwargs) + self._center = center + self._scale = scale + self._learnable_shift = learnable_shift + + def build(self, input_shape: GraphNormInput): + params_shape = (input_shape.node_features[-1],) + + if self._learnable_shift: + self.alpha = self.add_weight( + name='alpha', + shape=params_shape, + initializer=initializers.get('ones'), + trainable=True, + dtype=tf.float32, + ) + else: + self.alpha = None + + if self._center: + self.beta = self.add_weight( + name='beta', + shape=params_shape, + initializer=initializers.get('zero'), + trainable=True, + dtype=tf.float32, + ) + else: + self.beta = None + + if self._scale: + self.gamma = self.add_weight( + name='gamma', + shape=params_shape, + initializer=initializers.get('ones'), + trainable=True, + dtype=tf.float32, + ) + else: + self.gamma = None + + super().build(input_shape) + + def call(self, inputs: GraphNormInput, training: Optional[bool]=None): + # Compute mean + graph_means = tf.math.segment_mean( + data=inputs.node_features, segment_ids=inputs.node_to_graph_map + ) # Shape [G, GD] + + per_node_graph_means = tf.gather( + params=graph_means, + indices=inputs.node_to_graph_map, + ) + + if self._learnable_shift: + centered_node_features = inputs.node_features - self.alpha * per_node_graph_means + else: + centered_node_features = inputs.node_features - per_node_graph_means + + graph_variances = tf.math.segment_mean( + data=tf.square(centered_node_features), + segment_ids=inputs.node_to_graph_map, + ) # Shape [G, GD]) + + per_node_graph_stdev = tf.gather( + params=tf.sqrt(graph_variances), + indices=inputs.node_to_graph_map, + ) + + output = centered_node_features / (per_node_graph_stdev + SMALL_NUMBER) + + if self._scale: + output *= self.gamma + + if self._center: + output += self.beta + + return output