Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tf2_gnn/layers/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GraphGlobalGRUExchange,
GraphGlobalMLPExchange,
)
from .graph_norm import GraphNorm, GraphNormInput


class GNNInput(NamedTuple):
Expand Down Expand Up @@ -61,6 +62,7 @@ def get_default_hyperparameters(cls, mp_style: Optional[str] = None) -> Dict[str
"dense_every_num_layers": 2,
"residual_every_num_layers": 2,
"use_inter_layer_layernorm": False,
"use_graphnorm": False,
"hidden_dim": 16,
"layer_input_dropout_rate": 0.0,
"global_exchange_mode": "gru", # One of "mean", "mlp", "gru"
Expand All @@ -87,6 +89,7 @@ def __init__(self, params: Dict[str, Any]):
self._dense_every_num_layers = params["dense_every_num_layers"]
self._residual_every_num_layers = params["residual_every_num_layers"]
self._use_inter_layer_layernorm = params["use_inter_layer_layernorm"]
self._use_graphnorm = params.get("use_graphnorm") or False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it maybe make sense to add something like

assert not (self._use_graphnorm and self._use_inter_layer_layernorm), "Using layer normalization and graph normalization should not be used together."

self._initial_node_representation_activation_fn = get_activation_function(
params["initial_node_representation_activation"]
)
Expand All @@ -110,6 +113,7 @@ def __init__(self, params: Dict[str, Any]):
# Layer member variables. To be filled in in the `build` method.
self._initial_projection_layer: tf.keras.layers.Layer = None
self._mp_layers: List[MessagePassing] = []
self._graphnorm_layers: List[GraphNorm] = []
self._inter_layer_layernorms: List[tf.keras.layers.Layer] = []
self._dense_layers: Dict[str, tf.keras.layers.Layer] = {}
self._global_exchange_layers: Dict[str, GraphGlobalExchange] = {}
Expand Down Expand Up @@ -151,6 +155,12 @@ def build(self, tensor_shapes: GNNInput):
MessagePassingInput(embedded_shape, adjacency_list_shapes)
)

if self._use_graphnorm:
self._graphnorm_layers.append(GraphNorm())
self._graphnorm_layers[-1].build(
GraphNormInput(embedded_shape, tf.TensorShape((None,))),
)

# If required, prepare for a LayerNorm:
if self._use_inter_layer_layernorm:
with tf.name_scope(f"LayerNorm"):
Expand Down Expand Up @@ -302,6 +312,11 @@ def _internal_call(self, inputs: GNNInput, training: bool = False):
),
training=training,
)
if self._use_graphnorm:
cur_node_representations = self._graphnorm_layers[layer_idx](
GraphNormInput(cur_node_representations, inputs.node_to_graph_map),
training=training,
)
all_node_representations.append(cur_node_representations)

if layer_idx and layer_idx % self._global_exchange_every_num_layers == 0:
Expand Down
109 changes: 109 additions & 0 deletions tf2_gnn/layers/graph_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import NamedTuple, Optional

import tensorflow as tf
from tensorflow.python.keras import initializers

from tf2_gnn.utils.constants import SMALL_NUMBER


class GraphNormInput(NamedTuple):
"""Input named tuple for the GraphNorm."""

node_features: tf.Tensor
node_to_graph_map: tf.Tensor


class GraphNorm(tf.keras.layers.Layer):
"""Implementation of Graph Norm (https://arxiv.org/pdf/2009.03294.pdf).
Normalises node representations by the graph mean/variance.
Given node representations h_{i, j} from a single graph, computes
GraphNorm(h_{i, j}) = \gamma_j * (h_{i, j} - \alpha_j * \mu_j) / \sigma_j + \beta_j
with \alpha_j, \beta_j, \gamma_j learnable, and
\mu_j = 1/n \sum_i^n h_{i, j}
\sigma_j^2 = 1/n \sum_i^n (h_{i, j} - \alpha_j * \mu_j)^2
"""
def __init__(
self,
center: bool = True,
scale: bool = True,
learnable_shift: bool=True,
**kwargs
):
super().__init__(**kwargs)
self._center = center
self._scale = scale
self._learnable_shift = learnable_shift

def build(self, input_shape: GraphNormInput):
params_shape = (input_shape.node_features[-1],)

if self._learnable_shift:
self.alpha = self.add_weight(
name='alpha',
shape=params_shape,
initializer=initializers.get('ones'),
trainable=True,
dtype=tf.float32,
)
else:
self.alpha = None

if self._center:
self.beta = self.add_weight(
name='beta',
shape=params_shape,
initializer=initializers.get('zero'),
trainable=True,
dtype=tf.float32,
)
else:
self.beta = None

if self._scale:
self.gamma = self.add_weight(
name='gamma',
shape=params_shape,
initializer=initializers.get('ones'),
trainable=True,
dtype=tf.float32,
)
else:
self.gamma = None

super().build(input_shape)

def call(self, inputs: GraphNormInput, training: Optional[bool]=None):
# Compute mean
graph_means = tf.math.segment_mean(
data=inputs.node_features, segment_ids=inputs.node_to_graph_map
) # Shape [G, GD]

per_node_graph_means = tf.gather(
params=graph_means,
indices=inputs.node_to_graph_map,
)

if self._learnable_shift:
centered_node_features = inputs.node_features - self.alpha * per_node_graph_means
else:
centered_node_features = inputs.node_features - per_node_graph_means

graph_variances = tf.math.segment_mean(
data=tf.square(centered_node_features),
segment_ids=inputs.node_to_graph_map,
) # Shape [G, GD])

per_node_graph_stdev = tf.gather(
params=tf.sqrt(graph_variances),
indices=inputs.node_to_graph_map,
)

output = centered_node_features / (per_node_graph_stdev + SMALL_NUMBER)

if self._scale:
output *= self.gamma

if self._center:
output += self.beta

return output