diff --git a/dooc/models.py b/dooc/models.py index ee2a9e3..b961ae5 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -2,7 +2,7 @@ from moltx import nets as mnets from moltx import models as mmodels from dooc import nets as dnets -from dooc.nets import heads, drugcell, prmo +from dooc.nets import heads, drugcell, prmo, cnnmut """ @@ -44,6 +44,25 @@ def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float: return (out[0] - out[1]).item() +class MutSmisRankV2(dnets.CNNMutAdamr2MutSmisXattn): + + def __init__(self, mut_conf: cnnmut.CNNMutConfig = dnets.CNNMut.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None: + super().__init__(mut_conf, smi_conf) + self.reg = heads.RegHead(self.smi_conf.d_model) + + def forward( + self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + return self.reg(super().forward(mut_x, smi_tgt)).squeeze(-1) # [b, n] + + def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float: + """ + for infer, no batch dim + """ + assert mut_x.dim() == 1 and smi_tgt.dim() == 2 + out = self.forward(mut_x, smi_tgt) # [2] + return (out[0] - out[1]).item() + + class MultiOmicsSmisRank(dnets.PrmoAdamr2MultiOmicsSmisXattn): def __init__(self, multi_omics_conf: prmo.PrmoConfig = dnets.PrmoEncoder.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None: diff --git a/dooc/nets/__init__.py b/dooc/nets/__init__.py index ffa722a..909eaad 100644 --- a/dooc/nets/__init__.py +++ b/dooc/nets/__init__.py @@ -3,6 +3,7 @@ from moltx.models import AdaMR, AdaMR2 from dooc.nets.drugcell import Drugcell from dooc.nets.prmo import PrmoEncoder +from dooc.nets.cnnmut import CNNMut """ @@ -269,3 +270,59 @@ def forward( multi_omics_out = self._forward_multi_omics(mut_x, rna_x, pathway_x) # [b, 1, dmodel] smi_out = self._forward_smi(smi_tgt) return self.cross_attn(smi_out, multi_omics_out) # [b, n, dmodel] + + +class _CNNMutAdamr2(nn.Module): + def __init__(self, mut_conf, smi_conf) -> None: + super().__init__() + self.mut_conf = mut_conf + self.smi_conf = smi_conf + + self.mut_encoder = CNNMut(mut_conf) + self.smi_encoder = AdaMR2(smi_conf) + + def load_ckpt(self, *ckpt_files: str) -> None: + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")) + ) + + def load_pretrained_ckpt(self, smi_ckpt: str, freeze_smi: bool = False) -> None: + self.smi_encoder.load_ckpt(smi_ckpt) + if freeze_smi: + self.smi_encoder.requires_grad_(False) + + +class CNNMutAdamr2MutSmisXattn(_CNNMutAdamr2): + def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(mut_conf, smi_conf) + d_model = smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def _forward_mut(self, mut_x: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + out: [b, 1, dmodel] + """ + mut_x = mut_x.unsqueeze(-2) + return self.mut_encoder(mut_x) + + def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + smi_tgt: [b, n, smi_seqlen] + out: [b, n, dmodel] + """ + batched = smi_tgt.dim() == 3 + if batched: + n = smi_tgt.shape[1] + smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1]) + out = self.smi_encoder.forward_feature(smi_tgt) + return out.reshape(-1, n, out.shape[-1]) + return self.smi_encoder.forward_feature(smi_tgt) + + def forward( + self, mut_x: torch.Tensor, smi_tgt: torch.Tensor + ) -> torch.Tensor: + mut_out = self._forward_mut(mut_x) # [b, 1, dmodel] + smi_out = self._forward_smi(smi_tgt) + return self.cross_attn(smi_out, mut_out) # [b, n, dmodel] diff --git a/dooc/nets/cnnmut.py b/dooc/nets/cnnmut.py new file mode 100644 index 0000000..9ed251e --- /dev/null +++ b/dooc/nets/cnnmut.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from dataclasses import dataclass + + +@dataclass +class CNNMutConfig: + mut_dim: int + kernal_size: int + out_dim: int + dropout: float + + +class CNNMut(nn.Module): + DEFAULT_CONFIG = CNNMutConfig( + mut_dim=3008, + kernal_size=32, + out_dim=768, + dropout=0.1, + ) + + def __init__(self, conf: CNNMutConfig = DEFAULT_CONFIG) -> None: + super().__init__() + self.conf = conf + stride = 2 + self.encoder = nn.Sequential( + nn.Conv1d(in_channels=1, out_channels=20, kernel_size=conf.kernal_size, stride=stride), + nn.BatchNorm1d(20), + nn.ReLU(), + nn.Conv1d(in_channels=20, out_channels=10, kernel_size=conf.kernal_size, stride=stride), + nn.BatchNorm1d(10), + nn.ReLU(), + nn.Conv1d(in_channels=10, out_channels=1, kernel_size=conf.kernal_size, stride=stride), + nn.ReLU(), + nn.Dropout(p=conf.dropout), + ) + + encoder_out_dim = 0 + input_dim = conf.mut_dim + for _ in range(3): + encoder_out_dim = int((input_dim - conf.kernal_size) / stride) + 1 + input_dim = encoder_out_dim + + self.out = nn.Linear(encoder_out_dim, conf.out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.size(-2) == 1 + x = x.float() + x_dim = x.dim() + x = x.unsqueeze(0) if x_dim != 3 else x + encoder_out = self.encoder(x) + out = self.out(encoder_out) + return out.squeeze(0) if x_dim != 3 else out diff --git a/tests/test_nets.py b/tests/test_nets.py index b3c9567..2ca50aa 100644 --- a/tests/test_nets.py +++ b/tests/test_nets.py @@ -16,6 +16,10 @@ def adamr2_conf(): def drugcell_conf(): return nets.Drugcell.DEFAULT_CONFIG +@pytest.fixture +def cnnmut_conf(): + return nets.CNNMut.DEFAULT_CONFIG + @pytest.fixture def prmo_conf(): return nets.PrmoEncoder.DEFAULT_CONFIG @@ -123,6 +127,20 @@ def test_DrugcellAdamr2MutSmis(adamr2_conf, drugcell_conf, drugcell_adamr2_mut_s assert out.size(0) == label.size(0) and out.size(1) == label.size(1) +def test_CNNMutAdamr2MutSmis(adamr2_conf, cnnmut_conf, drugcell_adamr2_mut_smis_ds): + label = drugcell_adamr2_mut_smis_ds[-1] + + model = nets.CNNMutAdamr2MutSmisXattn(cnnmut_conf, adamr2_conf) + out = model(*drugcell_adamr2_mut_smis_ds[:-1]) + assert out.dim() == 3 + assert out.size(0) == label.size(0) and out.size(1) == label.size(1) + + model = nets.CNNMutAdamr2MutSmisXattn(cnnmut_conf, adamr2_conf) + out = model(*drugcell_adamr2_mut_smis_ds[:-1]) + assert out.dim() == 3 + assert out.size(0) == label.size(0) and out.size(1) == label.size(1) + + def test_PrmoAdamr2MultiOmicsSmis(adamr2_conf, prmo_conf, prmo_adamr2_mut_smis_ds): label = prmo_adamr2_mut_smis_ds[-1] diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 5fba9c2..8ae3b1a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -34,6 +34,13 @@ class Pointwise(pipelines._MutSmi, pipelines._MutSmisRank): assert len(out) == 3 assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br" + model = models.MutSmisRankV2() + pipeline = pipelines.MutSmisRank(smi_tokenizer=smi_tkz, model=model) + out = pipeline(mutation, smiles) + assert isinstance(out, list) + assert len(out) == 3 + assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br" + def test_MultiOmicsSmisRank(smi_tkz): mutation = [random.choice([1, 0]) for _ in range(3008)]