-
Notifications
You must be signed in to change notification settings - Fork 12
Add Bipartite Implementation for LightGCN #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bb1844b
ed6a2cd
39e1d6d
50e1f47
1cdc76c
117925a
2928a9b
f880b85
e6b96ef
7cc6e93
62fc9ed
b069715
cf84fcd
9ca696d
15f9a07
00b06be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -134,6 +134,19 @@ def unwrap_from_ddp(self) -> "LinkPredictionGNN": | |
| return LinkPredictionGNN(encoder=encoder, decoder=decoder) | ||
|
|
||
|
|
||
| def _get_feature_key(node_type: Union[str, NodeType]) -> str: | ||
| """ | ||
| Get the feature key for a node type's embedding table. | ||
|
|
||
| Args: | ||
| node_type: Node type as string or NodeType object. | ||
|
|
||
| Returns: | ||
| str: Feature key in format "{node_type}_id" | ||
| """ | ||
| return f"{node_type}_id" | ||
|
|
||
|
|
||
| # TODO(swong3): Move specific models to gigl.nn.models whenever we restructure model placement. | ||
| # TODO(swong3): Abstract TorchRec functionality, and make this LightGCN specific | ||
| # TODO(swong3): Remove device context from LightGCN module (use meta, but will have to figure out how to handle buffer transfer) | ||
|
|
@@ -187,18 +200,29 @@ def __init__( | |
| ) | ||
|
|
||
| # Build TorchRec EBC (one table per node type) | ||
| # feature key naming convention: f"{node_type}_id" | ||
| # Sort node types for deterministic ordering across machines | ||
| self._feature_keys: list[str] = [ | ||
| f"{node_type}_id" for node_type in self._node_type_to_num_nodes.keys() | ||
| _get_feature_key(node_type) for node_type in sorted(self._node_type_to_num_nodes.keys()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this use e.g. SortedDict functionalities as introduced in e.g. #391? i assume it was introduced for such usecases |
||
| ] | ||
|
|
||
| # Validate model configuration: restrict to homogeneous or bipartite graphs | ||
| num_node_types = len(self._feature_keys) | ||
| if num_node_types not in [1, 2]: | ||
| # TODO(kmonte, swong3): We should loosen this restriction and allow fully heterogenous graphs in the future. | ||
| raise ValueError( | ||
swong3-sc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f"LightGCN only supports homogeneous (1 node type) or bipartite (2 node types) graphs; " | ||
| f"got {num_node_types} node types: {self._feature_keys}" | ||
| ) | ||
|
|
||
| tables: list[EmbeddingBagConfig] = [] | ||
| for node_type, num_nodes in self._node_type_to_num_nodes.items(): | ||
| # Sort node types for deterministic ordering across machines | ||
| for node_type, num_nodes in sorted(self._node_type_to_num_nodes.items()): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flagging in case sorteddict should be used here |
||
| tables.append( | ||
| EmbeddingBagConfig( | ||
| name=f"node_embedding_{node_type}", | ||
| embedding_dim=embedding_dim, | ||
| num_embeddings=num_nodes, | ||
| feature_names=[f"{node_type}_id"], | ||
| feature_names=[_get_feature_key(node_type)], | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -215,32 +239,44 @@ def forward( | |
| self, | ||
| data: Union[Data, HeteroData], | ||
| device: torch.device, | ||
| output_node_types: Optional[list[NodeType]] = None, | ||
| anchor_node_ids: Optional[torch.Tensor] = None, | ||
| anchor_node_ids: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]] = None, | ||
| ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: | ||
| """ | ||
| Forward pass of the LightGCN model. | ||
|
|
||
| Args: | ||
| data (Union[Data, HeteroData]): Graph data (homogeneous or heterogeneous). | ||
| data (Union[Data, HeteroData]): Graph data. | ||
| - For homogeneous: Data object with edge_index and node field | ||
| - For heterogeneous: HeteroData with node types and edge_index_dict | ||
| device (torch.device): Device to run the computation on. | ||
| output_node_types (Optional[List[NodeType]]): List of node types to return | ||
| embeddings for. Required for heterogeneous graphs. Default: None. | ||
| anchor_node_ids (Optional[torch.Tensor]): Local node indices to return | ||
| embeddings for. If None, returns embeddings for all nodes. Default: None. | ||
| anchor_node_ids (Optional[Union[torch.Tensor, Dict[NodeType, torch.Tensor]]]): | ||
| Local node indices to return embeddings for. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just clarifying -- "local" here refers specifically to the local index in the object, which is different from the global node ID, right? |
||
| - For homogeneous: torch.Tensor of shape [num_anchors] | ||
| - For heterogeneous: dict mapping node types to anchor tensors | ||
| If None, returns embeddings for all nodes. Default: None. | ||
|
|
||
| Returns: | ||
| Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: Node embeddings. | ||
| For homogeneous graphs, returns tensor of shape [num_nodes, embedding_dim]. | ||
| For heterogeneous graphs, returns dict mapping node types to embeddings. | ||
| - For homogeneous: tensor of shape [num_nodes, embedding_dim] | ||
| - For heterogeneous: dict mapping node types to embeddings | ||
| """ | ||
| if isinstance(data, HeteroData): | ||
| raise NotImplementedError("HeteroData is not yet supported for LightGCN") | ||
| output_node_types = output_node_types or list(data.node_types) | ||
| return self._forward_heterogeneous( | ||
| data, device, output_node_types, anchor_node_ids | ||
| ) | ||
| is_heterogeneous = isinstance(data, HeteroData) | ||
|
|
||
| if is_heterogeneous: | ||
| # For heterogeneous graphs, anchor_node_ids must be a dict, not a Tensor | ||
| if anchor_node_ids is not None and not isinstance(anchor_node_ids, dict): | ||
| raise TypeError( | ||
| f"For heterogeneous graphs, anchor_node_ids must be a dict or None, " | ||
| f"got {type(anchor_node_ids)}" | ||
| ) | ||
| return self._forward_heterogeneous(data, device, anchor_node_ids) | ||
| else: | ||
| # For homogeneous graphs, anchor_node_ids must be a Tensor, not a dict | ||
| if anchor_node_ids is not None and not isinstance(anchor_node_ids, torch.Tensor): | ||
| raise TypeError( | ||
| f"For homogeneous graphs, anchor_node_ids must be a Tensor or None, " | ||
| f"got {type(anchor_node_ids)}" | ||
| ) | ||
| return self._forward_homogeneous(data, device, anchor_node_ids) | ||
|
|
||
| def _forward_homogeneous( | ||
|
|
@@ -323,6 +359,134 @@ def _forward_homogeneous( | |
| final_embeddings # shape [N_sub, D], embeddings for all nodes in subgraph | ||
| ) | ||
|
|
||
| def _forward_heterogeneous( | ||
| self, | ||
| data: HeteroData, | ||
| device: torch.device, | ||
| anchor_node_ids: Optional[dict[NodeType, torch.Tensor]] = None, | ||
| ) -> dict[NodeType, torch.Tensor]: | ||
| """ | ||
| Forward pass for heterogeneous graphs using LightGCN propagation. | ||
|
|
||
| For heterogeneous graphs (e.g., user-item), we have | ||
| multiple node types. LightGCN propagates embeddings across | ||
| all node types by creating a unified node space, running propagation, then splitting | ||
| back into per-type embeddings. | ||
|
|
||
| Note: All node types in the graph are processed during message passing, as this is | ||
| required for correct GNN computation. Use anchor_node_ids to filter which node types | ||
| and specific nodes are returned in the output. | ||
|
|
||
| Args: | ||
| data (HeteroData): PyG HeteroData object with node types. | ||
| device (torch.device): Device to run computation on. | ||
| anchor_node_ids (Optional[Dict[NodeType, torch.Tensor]]): Dict mapping node types | ||
| to local anchor indices. If None, returns all nodes for all types. | ||
| If provided, only returns embeddings for the specified node types and indices. | ||
|
|
||
| Returns: | ||
| Dict[NodeType, torch.Tensor]: Dict mapping node types to their embeddings, | ||
| each of shape [num_nodes_of_type, embedding_dim] (or [num_anchors, embedding_dim] | ||
| if anchor_node_ids is provided for that type). | ||
| """ | ||
| # Process all node types - this is required for correct message passing in GNNs | ||
| # Sort node types for deterministic ordering across machines | ||
| all_node_types_in_data = [NodeType(nt) for nt in sorted(data.node_types)] | ||
|
|
||
| # Lookup initial embeddings e^(0) for each node type | ||
| node_type_to_embeddings_0: dict[NodeType, torch.Tensor] = {} | ||
|
|
||
| for node_type in all_node_types_in_data: | ||
| node_type_str = str(node_type) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it might be less confusing to just have this variable be implicit and use Managing two variables that both are strings and both slightly differently reflect the node type is actually somehow more confusing than just using one and appropriately using a function modifier when you need it IMO |
||
| key = _get_feature_key(node_type_str) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: use a more descriptive name, e.g. embedding_table_key or something |
||
|
|
||
| assert hasattr(data[node_type_str], "node"), ( | ||
| f"Subgraph must include .node field for node type {node_type_str}" | ||
| ) | ||
|
|
||
| global_ids = data[node_type_str].node.to(device).long() # shape [N_type] | ||
|
|
||
| embeddings = self._lookup_embeddings_for_single_node_type( | ||
| key, global_ids | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: kwargs |
||
| ) # shape [N_type, D] | ||
|
|
||
| # Handle DMP Awaitable | ||
| if isinstance(embeddings, Awaitable): | ||
| embeddings = embeddings.wait() | ||
|
|
||
| node_type_to_embeddings_0[node_type] = embeddings | ||
|
|
||
| # For heterogeneous graphs, we need to create a unified edge representation | ||
| # Collect all edges and map node indices to a combined space | ||
| # E.g., node type 0 gets indices [0, num_type_0), node type 1 gets [num_type_0, num_type_0 + num_type_1) | ||
| node_type_to_offset: dict[NodeType, int] = {} | ||
| offset = 0 | ||
| for node_type in all_node_types_in_data: | ||
| node_type_to_offset[node_type] = offset | ||
| node_type_str = str(node_type) | ||
| offset += data[node_type_str].num_nodes | ||
|
|
||
| # Combine all embeddings into a single tensor | ||
| combined_embeddings_0 = torch.cat( | ||
| [node_type_to_embeddings_0[nt] for nt in all_node_types_in_data], dim=0 | ||
| ) # shape [total_nodes, D] | ||
|
|
||
| # Combine all edges into a single edge_index | ||
| # Sort edge types for deterministic ordering across machines | ||
| combined_edge_list: list[torch.Tensor] = [] | ||
| for edge_type_tuple in sorted(data.edge_types): | ||
| src_nt_str, _, dst_nt_str = edge_type_tuple | ||
| src_node_type = NodeType(src_nt_str) | ||
| dst_node_type = NodeType(dst_nt_str) | ||
|
|
||
| edge_index = data[edge_type_tuple].edge_index.to(device) # shape [2, E] | ||
|
|
||
| # Offset the indices to the combined node space | ||
| src_offset = node_type_to_offset[src_node_type] | ||
| dst_offset = node_type_to_offset[dst_node_type] | ||
|
|
||
| offset_edge_index = edge_index.clone() | ||
| offset_edge_index[0] += src_offset | ||
| offset_edge_index[1] += dst_offset | ||
|
|
||
| combined_edge_list.append(offset_edge_index) | ||
|
|
||
| combined_edge_index = torch.cat(combined_edge_list, dim=1) # shape [2, total_edges] | ||
|
Comment on lines
+419
to
+454
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused why we need to unify all these edge tensors and node tensors with this offset logic why can't we use native PyG operations like I think the important part of your current implementation is that it needs to be able to fetch embeddings at large scale. I am not sure you need to handroll your own convolution logic for messaging and aggregation too. The advantage of PyG is that it enables you to not have to do that right?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also in general if you have K different types of edges between 2 types of nodes (e.g. user and item), it feels reasonable that the application of LightGCN on this setting would be like LightGCN convolution applied to each of these edge-types + some pooling (sum, mean, max, etc.) applied to the K resulting user embeddings or K resulting item embeddings. basic LightGCN can be seen as an application which only works on 1 type of edge, and skips the pooling for like an Identity function. What you're doing here in terms of merging all the lists of edges actually feels like a specific sub-case of LightGCN and i'm not sure it is worth implementing it so specifically?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like our goal in implementing LightGCN in this exercise is to "plug and play" with TorchRec + PyG. I think we are currently doing something like TorchRec + hand-rolling, which feels clunky if that makes sense. After all the end goal we have with all this work (whether LightGCN or otherwise) is to connect the TorchRec elements to PyG elements more generally for GiGl models so following that pattern seems logical here
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I may be a pyg noob here but my understanding is that [part of] the highlighted code is to build Maybe I'm a bit confused as to what you'd be suggesting here? Is it to do the conv by edge index? e..g ~
this makes sense actually! again a bit of a pyg noob so I'm not sure the "idiomatic" way to solve this problem? Would it be to have some base class and then have subclasses override a
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
the block from line 399-417 already gets all the node embeddings for each node_type by making I think we already went through the bother of making these separate calls. You are right that perhaps they could be batched and fetched in one lookup if we were using a different implementation, but @swong3-sc is relying on
I don't totally know the best "most idiomatic" way here as others have likely read more PyG code than me, but just looking at LGConv's implementation: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/lg_conv.html#LGConv it seems like this operates on a single tensor of node features
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @swong3-sc is currently creating one big tensor of node features and one big edge_index which is a union of all edges in the individual edge-typed |
||
|
|
||
| # Track all layer embeddings | ||
| all_layer_embeddings: list[torch.Tensor] = [combined_embeddings_0] | ||
| current_embeddings = combined_embeddings_0 | ||
|
|
||
| # Perform K layers of propagation | ||
| for conv in self._convs: | ||
| current_embeddings = conv(current_embeddings, combined_edge_index) # shape [total_nodes, D] | ||
| all_layer_embeddings.append(current_embeddings) | ||
|
|
||
| # Weighted sum across layers | ||
| combined_final_embeddings = self._weighted_layer_sum(all_layer_embeddings) # shape [total_nodes, D] | ||
|
|
||
| # Split back into per-node-type embeddings | ||
| final_embeddings: dict[NodeType, torch.Tensor] = {} | ||
| for node_type in all_node_types_in_data: | ||
| start_idx = node_type_to_offset[node_type] | ||
| node_type_str = str(node_type) | ||
| num_nodes = data[node_type_str].num_nodes | ||
| end_idx = start_idx + num_nodes | ||
|
|
||
| final_embeddings[node_type] = combined_final_embeddings[start_idx:end_idx] # shape [num_nodes, D] | ||
|
|
||
| # Extract anchor nodes if specified | ||
| if anchor_node_ids is not None: | ||
| # Only return embeddings for node types specified in anchor_node_ids | ||
| filtered_embeddings: dict[NodeType, torch.Tensor] = {} | ||
| for node_type in all_node_types_in_data: | ||
| if node_type in anchor_node_ids: | ||
| anchors = anchor_node_ids[node_type].to(device).long() | ||
| filtered_embeddings[node_type] = final_embeddings[node_type][anchors] | ||
| return filtered_embeddings | ||
|
|
||
| return final_embeddings | ||
|
|
||
| def _lookup_embeddings_for_single_node_type( | ||
| self, node_type: str, ids: torch.Tensor | ||
| ) -> torch.Tensor: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd suggest a clearer name, e.g.
_build_node_embedding_table_name_by_node_type. i think we might actually use something like get_feature_key for other generic functionalities inside GiGL