From c4736bc83f57b5f816eec2d9a3df21f1de21594e Mon Sep 17 00:00:00 2001 From: juhimgupta Date: Mon, 24 Nov 2025 14:03:58 -0500 Subject: [PATCH 1/6] Updated code for same message passing function --- omtra/models/gvp.py | 72 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index 30f3763c..bd52c92e 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -141,7 +141,9 @@ def forward(self, data): # feats has shape (batch_size, n_feats) # vectors has shape (batch_size, n_vectors, 3) - + print("Scalar features dimensions:", feats.shape) + print("Size of n:", n) + print("Expected input dimensions:", self.dim_feats_in) assert c == 3 and v == self.dim_vectors_in, "vectors have wrong dimensions" assert n == self.dim_feats_in, "scalar features have wrong dimensions" @@ -307,7 +309,7 @@ def __init__( self.node_types = node_types self.edge_types = edge_types - self.scalar_size = scalar_size + self.scalar_size = scalar_size # node feature embeddings size self.vector_size = vector_size self.n_cp_feats = n_cp_feats self.scalar_activation = scalar_activation @@ -317,7 +319,7 @@ def __init__( self.edge_feat_size = edge_feat_size self.use_dst_feats = use_dst_feats self.rbf_dmax = rbf_dmax - self.rbf_dim = rbf_dim + self.rbf_dim = rbf_dim #radial basis function embedding dimension self.dropout_rate = dropout self.message_norm = message_norm self.attention = attention @@ -399,23 +401,42 @@ def __init__( else: s_dst_feats_for_messages = 0 v_dst_feats_for_messages = 0 + + # Edge feature projector (just have one edge projector) + + # self.edge_feat_projector = nn.ModuleDict() + # for etype in self.edge_types: + # # if edge feature exists, project to consistent size + # if self.edge_feat_size.get(etype, 0) > 0: + # output_dim_edge_projector = self.rbf_dim + self.edge_feat_size[etype] + # if self.use_dst_feats: + # output_dim_edge_projector += s_dst_feats_for_messages + + # self.edge_feat_projector[etype] = nn.Sequential( + # nn.Linear(self.edge_feat_size[etype], output_dim_edge_projector), + # ) + # else: + # #use identity if no edge features + # self.edge_feat_projector[etype] = nn.Identity() + + # Single edge feature projector for all edge types with consistent output dimension + self.edge_feat_projector_dim = None #set in the message method self.edge_message_fns = nn.ModuleDict() - for etype in edge_types: - src_ntype, _, dst_ntype = to_canonical_etype(etype) - inv_etype = get_inv_edge_type(etype) - if inv_etype in self.edge_message_fns: - self.edge_message_fns[etype] = self.edge_message_fns[inv_etype] - continue - message_gvps = [] + + message_gvps = [] + for etype in self.edge_types: for i in range(n_message_gvps): dim_vectors_in = self.v_message_dim dim_feats_in = self.s_message_dim if i == 0: dim_vectors_in += 1 - dim_feats_in += rbf_dim + self.edge_feat_size.get(etype, 0) + # use max edge feature size across all edge types + edge projector? + #max_edge_feat_size = max(self.edge_feat_size.values(), default=0) + dim_feats_in += rbf_dim + self.edge_feat_size[etype] + print("dim_feats_in for first message gvp:", dim_feats_in) else: # if not first layer, input size is the output size of the previous layer dim_feats_in = dim_feats_out @@ -424,6 +445,7 @@ def __init__( if use_dst_feats and i == 0: dim_vectors_in += v_dst_feats_for_messages dim_feats_in += s_dst_feats_for_messages + print("dim_feats_in after adding dst feats:", dim_feats_in) dim_feats_out = self.s_message_dim + extra_scalar_feats dim_vectors_out = self.v_message_dim @@ -440,7 +462,12 @@ def __init__( vector_gating=True, ) ) - self.edge_message_fns[etype] = nn.Sequential(*message_gvps) + + #Shared message function across all edge types + shared_message_fn = nn.Sequential(*message_gvps) + self.edge_message_fns = nn.ModuleDict() + for etype in self.edge_types: + self.edge_message_fns[etype] = shared_message_fn self.node_update_fns = nn.ModuleDict() self.dropout_layers = nn.ModuleDict() @@ -810,16 +837,31 @@ def message(self, edges, etype): vec_feats.append(edges.dst["v"]) vec_feats = torch.cat(vec_feats, dim=1) - # create scalar features + # create scalar features (output dim of the edge feat projector = scalar_feats + rbf_dim + (if the self.use_dst_feats) scalar_fears) scalar_feats = [edges.src["s"], edges.data["d"]] - if self.edge_feat_size[etype] > 0: - scalar_feats.append(edges.data["ef"]) if self.use_dst_feats: scalar_feats.append(edges.dst["s_dst_msg"]) scalar_feats = torch.cat(scalar_feats, dim=1) + # construct edge feature projector + scalar_feats_dim = scalar_feats.size(1) + print("Scalar feats shape before edge projector", scalar_feats.shape) + + if self.edge_feat_size.get(etype, 0) > 0: + if self.edge_feat_projector is None: + # set up edge feature projector with consistent output size + output_dim_edge_projector = self.rbf_dim + scalar_feats_dim + print("Output dim of edge projector", output_dim_edge_projector) + self.edge_feat_projector = nn.Sequential( + nn.Linear(self.edge_feat_size[etype], output_dim_edge_projector), + ).to(edges.data["ef"].device) + + #project edge features + projected_edge_feats = self.edge_feat_projector(edges.data["ef"]) + scalar_feats = torch.cat([scalar_feats, projected_edge_feats], dim=1) + print("Scalar feats shape after edge projector", scalar_feats.shape) scalar_message, vector_message = self.edge_message_fns[etype]( (scalar_feats, vec_feats) ) From 3e9df0344da487f5ea6bfbfb9d7abd59241f8169 Mon Sep 17 00:00:00 2001 From: juhimgupta Date: Mon, 24 Nov 2025 15:26:07 -0500 Subject: [PATCH 2/6] updated dim_feats_in initialization --- omtra/models/gvp.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index bd52c92e..f3173e50 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -420,7 +420,7 @@ def __init__( # self.edge_feat_projector[etype] = nn.Identity() # Single edge feature projector for all edge types with consistent output dimension - self.edge_feat_projector_dim = None #set in the message method + self.edge_feat_projector = None #set in the message method self.edge_message_fns = nn.ModuleDict() @@ -434,8 +434,8 @@ def __init__( if i == 0: dim_vectors_in += 1 # use max edge feature size across all edge types + edge projector? - #max_edge_feat_size = max(self.edge_feat_size.values(), default=0) - dim_feats_in += rbf_dim + self.edge_feat_size[etype] + max_edge_feat_size = max(self.edge_feat_size.values(), default=0) + dim_feats_in += rbf_dim + max_edge_feat_size print("dim_feats_in for first message gvp:", dim_feats_in) else: # if not first layer, input size is the output size of the previous layer @@ -838,6 +838,8 @@ def message(self, edges, etype): vec_feats = torch.cat(vec_feats, dim=1) # create scalar features (output dim of the edge feat projector = scalar_feats + rbf_dim + (if the self.use_dst_feats) scalar_fears) + print("Edges src scalar shape", edges.src["s"].shape) + print("Edges data d shape", edges.data["d"].shape) scalar_feats = [edges.src["s"], edges.data["d"]] if self.use_dst_feats: @@ -850,13 +852,12 @@ def message(self, edges, etype): print("Scalar feats shape before edge projector", scalar_feats.shape) if self.edge_feat_size.get(etype, 0) > 0: - if self.edge_feat_projector is None: - # set up edge feature projector with consistent output size - output_dim_edge_projector = self.rbf_dim + scalar_feats_dim - print("Output dim of edge projector", output_dim_edge_projector) - self.edge_feat_projector = nn.Sequential( - nn.Linear(self.edge_feat_size[etype], output_dim_edge_projector), - ).to(edges.data["ef"].device) + # set up edge feature projector with consistent output size + output_dim_edge_projector = self.rbf_dim + scalar_feats_dim + print("Output dim of edge projector", output_dim_edge_projector) + self.edge_feat_projector = nn.Sequential( + nn.Linear(self.edge_feat_size[etype], output_dim_edge_projector), + ).to(edges.data["ef"].device) #project edge features projected_edge_feats = self.edge_feat_projector(edges.data["ef"]) From afe72b8ed87b2615a925f7aa994547095f72cb4f Mon Sep 17 00:00:00 2001 From: juhimgupta Date: Mon, 24 Nov 2025 17:02:16 -0500 Subject: [PATCH 3/6] Updated message --- omtra/models/gvp.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index f3173e50..f2095428 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -420,7 +420,12 @@ def __init__( # self.edge_feat_projector[etype] = nn.Identity() # Single edge feature projector for all edge types with consistent output dimension - self.edge_feat_projector = None #set in the message method + #keep edge feature projector here + in_dim = list(self.edge_feat_size.values())[0] + out_dim = self.scalar_size + self.rbf_dim + if self.use_dst_feats: + out_dim += self.scalar_size + self.edge_feat_projector = nn.Linear(in_dim, out_dim) self.edge_message_fns = nn.ModuleDict() @@ -434,8 +439,9 @@ def __init__( if i == 0: dim_vectors_in += 1 # use max edge feature size across all edge types + edge projector? - max_edge_feat_size = max(self.edge_feat_size.values(), default=0) - dim_feats_in += rbf_dim + max_edge_feat_size + #max_edge_feat_size = max(self.edge_feat_size.values(), default=0) + dim_feats_in += rbf_dim + print("Edge feature size for etype init", etype, self.edge_feat_size.get(etype, 0)) print("dim_feats_in for first message gvp:", dim_feats_in) else: # if not first layer, input size is the output size of the previous layer @@ -838,8 +844,6 @@ def message(self, edges, etype): vec_feats = torch.cat(vec_feats, dim=1) # create scalar features (output dim of the edge feat projector = scalar_feats + rbf_dim + (if the self.use_dst_feats) scalar_fears) - print("Edges src scalar shape", edges.src["s"].shape) - print("Edges data d shape", edges.data["d"].shape) scalar_feats = [edges.src["s"], edges.data["d"]] if self.use_dst_feats: @@ -849,20 +853,10 @@ def message(self, edges, etype): # construct edge feature projector scalar_feats_dim = scalar_feats.size(1) - print("Scalar feats shape before edge projector", scalar_feats.shape) if self.edge_feat_size.get(etype, 0) > 0: - # set up edge feature projector with consistent output size - output_dim_edge_projector = self.rbf_dim + scalar_feats_dim - print("Output dim of edge projector", output_dim_edge_projector) - self.edge_feat_projector = nn.Sequential( - nn.Linear(self.edge_feat_size[etype], output_dim_edge_projector), - ).to(edges.data["ef"].device) - - #project edge features - projected_edge_feats = self.edge_feat_projector(edges.data["ef"]) - scalar_feats = torch.cat([scalar_feats, projected_edge_feats], dim=1) - print("Scalar feats shape after edge projector", scalar_feats.shape) + scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"]) + scalar_message, vector_message = self.edge_message_fns[etype]( (scalar_feats, vec_feats) ) From 5db04678497a441b0c650ea4dcf850dfc0117380 Mon Sep 17 00:00:00 2001 From: juhimgupta Date: Wed, 26 Nov 2025 15:22:22 -0500 Subject: [PATCH 4/6] Fixed bug in message passing fn initialization --- omtra/models/gvp.py | 92 ++++++++++++++++----------------------------- 1 file changed, 33 insertions(+), 59 deletions(-) diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index f2095428..c46918d3 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -141,9 +141,6 @@ def forward(self, data): # feats has shape (batch_size, n_feats) # vectors has shape (batch_size, n_vectors, 3) - print("Scalar features dimensions:", feats.shape) - print("Size of n:", n) - print("Expected input dimensions:", self.dim_feats_in) assert c == 3 and v == self.dim_vectors_in, "vectors have wrong dimensions" assert n == self.dim_feats_in, "scalar features have wrong dimensions" @@ -402,25 +399,8 @@ def __init__( s_dst_feats_for_messages = 0 v_dst_feats_for_messages = 0 - # Edge feature projector (just have one edge projector) - - # self.edge_feat_projector = nn.ModuleDict() - # for etype in self.edge_types: - # # if edge feature exists, project to consistent size - # if self.edge_feat_size.get(etype, 0) > 0: - # output_dim_edge_projector = self.rbf_dim + self.edge_feat_size[etype] - # if self.use_dst_feats: - # output_dim_edge_projector += s_dst_feats_for_messages - - # self.edge_feat_projector[etype] = nn.Sequential( - # nn.Linear(self.edge_feat_size[etype], output_dim_edge_projector), - # ) - # else: - # #use identity if no edge features - # self.edge_feat_projector[etype] = nn.Identity() - - # Single edge feature projector for all edge types with consistent output dimension - #keep edge feature projector here + # Edge feature projector + in_dim = list(self.edge_feat_size.values())[0] out_dim = self.scalar_size + self.rbf_dim if self.use_dst_feats: @@ -431,43 +411,38 @@ def __init__( message_gvps = [] - for etype in self.edge_types: - for i in range(n_message_gvps): - dim_vectors_in = self.v_message_dim - dim_feats_in = self.s_message_dim - - if i == 0: - dim_vectors_in += 1 - # use max edge feature size across all edge types + edge projector? - #max_edge_feat_size = max(self.edge_feat_size.values(), default=0) - dim_feats_in += rbf_dim - print("Edge feature size for etype init", etype, self.edge_feat_size.get(etype, 0)) - print("dim_feats_in for first message gvp:", dim_feats_in) - else: - # if not first layer, input size is the output size of the previous layer - dim_feats_in = dim_feats_out - dim_vectors_in = dim_vectors_out - - if use_dst_feats and i == 0: - dim_vectors_in += v_dst_feats_for_messages - dim_feats_in += s_dst_feats_for_messages - print("dim_feats_in after adding dst feats:", dim_feats_in) - - dim_feats_out = self.s_message_dim + extra_scalar_feats - dim_vectors_out = self.v_message_dim - - message_gvps.append( - GVP( - dim_vectors_in=dim_vectors_in, - dim_vectors_out=dim_vectors_out, - n_cp_feats=n_cp_feats, - dim_feats_in=dim_feats_in, - dim_feats_out=dim_feats_out, - feats_activation=scalar_activation(), - vectors_activation=vector_activation(), - vector_gating=True, - ) + for i in range(n_message_gvps): + dim_vectors_in = self.v_message_dim + dim_feats_in = self.s_message_dim + + if i == 0: + dim_vectors_in += 1 + dim_feats_in += rbf_dim + + else: + # if not first layer, input size is the output size of the previous layer + dim_feats_in = dim_feats_out + dim_vectors_in = dim_vectors_out + + if use_dst_feats and i == 0: + dim_vectors_in += v_dst_feats_for_messages + dim_feats_in += s_dst_feats_for_messages + + dim_feats_out = self.s_message_dim + extra_scalar_feats + dim_vectors_out = self.v_message_dim + + message_gvps.append( + GVP( + dim_vectors_in=dim_vectors_in, + dim_vectors_out=dim_vectors_out, + n_cp_feats=n_cp_feats, + dim_feats_in=dim_feats_in, + dim_feats_out=dim_feats_out, + feats_activation=scalar_activation(), + vectors_activation=vector_activation(), + vector_gating=True, ) + ) #Shared message function across all edge types shared_message_fn = nn.Sequential(*message_gvps) @@ -852,7 +827,6 @@ def message(self, edges, etype): scalar_feats = torch.cat(scalar_feats, dim=1) # construct edge feature projector - scalar_feats_dim = scalar_feats.size(1) if self.edge_feat_size.get(etype, 0) > 0: scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"]) From 662f01c9d4c30098a3f32ee87986ce405e5e7d57 Mon Sep 17 00:00:00 2001 From: juhimgupta Date: Sat, 6 Dec 2025 13:41:56 -0500 Subject: [PATCH 5/6] Updated with concatenation of edge features --- omtra/models/gvp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index c46918d3..90d07c33 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -418,6 +418,8 @@ def __init__( if i == 0: dim_vectors_in += 1 dim_feats_in += rbf_dim + edge_feat_out_dim = self.edge_feat_projector.out_features + dim_feats_in += edge_feat_out_dim else: # if not first layer, input size is the output size of the previous layer @@ -828,8 +830,17 @@ def message(self, edges, etype): # construct edge feature projector + # if self.edge_feat_size.get(etype, 0) > 0: + # scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"]) + + # concatenate edge features if self.edge_feat_size.get(etype, 0) > 0: - scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"]) + edge_feat_proj = self.edge_feat_projector(edges.data["ef"]) + scalar_feats = torch.cat([scalar_feats, edge_feat_proj], dim=1) + else: + # zeros tensor of size (num_edges, edge_feat_projector.out_features) + no_edge_feats = torch.zeros((scalar_feats.size(0), self.edge_feat_projector.out_features), device=scalar_feats.device) + scalar_feats = torch.cat([scalar_feats, no_edge_feats], dim=1) scalar_message, vector_message = self.edge_message_fns[etype]( (scalar_feats, vec_feats) From 159b679951bcd6607ec856a4d3cf3cf26669f0a5 Mon Sep 17 00:00:00 2001 From: juhimgupta Date: Sat, 13 Dec 2025 13:31:19 -0500 Subject: [PATCH 6/6] Removed unnecessary edge feat projection --- omtra/models/gvp.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index 90d07c33..a8f4f931 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -418,8 +418,9 @@ def __init__( if i == 0: dim_vectors_in += 1 dim_feats_in += rbf_dim - edge_feat_out_dim = self.edge_feat_projector.out_features - dim_feats_in += edge_feat_out_dim + dim_feats_in += list(self.edge_feat_size.values())[0] + #edge_feat_out_dim = self.edge_feat_projector.out_features + #dim_feats_in += edge_feat_out_dim else: # if not first layer, input size is the output size of the previous layer @@ -832,14 +833,15 @@ def message(self, edges, etype): # if self.edge_feat_size.get(etype, 0) > 0: # scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"]) - # concatenate edge features if self.edge_feat_size.get(etype, 0) > 0: - edge_feat_proj = self.edge_feat_projector(edges.data["ef"]) - scalar_feats = torch.cat([scalar_feats, edge_feat_proj], dim=1) + #edge_feat_proj = self.edge_feat_projector(edges.data["ef"]) + edge_feats = edges.data["ef"] + scalar_feats = torch.cat([scalar_feats, edge_feats], dim=1) else: # zeros tensor of size (num_edges, edge_feat_projector.out_features) - no_edge_feats = torch.zeros((scalar_feats.size(0), self.edge_feat_projector.out_features), device=scalar_feats.device) + #no_edge_feats = torch.zeros((scalar_feats.size(0), self.edge_feat_projector.out_features), device=scalar_feats.device) + no_edge_feats = torch.zeros((scalar_feats.size(0), list(self.edge_feat_size.values())[0]), device=scalar_feats.device) scalar_feats = torch.cat([scalar_feats, no_edge_feats], dim=1) scalar_message, vector_message = self.edge_message_fns[etype](