Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.
This repository was archived by the owner on May 21, 2025. It is now read-only.

Replicating GAT with CORA dataset #26

@Steboss89

Description

@Steboss89

Hello,

Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using jraph . I started from your example notebook, implementing GAT, along with add_self_edges_fn:

def add_self_edges_fn(receivers: jnp.ndarray,
                      senders: jnp.ndarray,
                      total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    r"""Adds self edges. Assumes self edges are not in the graph yet."""
    receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
    senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
    return receivers, senders
  
def GAT(attention_query_fn: Callable,
        attention_logit_fn: Callable,
        node_update_fn: Optional[Callable] = None,
        add_self_edges: bool = True) -> Callable:
    r""" Main GAT function"""
    # pylint: disable=g-long-lambda
    if node_update_fn is None:
        # By default, apply the leaky relu and then concatenate the heads on the
        # feature axis.
        node_update_fn = lambda x: jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))

    def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError('GAT requires node features')

        nodes = attention_query_fn(nodes)
        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]

        if add_self_edges:
            receivers, senders = add_self_edges_fn(receivers, senders,
                                                    total_num_nodes)
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        att_softmax_logits = attention_logit_fn(sent_attributes,
                                                received_attributes, edges)

        att_weights = jraph.segment_softmax(
            att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)

        messages = sent_attributes * att_weights

        nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)

        nodes = node_update_fn(nodes)

        return graph._replace(nodes=nodes)

    return _ApplyGAT


def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """ Define GAT algorithm to run 
    Parameters
    ----------
    graph: jraph.GraphsTupe, input network to be processed 
    
    Return 
    -------
    jraph.GraphsTuple updated node graph
    """

    def _attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
                            edges: jnp.ndarray) -> jnp.ndarray:
        del edges
        x = jnp.concatenate((sender_attr, receiver_attr), axis=-1)
        return jax.nn.leaky_relu(hk.Linear(1)(x))

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=None,
        add_self_edges=True)
    graph = gn(graph)

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=hk.Linear(2),
        add_self_edges=True)
    graph = gn(graph)
    return graph

Then, after defining the main GAT, I run the training as:


def run_cora(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
  r""" Run training on CORA dataset """
  cora_graph = cora_ds[0]['input_graph']
  labels = cora_ds[0]['target']
  params = network.init(jax.random.PRNGKey(42), cora_graph)

  @jax.jit
  def predict(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.argmax(decoded_graph.nodes, axis=1)

  @jax.jit
  def prediction_loss(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    preds = jnp.argmax(decoded_graph.nodes, axis=1)
    # We interpret the decoded nodes as a pair of logits for each node.
    loss = compute_bce_with_logits_loss(preds, labels)
    return loss#, preds

  opt_init, opt_update = optax.adam(5e-4)
  opt_state = opt_init(params)

  @jax.jit
  def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
    """Returns updated params and state."""
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)

  for step in range(num_steps):
    if step%100==0:
        print(f"step {step} accuracy {accuracy(params).item():.2f}")
    params, opt_state = update(params, opt_state)

  return predict(params)

The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions