Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 183 additions & 19 deletions python/gigl/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ def unwrap_from_ddp(self) -> "LinkPredictionGNN":
return LinkPredictionGNN(encoder=encoder, decoder=decoder)


def _get_feature_key(node_type: Union[str, NodeType]) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd suggest a clearer name, e.g. _build_node_embedding_table_name_by_node_type. i think we might actually use something like get_feature_key for other generic functionalities inside GiGL

"""
Get the feature key for a node type's embedding table.

Args:
node_type: Node type as string or NodeType object.

Returns:
str: Feature key in format "{node_type}_id"
"""
return f"{node_type}_id"


# TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement.
# TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific
# TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer)
Expand Down Expand Up @@ -187,18 +200,29 @@ def __init__(
)

# Build TorchRec EBC (one table per node type)
# feature key naming convention: f"{node_type}_id"
# Sort node types for deterministic ordering across machines
self._feature_keys: list[str] = [
f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys()
_get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this use e.g. SortedDict functionalities as introduced in e.g. #391? i assume it was introduced for such usecases

]

# Validate model configuration: restrict to homogeneous or bipartite graphs
num_node_types = len(self._feature_keys)
if num_node_types not in [1, 2]:
# TODO(kmonte, swong3): We should loosen this restriction and allow fully heterogenous graphs in the future.
raise ValueError(
f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; "
f"got {num_node_types} node types: {self._feature_keys}"
)

tables: list[EmbeddingBagConfig] = []
for node_type, num_nodes in self._node_type_to_num_nodes.items():
# Sort node types for deterministic ordering across machines
for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flagging in case sorteddict should be used here

tables.append(
EmbeddingBagConfig(
name=f"node_embedding_{node_type}",
embedding_dim=embedding_dim,
num_embeddings=num_nodes,
feature_names=[f"{node_type}_id"],
feature_names=[_get_feature_key(node_type)],
)
)

Expand All @@ -215,32 +239,44 @@ def forward(
self,
data: Union[Data, HeteroData],
device: torch.device,
output_node_types: Optional[list[NodeType]] = None,
anchor_node_ids: Optional[torch.Tensor] = None,
anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None,
) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
"""
Forward pass of the LightGCN model.

Args:
data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous).
data (Union[Data, HeteroData]): Graph data.
- For homogeneous: Data object with edge_index and node field
- For heterogeneous: HeteroData with node types and edge_index_dict
device (torch.device): Device to run the computation on.
output_node_types (Optional[List[NodeType]]): List of node types to return
embeddings for. Required for heterogeneous graphs. Default: None.
anchor_node_ids (Optional[torch.Tensor]): Local node indices to return
embeddings for. If None, returns embeddings for all nodes. Default: None.
anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]):
Local node indices to return embeddings for.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just clarifying -- "local" here refers specifically to the local index in the object, which is different from the global node ID, right?

- For homogeneous: torch.Tensor of shape [num_anchors]
- For heterogeneous: dict mapping node types to anchor tensors
If None, returns embeddings for all nodes. Default: None.

Returns:
Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings.
For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim].
For heterogeneous graphs, returns dict mapping node types to embeddings.
- For homogeneous: tensor of shape [num_nodes, embedding_dim]
- For heterogeneous: dict mapping node types to embeddings
"""
if isinstance(data, HeteroData):
raise NotImplementedError("HeteroData is not yet supported for LightGCN")
output_node_types = output_node_types or list(data.node_types)
return self._forward_heterogeneous(
data, device, output_node_types, anchor_node_ids
)
is_heterogeneous = isinstance(data, HeteroData)

if is_heterogeneous:
# For heterogeneous graphs, anchor_node_ids must be a dict, not a Tensor
if anchor_node_ids is not None and not isinstance(anchor_node_ids, dict):
raise TypeError(
f"For heterogeneous graphs, anchor_node_ids must be a dict or None, "
f"got {type(anchor_node_ids)}"
)
return self._forward_heterogeneous(data, device, anchor_node_ids)
else:
# For homogeneous graphs, anchor_node_ids must be a Tensor, not a dict
if anchor_node_ids is not None and not isinstance(anchor_node_ids, torch.Tensor):
raise TypeError(
f"For homogeneous graphs, anchor_node_ids must be a Tensor or None, "
f"got {type(anchor_node_ids)}"
)
return self._forward_homogeneous(data, device, anchor_node_ids)

