From dc84439179fd00e6c7115e21e8505a65a43d0e8f Mon Sep 17 00:00:00 2001 From: Pan Piotr Date: Mon, 13 Mar 2023 12:56:30 +0000 Subject: [PATCH] Add return dict option to RMAT and MAT --- .../configuration/configuration_mat.py | 1 + .../configuration/configuration_rmat.py | 1 + .../downloading/downloading_utils.py | 2 +- .../models/models_common_utils.py | 17 ++++++++++++++++- src/huggingmolecules/models/models_mat.py | 3 ++- src/huggingmolecules/models/models_rmat.py | 3 ++- 6 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/huggingmolecules/configuration/configuration_mat.py b/src/huggingmolecules/configuration/configuration_mat.py index dcbd626..9b96c11 100644 --- a/src/huggingmolecules/configuration/configuration_mat.py +++ b/src/huggingmolecules/configuration/configuration_mat.py @@ -27,6 +27,7 @@ class MatConfig(PretrainedConfigMixin): ffn_n_layers: int = 1 ffn_d_hidden: int = 0 + generator_return_representations_dict: bool = False generator_aggregation: str = 'mean' generator_n_layers: int = 1 generator_n_outputs: int = 1 diff --git a/src/huggingmolecules/configuration/configuration_rmat.py b/src/huggingmolecules/configuration/configuration_rmat.py index b78f8d5..a50d9b2 100644 --- a/src/huggingmolecules/configuration/configuration_rmat.py +++ b/src/huggingmolecules/configuration/configuration_rmat.py @@ -28,6 +28,7 @@ class RMatConfig(PretrainedConfigMixin): ffn_d_hidden: int = 1536 ffn_d_output: int = 768 + generator_return_representations_dict: bool = False generator_aggregation: str = 'grover' generator_n_layers: int = 1 generator_d_outputs: int = 1 diff --git a/src/huggingmolecules/downloading/downloading_utils.py b/src/huggingmolecules/downloading/downloading_utils.py index 5c9b4d4..4e8f7fe 100644 --- a/src/huggingmolecules/downloading/downloading_utils.py +++ b/src/huggingmolecules/downloading/downloading_utils.py @@ -1,7 +1,6 @@ import os import filelock -import gdown default_cache_dir = '~/.cache/torch/huggingmolecules/' HUGGINGMOLECULES_CACHE = os.getenv("HUGGINGMOLECULES_CACHE", default_cache_dir) @@ -14,6 +13,7 @@ def get_cache_filepath(pretrained_name: str, archive_dict: dict, extension: str) def download_file(src: str, target: str) -> None: + import gdown dirname = os.path.dirname(target) os.makedirs(dirname, exist_ok=True) lock_path = target + ".lock" diff --git a/src/huggingmolecules/models/models_common_utils.py b/src/huggingmolecules/models/models_common_utils.py index 352ea9b..404da6d 100644 --- a/src/huggingmolecules/models/models_common_utils.py +++ b/src/huggingmolecules/models/models_common_utils.py @@ -202,7 +202,8 @@ def __init__(self, *, n_layers: int, dropout: float, attn_hidden: int = 128, - attn_out: int = 4): + attn_out: int = 4, + return_representations_dict: bool = False): super(Generator, self).__init__() if aggregation_type == 'grover': self.att_net = nn.Sequential( @@ -225,6 +226,7 @@ def __init__(self, *, self.proj.append(nn.Linear(attn_hidden, d_output)) self.proj = torch.nn.Sequential(*self.proj) self.aggregation_type = aggregation_type + self.return_representations_dict = return_representations_dict def forward(self, x, mask, generated_features): mask = mask.unsqueeze(-1).float() @@ -244,6 +246,19 @@ def forward(self, x, mask, generated_features): if generated_features is not None: out_avg_pooling = torch.cat([out_avg_pooling, generated_features], 1) projected = self.proj(out_avg_pooling) + if self.return_representations_dict: + batch_size, nodes_per_graph, hidden_dim = x.shape + node_representations_mask = mask.view(-1).bool() + dummy_nodes_indices = torch.arange(0, batch_size * nodes_per_graph, step=nodes_per_graph).to(x.device) + node_representations_mask[dummy_nodes_indices] = False + + node_representations = x.view(-1, hidden_dim) + node_representations = node_representations[node_representations_mask] + return { + 'predictions': projected, + 'graph_representations': out_avg_pooling, + 'node_representations': node_representations + } return projected diff --git a/src/huggingmolecules/models/models_mat.py b/src/huggingmolecules/models/models_mat.py index 79f93a0..5a74843 100644 --- a/src/huggingmolecules/models/models_mat.py +++ b/src/huggingmolecules/models/models_mat.py @@ -67,7 +67,8 @@ def __init__(self, config: MatConfig): aggregation_type=config.generator_aggregation, d_output=config.generator_n_outputs, n_layers=config.generator_n_layers, - dropout=config.dropout) + dropout=config.dropout, + return_representations_dict=config.generator_return_representations_dict) # Initialization self.init_weights(config.init_type) diff --git a/src/huggingmolecules/models/models_rmat.py b/src/huggingmolecules/models/models_rmat.py index d704748..13b20bc 100644 --- a/src/huggingmolecules/models/models_rmat.py +++ b/src/huggingmolecules/models/models_rmat.py @@ -72,7 +72,8 @@ def __init__(self, config: RMatConfig): aggregation_type=config.generator_aggregation, d_output=config.generator_d_outputs, n_layers=config.generator_n_layers, - dropout=config.dropout) + dropout=config.dropout, + return_representations_dict=config.generator_return_representations_dict) # Initialization self.init_weights(config.init_type)