From c48b8d51f3461004e56665be5274da1fb275e24b Mon Sep 17 00:00:00 2001 From: qiuyinzhang Date: Mon, 8 Jul 2024 01:52:58 +0000 Subject: [PATCH] feat: vae omics net --- dooc/datasets.py | 101 +++++++++++++- dooc/loss.py | 116 +++++++++++++++++ dooc/models.py | 33 ++++- dooc/nets/__init__.py | 86 ++++++++++++ dooc/nets/vaeomic.py | 297 ++++++++++++++++++++++++++++++++++++++++++ dooc/pipelines.py | 87 ++++++++++++- 6 files changed, 714 insertions(+), 6 deletions(-) create mode 100644 dooc/nets/vaeomic.py diff --git a/dooc/datasets.py b/dooc/datasets.py index 2930db7..7d3f19e 100644 --- a/dooc/datasets.py +++ b/dooc/datasets.py @@ -4,7 +4,7 @@ from moltx import tokenizers, datasets -class _SmiMutBase: +class _SmiBase: def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device("cpu")) -> None: self.smi_ds = datasets.Base(smi_tokenizer, device) self.device = device @@ -22,7 +22,7 @@ def _smi_tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> to """ -class _DrugcellAdamrBase(_SmiMutBase): +class _DrugcellAdamrBase(_SmiBase): """Base datasets, convert smiles and genes to torch.Tensor.""" def __init__( @@ -141,6 +141,72 @@ def __call__( return mut_x, smi_tgt, out +class _VAEOmicsAdamr2Base(_SmiBase): + """Base datasets, convert smiles and omics to torch.Tensor.""" + + def __init__( + self, + smi_tokenizer: tokenizers.MoltxTokenizer, + device: torch.device = torch.device("cpu") + ) -> None: + super().__init__(smi_tokenizer, device) + self.smi_tokenizer = smi_tokenizer + + def _smi_tokens( + self, + smiles: typing.Sequence[str], + seq_len: int = 200, + ) -> torch.Tensor: + tgt = self._smi_tokenize( + [f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len) + return tgt + + def _omics_tokens(self, omics_seq: typing.Sequence[list]) -> typing.Sequence[torch.Tensor]: + return [torch.tensor(omic, device=self.device) for omic in omics_seq] + + def _out(self, values: typing.Sequence[float]) -> torch.Tensor: + return torch.tensor(values, device=self.device) + + +class _VAEOmicsAdamr2OmicsSmi(_VAEOmicsAdamr2Base): + def __call__( + self, + omics_seq: typing.Sequence[list], + smis: typing.Sequence[str], + vals: typing.Sequence[float], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + assert len(smis) == len(vals) and len(omics_seq[0]) == len(vals) + omics_x = self._omics_tokens(omics_seq) + smi_tgt = self._smi_tokens(smis, seq_len) + out = self._out(vals).unsqueeze(-1) + return omics_x, smi_tgt, out + + +class _VAEOmicsAdamr2OmicsSmis(_VAEOmicsAdamr2Base): + def __call__( + self, + omics_seq: typing.Sequence[list], + lsmis: typing.Sequence[typing.Sequence[str]], + lvals: typing.Sequence[typing.Sequence[float]], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + """ + omics_seq: [omic1, omic2, ...](omics type len) omic1: [omic11, omic12, ...](batch size) omics1_1: [gene1, gene2, ...] + bsmiles: [[smi11, smi12], [smi21, smi22], ...] + bvlaues: [[val11, val12], [val21, val22], ...] + """ + assert len(lsmis) == len(lvals) and len(omics_seq[0]) == len(lvals) + omics_x = self._omics_tokens(omics_seq) + batchlen = len(lsmis) + listlen = len(lsmis[0]) + smiles = [smi for bsmi in lsmis for smi in bsmi] + smi_tgt = self._smi_tokens(smiles, seq_len) + smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1)) + out = self._out(lvals) + return omics_x, smi_tgt, out + + """ Mutations(Individual Sample) and Smiles Interaction @@ -170,3 +236,34 @@ def __call__( class MutSmisListwiseRank(_DrugcellAdamr2MutSmis): pass + + +""" +Omicsations(Individual Sample) and Smiles Interaction + +OmicsSmiReg +OmicsSmis{Pair/List}wiseRank +OmicssSmi{Pair/List}wiseRank +""" + + +class OmicsSmiReg(_VAEOmicsAdamr2OmicsSmi): + pass + + +class OmicsSmisPairwiseRank(_VAEOmicsAdamr2OmicsSmis): + def __call__( + self, + omics_seq: typing.Sequence[list], + lsmiles: typing.Sequence[typing.Sequence[str]], + lvalues: typing.Sequence[typing.Sequence[float]], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + omics_x, smi_tgt, rout = super().__call__(omics_seq, lsmiles, lvalues, seq_len) + out = torch.zeros(rout.size(0), dtype=rout.dtype, device=self.device) + out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1.0 + return omics_x, smi_tgt, out + + +class OmicsSmisListwiseRank(_VAEOmicsAdamr2OmicsSmis): + pass diff --git a/dooc/loss.py b/dooc/loss.py index c0ce9f2..ef9f3e7 100644 --- a/dooc/loss.py +++ b/dooc/loss.py @@ -1,5 +1,7 @@ +import typing import torch import torch.nn as nn +from torch.distributions import Normal, kl_divergence class ListNetLoss(nn.Module): @@ -11,3 +13,117 @@ def __init__(self, reduction: str = 'mean') -> None: def forward(self, predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor: out = - (target.softmax(dim=-1) * predict.log_softmax(dim=-1)) return getattr(out, self.reduction)() + + +class VAEOmicsLoss(nn.Module): + + def __init__(self, loss_type: str, omics_num: int) -> None: + super().__init__() + self.loss_type = loss_type + self.k = omics_num + self.kl_loss_weight = 0.1 # TODO: 待定 + + def forward(self, x: typing.Sequence, out_x: typing.Sequence, **kwargs) -> float: + return getattr(self, f"_forward_{self.loss_type}")(x, out_x, **kwargs) + + def _forward_generate(self, x: typing.Sequence, out_x: typing.Sequence, labels: torch.Tensor, **kwargs) -> typing.Sequence: + # out_encoder, out_self, out_cross, out_dsc, out_cl = out_x + out_encoder, out_self, out_cross, out_dsc = out_x + self_loss = self._calc_self_vae_loss(x, out_self) + cross_loss, cross_infer_dsc_loss = self._calc_cross_vae_loss(x, out_cross, out_encoder) + cross_infer_loss = self._calc_cross_infer_loss(out_encoder) + dsc_loss = self._calc_dsc_loss(out_dsc) + # contrastive_loss = self._calc_contrastive_loss(out_cl, labels) + generate_loss = ( + self_loss + 0.1 * (cross_loss + cross_infer_loss * cross_infer_loss) + - (dsc_loss + cross_infer_dsc_loss) * 0.01 # + contrastive_loss + ) + # return generate_loss, self_loss, cross_loss, cross_infer_loss, dsc_loss + return generate_loss + + def _forward_dsc(self, x: typing.Sequence, out_x: typing.Sequence, **kwargs) -> float: + out_encoder, out_cross, out_dsc = out_x + _, cross_infer_dsc_loss = self._calc_cross_vae_loss(x, out_cross, out_encoder) + dsc_loss = self._calc_dsc_loss(out_dsc) + return cross_infer_dsc_loss + dsc_loss + + def _calc_self_vae_loss(self, x: typing.Sequence, out_self: typing.Sequence) -> float: + loss = 0. + for i, v in enumerate(out_self): + recon_omics, mu, log_var = v + loss += (self.kl_loss_weight * self._kl_loss(mu, log_var, 1.0) + self.reconstruction_loss(x[i], recon_omics)) + return loss + + def _calc_cross_vae_loss(self, x: typing.Sequence, out_cross: typing.Sequence, out_encoder: typing.Sequence) -> typing.Sequence: + batch_size = x[0].size(0) + device = x[0].device + cross_elbo, cross_infer_loss, cross_kl_loss, cross_dsc_loss = 0, 0, 0, 0 + for i, v in enumerate(out_cross): + _, real_mu, real_log_var = out_encoder[i][i] + reconstruct_omic, poe_mu, poe_log_var, pred_real_modal, pred_infer_modal = v + cross_elbo += ( + self.kl_loss_weight * self._kl_loss(poe_mu, poe_log_var, 1.0) + + self.reconstruction_loss(x[i], reconstruct_omic) + ) + cross_infer_loss += self.reconstruction_loss(real_mu, poe_mu) + cross_kl_loss += self._kl_divergence(poe_mu, real_mu, poe_log_var, real_log_var) + + real_modal = torch.tensor([1 for _ in range(batch_size)]).to(device) + infer_modal = torch.tensor([0 for _ in range(batch_size)]).to(device) + cross_dsc_loss += torch.nn.CrossEntropyLoss()(pred_real_modal, real_modal) + cross_dsc_loss += torch.nn.CrossEntropyLoss()(pred_infer_modal, infer_modal) + + cross_dsc_loss = cross_dsc_loss.sum(0) / (len(out_cross) * batch_size) + return cross_elbo + cross_infer_loss + self.kl_loss_weight * cross_kl_loss, cross_dsc_loss + + def _calc_cross_infer_loss(self, out_encoder: typing.Sequence) -> float: + infer_loss = 0 + for i in range(self.k): + _, latent_mu, _ = out_encoder[i][i] + for j in range(self.k): + if i == j: + continue + _, latent_mu_infer, _ = out_encoder[j][i] + infer_loss += self.reconstruction_loss(latent_mu_infer, latent_mu) + return infer_loss / self.k + + def _calc_dsc_loss(self, out_dsc: typing.Sequence) -> float: + dsc_loss = 0 + batch_size = out_dsc[0].size(0) + for i in range(self.k): + real_modal = torch.tensor([i for _ in range(batch_size)]) + dsc_loss += torch.nn.CrossEntropyLoss()(out_dsc[i], real_modal.to(out_dsc[i].device)) + return dsc_loss.sum(0) / (self.k * batch_size) + + def _calc_contrastive_loss(self, out_cl: typing.Sequence, labels: torch.Tensor) -> float: + margin = 1.0 + distances = torch.cdist(out_cl, out_cl) + + labels_matrix = labels.view(-1, 1) == labels.view(1, -1) + + positive_pair_distances = distances * labels_matrix.float() + negative_pair_distances = distances * (1 - labels_matrix.float()) + + positive_loss = positive_pair_distances.sum() / labels_matrix.float().sum() + negative_loss = torch.nn.ReLU()(margin - negative_pair_distances).sum() / (1 - labels_matrix.float()).sum() + + return positive_loss + negative_loss + + def _kl_loss(self, mu, logvar, beta): + # KL divergence loss + kld_1 = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return beta * kld_1 + + def _kl_divergence(self, mu1, mu2, log_sigma1, log_sigma2): + p = Normal(mu1, torch.exp(log_sigma1)) + q = Normal(mu2, torch.exp(log_sigma2)) + + # 计算KL损失 + kl_loss = kl_divergence(p, q).mean() + return kl_loss + + def reconstruction_loss(self, recon_x, x): + # batch_size = recon_x.size(0) + mse = nn.MSELoss() # reduction='sum' + recons_loss = mse(recon_x, x) # / batch_size + return recons_loss diff --git a/dooc/models.py b/dooc/models.py index 0e0da94..d81a599 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -1,8 +1,9 @@ import torch +import typing from moltx import nets as mnets from moltx import models as mmodels from dooc import nets as dnets -from dooc.nets import heads, drugcell +from dooc.nets import heads, drugcell, vaeomic """ @@ -42,3 +43,33 @@ def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float: assert mut_x.dim() == 1 and smi_tgt.dim() == 2 out = self.forward(mut_x, smi_tgt) # [2] return (out[0] - out[1]).item() + + +class OmicsSmiReg(dnets.VAEOmicsAdamr2OmicsSmisXattn): + + def __init__(self, omics_conf: vaeomic.VAEOmicsConfig = dnets.VAEOmics.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None: + super().__init__(omics_conf, smi_conf) + self.reg = heads.RegHead(self.smi_conf.d_model) + + def forward( + self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor: + return self.reg(super().forward(omics, omics_x, smi_tgt)) # [b, 1] + + +class OmicsSmisRank(dnets.VAEOmicsAdamr2OmicsSmiXattn): + + def __init__(self, omics_conf: vaeomic.VAEOmicsConfig = dnets.VAEOmics.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None: + super().__init__(omics_conf, smi_conf) + self.reg = heads.RegHead(self.smi_conf.d_model) + + def forward( + self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor: + return self.reg(super().forward(omics, omics_x, smi_tgt)).squeeze(-1) # [b, n] + + def forward_cmp(self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> float: + """ + for infer, no batch dim + """ + assert omics_x[0].dim() == 1 and smi_tgt.dim() == 2 + out = self.forward(omics, omics_x, smi_tgt) # [2] + return (out[0] - out[1]).item() diff --git a/dooc/nets/__init__.py b/dooc/nets/__init__.py index 13bb44a..b9fad21 100644 --- a/dooc/nets/__init__.py +++ b/dooc/nets/__init__.py @@ -1,7 +1,9 @@ import torch +import typing from torch import nn from moltx.models import AdaMR, AdaMR2 from dooc.nets.drugcell import Drugcell +from dooc.nets.vaeomic import VAEOmics """ @@ -209,3 +211,87 @@ def forward( mut_out = self._forward_mut(mut_x) smi_out = self._forward_smi(smi_tgt) return self.cross_attn(smi_out, mut_out) # [b, n, dmodel] + + +class _VAEOmicsAdamr2(nn.Module): + + def __init__(self, omics_conf, smi_conf) -> None: + super().__init__() + self.omics_conf = omics_conf + self.smi_conf = smi_conf + + self.omics_encoder = VAEOmics(omics_conf) + self.smi_encoder = AdaMR2(smi_conf) + + def load_ckpt(self, *ckpt_files: str) -> None: + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")) + ) + + def load_pretrained_ckpt(self, omics_ckpt: str, smi_ckpt: str, freeze_omics: bool = False, freeze_smi: bool = False) -> None: + self.omics_encoder.load_ckpt(omics_ckpt) + self.smi_encoder.load_ckpt(smi_ckpt) + if freeze_smi: + self.smi_encoder.requires_grad_(False) + if freeze_omics: + self.omics_encoder.requires_grad_(False) + + +class VAEOmicsAdamr2OmicsSmiXattn(_VAEOmicsAdamr2): + def __init__(self, omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(omics_conf, smi_conf) + d_model = self.smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def forward( + self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor: + """ + omics_x: [b, omics_seqlen] + smi_tgt: [b, smi_seqlen] + """ + omics_out = self.omics_encoder.forward_encoder(omics, *omics_x).unsqueeze(-2) # [b, 1, dmodel] + smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel] + return self.cross_attn(smi_out, omics_out).squeeze(-2) # [b, dmodel] + + +class VAEOmicsAdamr2OmicsSmisAdd(_VAEOmicsAdamr2): + def _forward_omics(self, omics: dict, omics_x: typing.Sequence[torch.Tensor]) -> torch.Tensor: + """ + omics_x: [b, omics_seqlen] + out: [b, 1, dmodel] + """ + return self.omics_encoder.forward_encoder(omics, *omics_x).unsqueeze(-2) + + def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + smi_tgt: [b, n, smi_seqlen] + out: [b, n, dmodel] + """ + batched = smi_tgt.dim() == 3 + if batched: + n = smi_tgt.shape[1] + smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1]) + out = self.smi_encoder.forward_feature(smi_tgt) + return out.reshape(-1, n, out.shape[-1]) + return self.smi_encoder.forward_feature(smi_tgt) + + def forward( + self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor: + omics_out = self._forward_omics(omics, omics_x) + smi_out = self._forward_smi(smi_tgt) + return smi_out + omics_out # [b, n, dmodel] + + +class VAEOmicsAdamr2OmicsSmisXattn(VAEOmicsAdamr2OmicsSmisAdd): + def __init__(self, omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(omics_conf, smi_conf) + d_model = smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def forward( + self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor: + omics_out = self._forward_omics(omics, omics_x) + smi_out = self._forward_smi(smi_tgt) + return self.cross_attn(smi_out, omics_out) # [b, n, dmodel] diff --git a/dooc/nets/vaeomic.py b/dooc/nets/vaeomic.py new file mode 100644 index 0000000..4952087 --- /dev/null +++ b/dooc/nets/vaeomic.py @@ -0,0 +1,297 @@ +import torch +import typing +import torch.nn as nn +from dataclasses import dataclass + + +def product_of_experts(mu_set, log_var_set): + tmp = 0 + for i in range(len(mu_set)): + tmp += torch.div(1, torch.exp(log_var_set[i])) + + poe_var = torch.div(1., tmp) + poe_log_var = torch.log(poe_var) + + tmp = 0. + for i in range(len(mu_set)): + tmp += torch.div(1., torch.exp(log_var_set[i])) * mu_set[i] + poe_mu = poe_var * tmp + return poe_mu, poe_log_var + + +def reparameterize(mean, logvar): + std = torch.exp(logvar / 2) + epsilon = torch.randn_like(std) + return epsilon * std + mean + + +class LinearLayer(nn.Module): + def __init__(self, input_dim: int, output_dim: int, dropout: float = 0.2, batchnorm: bool = False, activation=None) -> None: + super(LinearLayer, self).__init__() + self.linear_layer = nn.Linear(input_dim, output_dim) + + self.dropout = nn.Dropout(dropout) if dropout > 0 else None + self.batchnorm = nn.BatchNorm1d(output_dim) if batchnorm else None + + self.activation = None + if activation is not None: + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'sigmoid': + self.activation = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_layer(x) + if self.dropout is not None: + x = self.dropout(x) + if self.batchnorm is not None: + x = self.batchnorm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class VAEEncoder(nn.Module): + def __init__(self, input_dim: int, output_dim: int, hidden_dims: typing.Sequence) -> None: + super(VAEEncoder, self).__init__() + self.encoders = nn.ModuleList([ + LinearLayer( + input_dim, hidden_dims[0], + batchnorm=True, activation='relu' + ) + ]) + for i in range(len(hidden_dims) - 1): + self.encoders.append( + LinearLayer( + hidden_dims[i], hidden_dims[i + 1], + batchnorm=True, activation='relu' + ) + ) + + self.mu_predictor = nn.Sequential( + nn.Linear(hidden_dims[-1], output_dim), nn.ReLU() + ) + self.log_var_predictor = nn.Sequential( + nn.Linear(hidden_dims[-1], output_dim), nn.ReLU() + ) + + def reparameterize(self, mean, logvar): + std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two + epsilon = torch.randn_like(std) + return epsilon * std + mean + + def forward(self, x: torch.Tensor) -> typing.Sequence: + for layer in self.encoders: + x = layer(x) + mu = self.mu_predictor(x) + log_var = self.log_var_predictor(x) + latent_z = self.reparameterize(mu, log_var) + return latent_z, mu, log_var + + +class VAEDecoder(nn.Module): + def __init__(self, input_dim: int, output_dim: int, hidden_dims: typing.Sequence) -> None: + super(VAEDecoder, self).__init__() + + self.decoders = nn.ModuleList([LinearLayer( + input_dim, hidden_dims[0], + dropout=0.1, batchnorm=True, + activation='relu' + )]) + for i in range(len(hidden_dims) - 1): + self.decoders.append(LinearLayer( + hidden_dims[i], hidden_dims[i + 1], + dropout=0.1, batchnorm=True, + activation='relu' + )) + + self.recons_predictor = LinearLayer(hidden_dims[-1], output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + x = layer(x) + data_recons = self.recons_predictor(x) + return data_recons + + +@dataclass +class VAEOmicsConfig: + modal_num: int + modal_dim: list[int] + latent_dim: int + encoder_hidden_dims: list[int] + decoder_hidden_dims: list[int] + complete_omics: dict[str] + + +class VAEOmics(nn.Module): + DEFAULT_CONFIG = VAEOmicsConfig( + modal_num=3, + modal_dim=[2827, 2739, 8802], + latent_dim=256, + encoder_hidden_dims = [2048, 512], + decoder_hidden_dims = [512, 2048], + complete_omics={'methylation': 0, 'mutation': 1, 'rna': 2}, + ) + + def __init__(self, config: VAEOmicsConfig) -> None: + super(VAEOmics, self).__init__() + + self.config = config + self.device = torch.device('cpu') + + self.k = config.modal_num + self.encoders = nn.ModuleList( + nn.ModuleList([ + VAEEncoder( + self.config.modal_dim[i], self.config.latent_dim, self.config.encoder_hidden_dims + ) for j in range(self.k) + ]) for i in range(self.k) + ) + + self.self_decoders = nn.ModuleList([ + VAEDecoder( + self.config.latent_dim, self.config.modal_dim[i], self.config.decoder_hidden_dims + ) for i in range(self.k) + ]) + + self.cross_decoders = nn.ModuleList([ + VAEDecoder( + self.config.latent_dim, self.config.modal_dim[i], self.config.decoder_hidden_dims + ) for i in range(self.k) + ]) + + self.share_encoder = nn.Sequential( + nn.Linear(self.config.latent_dim, self.config.latent_dim), + nn.BatchNorm1d(self.config.latent_dim), + nn.ReLU() + ) + + self.discriminator = nn.Sequential( + nn.Linear(self.config.latent_dim, 16), + nn.BatchNorm1d(16), + nn.ReLU(), + nn.Linear(16, self.k) + ) + + self.infer_discriminator = nn.ModuleList(nn.Sequential( + nn.Linear(self.config.latent_dim, 16), + nn.BatchNorm1d(16), + nn.ReLU(), + nn.Linear(16, 2) + ) for i in range(self.k)) + + def load_ckpt(self, *ckpt_files: str) -> None: + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False + ) + + def freeze_switch(self, layer_name: str, freeze: bool = False) -> None: + layer = getattr(self, layer_name) + + def switch(model, freeze: bool = False) -> None: + for _, child in model.named_children(): + for param in child.parameters(): + param.requires_grad = not freeze + switch(child) + + switch(layer, freeze) + + def forward_generate(self, *x: torch.Tensor) -> typing.Sequence: + ''' + x: list of omics tensor + ''' + self.device = x[0].device + output = [[0 for _ in range(self.k)] for _ in range(self.k)] + for (idx, i) in enumerate(range(self.k)): + for j in range(self.k): + output[i][j] = self.encoders[i][j](x[idx]) + + se = [output[i][i] for i in range(self.k)] + out_self = self._forward_self_vae(se) + out_cross = self._forward_cross_vae(output) + out_disc = self._forward_discriminator(output) + # out_cl = self._forward_contrastive_learning(*x) + return output, out_self, out_cross, out_disc # , out_cl + + def forward_dsc(self, *x: torch.Tensor) -> typing.Sequence: + self.device = x[0].device + output = [[0 for _ in range(self.k)] for _ in range(self.k)] + for (idx, i) in enumerate(range(self.k)): + for j in range(self.k): + output[i][j] = self.encoders[i][j](x[idx]) + + out_cross = self._forward_cross_vae(output) + out_disc = self._forward_discriminator(output) + return output, out_cross, out_disc + + # can input incomplete omics + def forward_encoder(self, omics: dict, *x: torch.Tensor) -> torch.Tensor: + self.device = x[0].device + values = list(omics.values()) + output = [[0 for _ in range(self.k)] for _ in range(self.k)] + + for (item, i) in enumerate(values): + for j in range(self.k): + output[i][j] = self.encoders[i][j](x[item]) + + embedding_tensor = [] + for i in range(self.k): + mu_set = [] + log_var_set = [] + for j in range(self.k): + if i == j or j not in values: + continue + _, mu, log_var = output[j][i] + mu_set.append(mu) + log_var_set.append(log_var) + poe_mu, _ = product_of_experts(mu_set, log_var_set) + if i in values: + _, omic_mu, _ = output[i][i] + joint_mu = (omic_mu + poe_mu) / 2 + else: + joint_mu = poe_mu + embedding_tensor.append(joint_mu.to(self.device)) + embedding_tensor = torch.cat(embedding_tensor, dim=1) + return embedding_tensor + + def _forward_self_vae(self, se: typing.Sequence) -> typing.Sequence: + output = [] + for i in range(self.k): + latent_z, mu, log_var = se[i] + recon_omics = self.self_decoders[i](latent_z) + output.append((recon_omics, mu, log_var)) + return output + + def _forward_cross_vae(self, e: typing.Sequence) -> typing.Sequence: + output = [] + for i in range(self.k): + _, real_mu, _ = e[i][i] + mus = [] + log_vars = [] + for j in range(self.k): + _, mu, log_var = e[j][i] + mus.append(mu) + log_vars.append(log_var) + poe_mu, poe_log_var = product_of_experts(mus, log_vars) + poe_mu = poe_mu.to(self.device) + poe_log_var = poe_log_var.to(self.device) + poe_latent_z = reparameterize(poe_mu, poe_log_var).to(self.device) + reconstruct_omic = self.self_decoders[i](poe_latent_z) + + pred_real_modal = self.infer_discriminator[i](real_mu) + pred_infer_modal = self.infer_discriminator[i](poe_mu) + + output.append((reconstruct_omic, poe_mu, poe_log_var, pred_real_modal, pred_infer_modal)) + return output + + def _forward_discriminator(self, e: typing.Sequence) -> typing.Sequence: + output = [] + for i in range(self.k): + _, mu, _ = e[i][i] + pred_modal = self.discriminator(mu) + output.append(pred_modal) + return output + + def _forward_contrastive_learning(self, *x: torch.Tensor) -> typing.Sequence: + return self.forward_encoder(self.config.complete_omics, *x) diff --git a/dooc/pipelines.py b/dooc/pipelines.py index e635e1f..e13be34 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -5,7 +5,7 @@ from moltx import tokenizers -class _MutSmiBase: +class _Base: def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, model: nn.Module, device: torch.device = torch.device("cpu")) -> None: self.smi_tokenizer = smi_tokenizer self.device = device @@ -25,7 +25,7 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int = None) -> torc return out.to(self.device) -class _MutSmi(_MutSmiBase): +class _MutSmi(_Base): def _model_args(self, mut: typing.Sequence[int], smi: str) -> typing.Tuple[torch.Tensor]: mut_x = torch.tensor(mut, device=self.device) @@ -50,7 +50,7 @@ def cmp(smi1, smi2): return cmp -class _MutSmis(_MutSmiBase): +class _MutSmis(_Base): def _smi_args( self, smis: typing.Sequence[str] @@ -97,6 +97,79 @@ def __call__(self, mut: typing.Sequence[int], smis: typing.Sequence[str]) -> typ return sorted(smis, key=cmp_to_key(self.cmp_smis_func(mut))) +class _OmicsSmi(_Base): + + def _model_args(self, omics_seq: typing.Sequence[list], smi: str) -> typing.Tuple[torch.Tensor]: + omics_x = [torch.tensor(omic, device=self.device) for omic in omics_seq] + smi_tgt = self._tokens2tensor(self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS)) + return omics_x, smi_tgt + + def reg(self, omics: dict, omics_seq: typing.Sequence[list], smi: str) -> float: + return self.model(omics, *self._model_args(omics_seq, smi)).item() + + def cmp_smis_func(self, omics: dict, omics_seq: typing.Sequence[list]) -> typing.Callable: + cmped = {} + + def cmp(smi1, smi2): + query = '-'.join([smi1, smi2]) + if query in cmped: + return cmped[query] + out1 = self.reg(omics, omics_seq, smi1) + out2 = self.reg(omics, omics_seq, smi2) + out = out1 - out2 + cmped[query] = out + return out + return cmp + + +class _OmicsSmis(_Base): + + def _smi_args( + self, smis: typing.Sequence[str] + ) -> torch.Tensor: + smi_tgt = [self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS) for smi in smis] + size_tgt = max(map(len, smi_tgt)) + smi_tgt = torch.concat([self._tokens2tensor(smi, size_tgt).unsqueeze(0) for smi in smi_tgt]) + return smi_tgt + + def cmp_smis_func(self, omics: dict, omics_seq: typing.Sequence[list]) -> typing.Callable: + omics_x = [torch.tensor(omic, device=self.device) for omic in omics_seq] + cmped = {} + + def cmp(smi1, smi2): + smis = [smi1, smi2] + query = '-'.join(smis) + if query in cmped: + return cmped[query] + smi_tgt = self._smi_args(smis) + out = self.model.forward_cmp(omics, omics_x, smi_tgt) + cmped[query] = out + return out + return cmp + + +class _OmicsSmiReg: + + def __call__(self, omics: dict, omics_seq: typing.Sequence[list], smi: str) -> typing.Dict: + return self.reg(omics, omics_seq, smi) + + +class _OmicsSmisRank: + + def __call__(self, omics: dict, omics_seq: typing.Sequence[list], smis: typing.Sequence[str]) -> typing.Sequence[str]: + """ + The output smiles queue is sorted in ascending order. The higher the ranking, the better the effect. + + Therefore, when using the dataset, it is necessary to ensure the consistency of the value and the ranking, + that is, the smaller the value, the higher the ranking. + + For example, IC50 can be used directly; while for indicators such as inhibition rate, + they need to be converted before use. + """ + return sorted(smis, key=cmp_to_key(self.cmp_smis_func(omics, omics_seq))) + + + """ Mutations(Individual Sample) and Smiles Interaction @@ -112,3 +185,11 @@ class MutSmiReg(_MutSmi, _MutSmiReg): class MutSmisRank(_MutSmis, _MutSmisRank): pass + + +class OmicsSmiReg(_OmicsSmi, _OmicsSmiReg): + pass + + +class OmicsSmisRank(_OmicsSmis, _OmicsSmisRank): + pass