def _forward_homogeneous(
Expand Down Expand Up @@ -323,6 +359,134 @@ def _forward_homogeneous(
final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph
)

def _forward_heterogeneous(
self,
data: HeteroData,
device: torch.device,
anchor_node_ids: Optional[dict[NodeType, torch.Tensor]] = None,
) -> dict[NodeType, torch.Tensor]:
"""
Forward pass for heterogeneous graphs using LightGCN propagation.

For heterogeneous graphs (e.g., user-item), we have
multiple node types. LightGCN propagates embeddings across
all node types by creating a unified node space, running propagation, then splitting
back into per-type embeddings.

Note: All node types in the graph are processed during message passing, as this is
required for correct GNN computation. Use anchor_node_ids to filter which node types
and specific nodes are returned in the output.

Args:
data (HeteroData): PyG HeteroData object with node types.
device (torch.device): Device to run computation on.
anchor_node_ids (Optional[Dict[NodeType, torch.Tensor]]): Dict mapping node types
to local anchor indices. If None, returns all nodes for all types.
If provided, only returns embeddings for the specified node types and indices.

Returns:
Dict[NodeType, torch.Tensor]: Dict mapping node types to their embeddings,
each of shape [num_nodes_of_type, embedding_dim] (or [num_anchors, embedding_dim]
if anchor_node_ids is provided for that type).
"""
# Process all node types - this is required for correct message passing in GNNs
# Sort node types for deterministic ordering across machines
all_node_types_in_data = [NodeType(nt) for nt in sorted(data.node_types)]

# Lookup initial embeddings e^(0) for each node type
node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {}

for node_type in all_node_types_in_data:
node_type_str = str(node_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be less confusing to just have this variable be implicit and use _get_feature_key(str(node_type)) .

Managing two variables that both are strings and both slightly differently reflect the node type is actually somehow more confusing than just using one and appropriately using a function modifier when you need it IMO

key = _get_feature_key(node_type_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use a more descriptive name, e.g. embedding_table_key or something


assert hasattr(data[node_type_str], "node"), (
f"Subgraph must include .node field for node type {node_type_str}"
)

global_ids = data[node_type_str].node.to(device).long() # shape [N_type]

embeddings = self._lookup_embeddings_for_single_node_type(
key, global_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: kwargs

) # shape [N_type, D]

# Handle DMP Awaitable
if isinstance(embeddings, Awaitable):
embeddings = embeddings.wait()

node_type_to_embeddings_0[node_type] = embeddings

# For heterogeneous graphs, we need to create a unified edge representation
# Collect all edges and map node indices to a combined space
# E.g., node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1)
node_type_to_offset: dict[NodeType, int] = {}
offset = 0
for node_type in all_node_types_in_data:
node_type_to_offset[node_type] = offset
node_type_str = str(node_type)
offset += data[node_type_str].num_nodes

# Combine all embeddings into a single tensor
combined_embeddings_0 = torch.cat(
[node_type_to_embeddings_0[nt] for nt in all_node_types_in_data], dim=0
) # shape [total_nodes, D]

# Combine all edges into a single edge_index
# Sort edge types for deterministic ordering across machines
combined_edge_list: list[torch.Tensor] = []
for edge_type_tuple in sorted(data.edge_types):
src_nt_str, _, dst_nt_str = edge_type_tuple
src_node_type = NodeType(src_nt_str)
dst_node_type = NodeType(dst_nt_str)

edge_index = data[edge_type_tuple].edge_index.to(device) # shape [2, E]

# Offset the indices to the combined node space
src_offset = node_type_to_offset[src_node_type]
dst_offset = node_type_to_offset[dst_node_type]

offset_edge_index = edge_index.clone()
offset_edge_index[0] += src_offset
offset_edge_index[1] += dst_offset

combined_edge_list.append(offset_edge_index)

combined_edge_index = torch.cat(combined_edge_list, dim=1) # shape [2, total_edges]
Comment on lines +419 to +454
Copy link
Collaborator

@nshah-sc nshah-sc Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need to unify all these edge tensors and node tensors with this offset logic

why can't we use native PyG operations like message or propagate on the HeteroData or Data objects using e.g. torch_geometric.nn.conv.LGConv?

I think the important part of your current implementation is that it needs to be able to fetch embeddings at large scale. I am not sure you need to handroll your own convolution logic for messaging and aggregation too. The advantage of PyG is that it enables you to not have to do that right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also in general if you have K different types of edges between 2 types of nodes (e.g. user and item), it feels reasonable that the application of LightGCN on this setting would be like LightGCN convolution applied to each of these edge-types + some pooling (sum, mean, max, etc.) applied to the K resulting user embeddings or K resulting item embeddings. basic LightGCN can be seen as an application which only works on 1 type of edge, and skips the pooling for like an Identity function.

What you're doing here in terms of merging all the lists of edges actually feels like a specific sub-case of LightGCN and i'm not sure it is worth implementing it so specifically?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like our goal in implementing LightGCN in this exercise is to "plug and play" with TorchRec + PyG. I think we are currently doing something like TorchRec + hand-rolling, which feels clunky if that makes sense.

After all the end goal we have with all this work (whether LightGCN or otherwise) is to connect the TorchRec elements to PyG elements more generally for GiGl models so following that pattern seems logical here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we need to unify all these edge tensors and node tensors with this offset logic

I may be a pyg noob here but my understanding is that [part of] the highlighted code is to build node_type_to_offset so that we can do the one lookup into the torchrec tables and then later reference the embeddings by node type?

Maybe I'm a bit confused as to what you'd be suggesting here? Is it to do the conv by edge index? e..g ~ for edge_index in data.edge_indicies(): embeddings.append(self.conv(edge_index))

also in general if you have K different types of edges between 2 types of nodes (e.g. user and item), it feels reasonable that the application of LightGCN on this setting would be like LightGCN convolution applied to each of these edge-types + some pooling (sum, mean, max, etc.) applied to the K resulting user embeddings or K resulting item embeddings. basic LightGCN can be seen as an application which only works on 1 type of edge, and skips the pooling for like an Identity function.

this makes sense actually! again a bit of a pyg noob so I'm not sure the "idiomatic" way to solve this problem? Would it be to have some base class and then have subclasses override a pool method on it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be a pyg noob here but my understanding is that [part of] the highlighted code is to build node_type_to_offset so that we can do the one lookup into the torchrec tables and then later reference the embeddings by node type?

the block from line 399-417 already gets all the node embeddings for each node_type by making len(all_node_types_in_data) calls to the TorchRec tables IIUC.

I think we already went through the bother of making these separate calls. You are right that perhaps they could be batched and fetched in one lookup if we were using a different implementation, but @swong3-sc is relying on _lookup_embeddings_for_single_node_type being called multiple times so that isn't happening here.

this makes sense actually! again a bit of a pyg noob so I'm not sure the "idiomatic" way to solve this problem? Would it be to have some base class and then have subclasses override a pool method on it?

I don't totally know the best "most idiomatic" way here as others have likely read more PyG code than me, but just looking at LGConv's implementation: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/lg_conv.html#LGConv it seems like this operates on a single tensor of node features x (which contains multiple types of nodes all in a single (num_nodes, F) tensor, and a single edge_index tensor which contains the edges between src and dst nodes. Presumably this edge_index is defined on a single edge type. we could just apply this on num_edge_types in a for loop or something and accumulate the embeddings for each one using some pooling function. I would probably say that our best bet when building these "torchrec-enhanced modules" is to rely on a pattern of "doing the TorchRec munging to get in PyG format" + "calling PyG" so we don't have to reinvent parts of the wheel that are already built and tested

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @swong3-sc is currently creating one big tensor of node features and one big edge_index which is a union of all edges in the individual edge-typed edge_index tensors, but the issue with this (I think) is that in reality, this should be treated as different edge_indices and convolutions, rather than just one. LGConv operation involves normalizing messages based on degree of nodes, which is specific to each edge type. The message-passing should (I believe) not be mixed across all edge types and should rather be specific to each edge type, so that we end up with multiple embeddings for each node if it is adjacent to multiple edge types.


# Track all layer embeddings
all_layer_embeddings: list[torch.Tensor] = [combined_embeddings_0]
current_embeddings = combined_embeddings_0

# Perform K layers of propagation
for conv in self._convs:
current_embeddings = conv(current_embeddings, combined_edge_index) # shape [total_nodes, D]
all_layer_embeddings.append(current_embeddings)

# Weighted sum across layers
combined_final_embeddings = self._weighted_layer_sum(all_layer_embeddings) # shape [total_nodes, D]

# Split back into per-node-type embeddings
final_embeddings: dict[NodeType, torch.Tensor] = {}
for node_type in all_node_types_in_data:
start_idx = node_type_to_offset[node_type]
node_type_str = str(node_type)
num_nodes = data[node_type_str].num_nodes
end_idx = start_idx + num_nodes

final_embeddings[node_type] = combined_final_embeddings[start_idx:end_idx] # shape [num_nodes, D]

# Extract anchor nodes if specified
if anchor_node_ids is not None:
# Only return embeddings for node types specified in anchor_node_ids
filtered_embeddings: dict[NodeType, torch.Tensor] = {}
for node_type in all_node_types_in_data:
if node_type in anchor_node_ids:
anchors = anchor_node_ids[node_type].to(device).long()
filtered_embeddings[node_type] = final_embeddings[node_type][anchors]
return filtered_embeddings

return final_embeddings

def _lookup_embeddings_for_single_node_type(
self, node_type: str, ids: torch.Tensor
) -> torch.Tensor:
Expand Down
Loading