Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 66 additions & 42 deletions omtra/models/gvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def forward(self, data):

# feats has shape (batch_size, n_feats)
# vectors has shape (batch_size, n_vectors, 3)

assert c == 3 and v == self.dim_vectors_in, "vectors have wrong dimensions"
assert n == self.dim_feats_in, "scalar features have wrong dimensions"

Expand Down Expand Up @@ -307,7 +306,7 @@ def __init__(

self.node_types = node_types
self.edge_types = edge_types
self.scalar_size = scalar_size
self.scalar_size = scalar_size # node feature embeddings size
self.vector_size = vector_size
self.n_cp_feats = n_cp_feats
self.scalar_activation = scalar_activation
Expand All @@ -317,7 +316,7 @@ def __init__(
self.edge_feat_size = edge_feat_size
self.use_dst_feats = use_dst_feats
self.rbf_dmax = rbf_dmax
self.rbf_dim = rbf_dim
self.rbf_dim = rbf_dim #radial basis function embedding dimension
self.dropout_rate = dropout
self.message_norm = message_norm
self.attention = attention
Expand Down Expand Up @@ -399,48 +398,60 @@ def __init__(
else:
s_dst_feats_for_messages = 0
v_dst_feats_for_messages = 0

# Edge feature projector

self.edge_message_fns = nn.ModuleDict()
for etype in edge_types:
src_ntype, _, dst_ntype = to_canonical_etype(etype)
inv_etype = get_inv_edge_type(etype)
if inv_etype in self.edge_message_fns:
self.edge_message_fns[etype] = self.edge_message_fns[inv_etype]
continue
message_gvps = []

for i in range(n_message_gvps):
dim_vectors_in = self.v_message_dim
dim_feats_in = self.s_message_dim
in_dim = list(self.edge_feat_size.values())[0]
out_dim = self.scalar_size + self.rbf_dim
if self.use_dst_feats:
out_dim += self.scalar_size
self.edge_feat_projector = nn.Linear(in_dim, out_dim)

if i == 0:
dim_vectors_in += 1
dim_feats_in += rbf_dim + self.edge_feat_size.get(etype, 0)
else:
# if not first layer, input size is the output size of the previous layer
dim_feats_in = dim_feats_out
dim_vectors_in = dim_vectors_out
self.edge_message_fns = nn.ModuleDict()

message_gvps = []

if use_dst_feats and i == 0:
dim_vectors_in += v_dst_feats_for_messages
dim_feats_in += s_dst_feats_for_messages
for i in range(n_message_gvps):
dim_vectors_in = self.v_message_dim
dim_feats_in = self.s_message_dim

dim_feats_out = self.s_message_dim + extra_scalar_feats
dim_vectors_out = self.v_message_dim
if i == 0:
dim_vectors_in += 1
dim_feats_in += rbf_dim
dim_feats_in += list(self.edge_feat_size.values())[0]
#edge_feat_out_dim = self.edge_feat_projector.out_features
#dim_feats_in += edge_feat_out_dim

message_gvps.append(
GVP(
dim_vectors_in=dim_vectors_in,
dim_vectors_out=dim_vectors_out,
n_cp_feats=n_cp_feats,
dim_feats_in=dim_feats_in,
dim_feats_out=dim_feats_out,
feats_activation=scalar_activation(),
vectors_activation=vector_activation(),
vector_gating=True,
)
else:
# if not first layer, input size is the output size of the previous layer
dim_feats_in = dim_feats_out
dim_vectors_in = dim_vectors_out

if use_dst_feats and i == 0:
dim_vectors_in += v_dst_feats_for_messages
dim_feats_in += s_dst_feats_for_messages

dim_feats_out = self.s_message_dim + extra_scalar_feats
dim_vectors_out = self.v_message_dim

message_gvps.append(
GVP(
dim_vectors_in=dim_vectors_in,
dim_vectors_out=dim_vectors_out,
n_cp_feats=n_cp_feats,
dim_feats_in=dim_feats_in,
dim_feats_out=dim_feats_out,
feats_activation=scalar_activation(),
vectors_activation=vector_activation(),
vector_gating=True,
)
self.edge_message_fns[etype] = nn.Sequential(*message_gvps)
)

#Shared message function across all edge types
shared_message_fn = nn.Sequential(*message_gvps)
self.edge_message_fns = nn.ModuleDict()
for etype in self.edge_types:
self.edge_message_fns[etype] = shared_message_fn

self.node_update_fns = nn.ModuleDict()
self.dropout_layers = nn.ModuleDict()
Expand Down Expand Up @@ -810,16 +821,29 @@ def message(self, edges, etype):
vec_feats.append(edges.dst["v"])
vec_feats = torch.cat(vec_feats, dim=1)

# create scalar features
# create scalar features (output dim of the edge feat projector = scalar_feats + rbf_dim + (if the self.use_dst_feats) scalar_fears)
scalar_feats = [edges.src["s"], edges.data["d"]]
if self.edge_feat_size[etype] > 0:
scalar_feats.append(edges.data["ef"])

if self.use_dst_feats:
scalar_feats.append(edges.dst["s_dst_msg"])

scalar_feats = torch.cat(scalar_feats, dim=1)

# construct edge feature projector

# if self.edge_feat_size.get(etype, 0) > 0:
# scalar_feats = scalar_feats + self.edge_feat_projector(edges.data["ef"])
# concatenate edge features
if self.edge_feat_size.get(etype, 0) > 0:
#edge_feat_proj = self.edge_feat_projector(edges.data["ef"])
edge_feats = edges.data["ef"]
scalar_feats = torch.cat([scalar_feats, edge_feats], dim=1)
else:
# zeros tensor of size (num_edges, edge_feat_projector.out_features)
#no_edge_feats = torch.zeros((scalar_feats.size(0), self.edge_feat_projector.out_features), device=scalar_feats.device)
no_edge_feats = torch.zeros((scalar_feats.size(0), list(self.edge_feat_size.values())[0]), device=scalar_feats.device)
scalar_feats = torch.cat([scalar_feats, no_edge_feats], dim=1)

scalar_message, vector_message = self.edge_message_fns[etype](
(scalar_feats, vec_feats)
)
Expand Down