Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/huggingmolecules/configuration/configuration_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MatConfig(PretrainedConfigMixin):
ffn_n_layers: int = 1
ffn_d_hidden: int = 0

generator_return_representations_dict: bool = False
generator_aggregation: str = 'mean'
generator_n_layers: int = 1
generator_n_outputs: int = 1
Expand Down
1 change: 1 addition & 0 deletions src/huggingmolecules/configuration/configuration_rmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RMatConfig(PretrainedConfigMixin):
ffn_d_hidden: int = 1536
ffn_d_output: int = 768

generator_return_representations_dict: bool = False
generator_aggregation: str = 'grover'
generator_n_layers: int = 1
generator_d_outputs: int = 1
Expand Down
2 changes: 1 addition & 1 deletion src/huggingmolecules/downloading/downloading_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import filelock
import gdown

default_cache_dir = '~/.cache/torch/huggingmolecules/'
HUGGINGMOLECULES_CACHE = os.getenv("HUGGINGMOLECULES_CACHE", default_cache_dir)
Expand All @@ -14,6 +13,7 @@ def get_cache_filepath(pretrained_name: str, archive_dict: dict, extension: str)


def download_file(src: str, target: str) -> None:
import gdown
dirname = os.path.dirname(target)
os.makedirs(dirname, exist_ok=True)
lock_path = target + ".lock"
Expand Down
17 changes: 16 additions & 1 deletion src/huggingmolecules/models/models_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def __init__(self, *,
n_layers: int,
dropout: float,
attn_hidden: int = 128,
attn_out: int = 4):
attn_out: int = 4,
return_representations_dict: bool = False):
super(Generator, self).__init__()
if aggregation_type == 'grover':
self.att_net = nn.Sequential(
Expand All @@ -225,6 +226,7 @@ def __init__(self, *,
self.proj.append(nn.Linear(attn_hidden, d_output))
self.proj = torch.nn.Sequential(*self.proj)
self.aggregation_type = aggregation_type
self.return_representations_dict = return_representations_dict

def forward(self, x, mask, generated_features):
mask = mask.unsqueeze(-1).float()
Expand All @@ -244,6 +246,19 @@ def forward(self, x, mask, generated_features):
if generated_features is not None:
out_avg_pooling = torch.cat([out_avg_pooling, generated_features], 1)
projected = self.proj(out_avg_pooling)
if self.return_representations_dict:
batch_size, nodes_per_graph, hidden_dim = x.shape
node_representations_mask = mask.view(-1).bool()
dummy_nodes_indices = torch.arange(0, batch_size * nodes_per_graph, step=nodes_per_graph).to(x.device)
node_representations_mask[dummy_nodes_indices] = False

node_representations = x.view(-1, hidden_dim)
node_representations = node_representations[node_representations_mask]
return {
'predictions': projected,
'graph_representations': out_avg_pooling,
'node_representations': node_representations
}
return projected


Expand Down
3 changes: 2 additions & 1 deletion src/huggingmolecules/models/models_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def __init__(self, config: MatConfig):
aggregation_type=config.generator_aggregation,
d_output=config.generator_n_outputs,
n_layers=config.generator_n_layers,
dropout=config.dropout)
dropout=config.dropout,
return_representations_dict=config.generator_return_representations_dict)

# Initialization
self.init_weights(config.init_type)
Expand Down
3 changes: 2 additions & 1 deletion src/huggingmolecules/models/models_rmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __init__(self, config: RMatConfig):
aggregation_type=config.generator_aggregation,
d_output=config.generator_d_outputs,
n_layers=config.generator_n_layers,
dropout=config.dropout)
dropout=config.dropout,
return_representations_dict=config.generator_return_representations_dict)

# Initialization
self.init_weights(config.init_type)
Expand Down