diff --git a/omtra/models/gvp.py b/omtra/models/gvp.py index 30f3763c..a8f4f931 100644 --- a/omtra/models/gvp.py +++ b/omtra/models/gvp.py @@ -141,7 +141,6 @@ def forward(self, data): # feats has shape (batch_size, n_feats) # vectors has shape (batch_size, n_vectors, 3) - assert c == 3 and v == self.dim_vectors_in, "vectors have wrong dimensions" assert n == self.dim_feats_in, "scalar features have wrong dimensions" @@ -307,7 +306,7 @@ def __init__( self.node_types = node_types self.edge_types = edge_types - self.scalar_size = scalar_size + self.scalar_size = scalar_size # node feature embeddings size self.vector_size = vector_size self.n_cp_feats = n_cp_feats self.scalar_activation = scalar_activation @@ -317,7 +316,7 @@ def __init__( self.edge_feat_size = edge_feat_size self.use_dst_feats = use_dst_feats self.rbf_dmax = rbf_dmax - self.rbf_dim = rbf_dim + self.rbf_dim = rbf_dim #radial basis function embedding dimension self.dropout_rate = dropout self.message_norm = message_norm self.attention = attention @@ -399,48 +398,60 @@ def __init__( else: s_dst_feats_for_messages = 0 v_dst_feats_for_messages = 0 + + # Edge feature projector - self.edge_message_fns = nn.ModuleDict() - for etype in edge_types: - src_ntype, _, dst_ntype = to_canonical_etype(etype) - inv_etype = get_inv_edge_type(etype) - if inv_etype in self.edge_message_fns: - self.edge_message_fns[etype] = self.edge_message_fns[inv_etype] - continue - message_gvps = [] - - for i in range(n_message_gvps): - dim_vectors_in = self.v_message_dim - dim_feats_in = self.s_message_dim + in_dim = list(self.edge_feat_size.values())[0] + out_dim = self.scalar_size + self.rbf_dim + if self.use_dst_feats: + out_dim += self.scalar_size + self.edge_feat_projector = nn.Linear(in_dim, out_dim) - if i == 0: - dim_vectors_in += 1 - dim_feats_in += rbf_dim + self.edge_feat_size.get(etype, 0) - else: - # if not first layer, input size is the output size of the previous layer - dim_feats_in = dim_feats_out - dim_vectors_in = dim_vectors_out + self.edge_message_fns = nn.ModuleDict() + + message_gvps = [] - if use_dst_feats and i == 0: - dim_vectors_in += v_dst_feats_for_messages - dim_feats_in += s_dst_feats_for_messages + for i in range(n_message_gvps): + dim_vectors_in = self.v_message_dim + dim_feats_in = self.s_message_dim - dim_feats_out = self.s_message_dim + extra_scalar_feats - dim_vectors_out = self.v_message_dim + if i == 0: + dim_vectors_in += 1 + dim_feats_in += rbf_dim + dim_feats_in += list(self.edge_feat_size.values())[0] + #edge_feat_out_dim = self.edge_feat_projector.out_features + #dim_feats_in += edge_feat_out_dim - message_gvps.append( - GVP( - dim_vectors_in=dim_vectors_in, - dim_vectors_out=dim_vectors_out, - n_cp_feats=n_cp_feats, - dim_feats_in=dim_feats_in, - dim_feats_out=dim_feats_out, - feats_activation=scalar_activation(), - vectors_activation=vector_activation(), - vector_gating=True, - ) + else: + # if not first layer, input size is the output size of the previous layer + dim_feats_in = dim_feats_out + dim_vectors_in = dim_vectors_out + + if use_dst_feats and i == 0: + dim_vectors_in += v_dst_feats_for_messages + dim_feats_in += s_dst_feats_for_messages + + dim_feats_out = self.s_message_dim + extra_scalar_feats + dim_vectors_out = self.v_message_dim + + message_gvps.append( + GVP( + dim_vectors_in=dim_vectors_in, + dim_vectors_out=dim_vectors_out, + n_cp_feats=n_cp_feats, + dim_feats_in=dim_feats_in, + dim_feats_out=dim_feats_out, + feats_activation=scalar_activation(), + vectors_activation=vector_activation(), + vector_gating=True, ) - self.edge_message_fns[etype] = nn.Sequential(*message_gvps) + ) + + #Shared message function across all edge types + shared_message_fn = nn.Sequential(*message_gvps) + self.edge_message_fns = nn.ModuleDict() + for etype in self.edge_types: + self.edge_message_fns[etype] = shared_message_fn self.node_update_fns = nn.ModuleDict() self.dropout_layers = nn.ModuleDict() @@ -810,16 +821,29 @@ def message(self, edges, etype): vec_feats.append(edges.dst["v"]) vec_feats = torch.cat(vec_feats, dim=1) - # create scalar features + # create scalar features (output dim of the edge feat projector = scalar_feats + rbf_dim + (if the self.use_dst_feats) scalar_fears) scalar_feats = [edges.src["s"], edges.data["d"]] - if self.edge_feat_size[etype] > 0: - scalar_feats.append(edges.data["ef"]) if self.use_dst_feats: scalar_feats.append(edges.dst["s_dst_msg"]) scalar_feats = torch.cat(scalar_feats, dim=1) + # construct edge feature projector + + # if self.edge_feat_size.get(etype, 0) > 0: + # scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"]) + # concatenate edge features + if self.edge_feat_size.get(etype, 0) > 0: + #edge_feat_proj = self.edge_feat_projector(edges.data["ef"]) + edge_feats = edges.data["ef"] + scalar_feats = torch.cat([scalar_feats, edge_feats], dim=1) + else: + # zeros tensor of size (num_edges, edge_feat_projector.out_features) + #no_edge_feats = torch.zeros((scalar_feats.size(0), self.edge_feat_projector.out_features), device=scalar_feats.device) + no_edge_feats = torch.zeros((scalar_feats.size(0), list(self.edge_feat_size.values())[0]), device=scalar_feats.device) + scalar_feats = torch.cat([scalar_feats, no_edge_feats], dim=1) + scalar_message, vector_message = self.edge_message_fns[etype]( (scalar_feats, vec_feats) )