From ce2537073c4be731ab086e4471a6f4e55bd1d455 Mon Sep 17 00:00:00 2001 From: Muhammed Hasan Celik Date: Mon, 10 Nov 2025 03:18:17 +0000 Subject: [PATCH 01/11] versining --- src/decima/cli/attributions.py | 3 +- src/decima/cli/callback.py | 7 +-- src/decima/cli/download.py | 3 +- src/decima/cli/modisco.py | 5 +- src/decima/cli/predict_genes.py | 6 ++- src/decima/cli/vep.py | 4 +- src/decima/constants.py | 17 +++++-- src/decima/core/result.py | 4 +- src/decima/data/dataset.py | 6 +-- src/decima/data/read_hdf5.py | 4 +- src/decima/hub/__init__.py | 3 +- src/decima/hub/download.py | 9 ++-- src/decima/interpret/attributions.py | 7 +-- src/decima/model/lightning.py | 5 +- src/decima/tools/inference.py | 3 +- src/decima/vep/__init__.py | 8 +-- tests/test_cli.py | 6 +-- tests/test_predict_gene_expression.py | 3 +- tests/test_sequence.py | 4 +- tests/test_vep.py | 71 ++++++++++++++------------- 20 files changed, 102 insertions(+), 76 deletions(-) diff --git a/src/decima/cli/attributions.py b/src/decima/cli/attributions.py index f2dab90..da8112d 100644 --- a/src/decima/cli/attributions.py +++ b/src/decima/cli/attributions.py @@ -17,6 +17,7 @@ """ import click +from decima.constants import DEFAULT_ENSEMBLE from decima.cli.callback import parse_genes, parse_model, parse_attributions from decima.interpret.attributions import ( plot_attributions, @@ -196,7 +197,7 @@ def cli_attributions_predict( "--model", type=str, required=False, - default="ensemble", + default=DEFAULT_ENSEMBLE, callback=parse_model, help="Model to use for attribution analysis either replicate number or path to the model.", show_default=True, diff --git a/src/decima/cli/callback.py b/src/decima/cli/callback.py index 1a24cc5..47b2246 100644 --- a/src/decima/cli/callback.py +++ b/src/decima/cli/callback.py @@ -1,13 +1,14 @@ import click from pathlib import Path +from decima.constants import AVAILABLE_ENSEMBLES def parse_model(ctx, param, value): if value is None: return None elif isinstance(value, str): - if value == "ensemble": - return "ensemble" + if value in AVAILABLE_ENSEMBLES: + return value elif value in ["0", "1", "2", "3"]: return int(value) @@ -32,7 +33,7 @@ def parse_genes(ctx, param, value): def validate_save_replicates(ctx, param, value): if value: - if ctx.params["model"] == "ensemble": + if ctx.params["model"] in AVAILABLE_ENSEMBLES: return value elif isinstance(ctx.params["model"], list) and (len(ctx.params["model"]) > 1): return value diff --git a/src/decima/cli/download.py b/src/decima/cli/download.py index a145cd1..6cee6b9 100644 --- a/src/decima/cli/download.py +++ b/src/decima/cli/download.py @@ -10,6 +10,7 @@ """ import click +from decima.constants import DEFAULT_ENSEMBLE from decima.cli.callback import parse_model from decima.hub.download import ( cache_decima_data, @@ -27,7 +28,7 @@ def cli_cache(): @click.command() @click.option( - "--model", type=str, default="ensemble", help="Model to download. Default: ensemble.", callback=parse_model + "--model", type=str, default=DEFAULT_ENSEMBLE, help="Model to download. Default: ensemble.", callback=parse_model ) @click.option( "--download-dir", diff --git a/src/decima/cli/modisco.py b/src/decima/cli/modisco.py index 264f285..a252488 100644 --- a/src/decima/cli/modisco.py +++ b/src/decima/cli/modisco.py @@ -21,6 +21,7 @@ import click from typing import List, Optional, Union +from decima.constants import DEFAULT_ENSEMBLE from decima.cli.callback import parse_model, parse_genes, parse_attributions from decima.interpret.modisco import ( predict_save_modisco_attributions, @@ -313,7 +314,7 @@ def cli_modisco_seqlet_bed( @click.option( "--model", type=str, - default="ensemble", + default=DEFAULT_ENSEMBLE, show_default=True, help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files. Default: `ensemble`.", callback=parse_model, @@ -406,7 +407,7 @@ def cli_modisco( tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, tss_distance: int = 10_000, - model: Optional[Union[str, int]] = "ensemble", + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, metadata: Optional[str] = None, method: str = "saliency", batch_size: int = 1, diff --git a/src/decima/cli/predict_genes.py b/src/decima/cli/predict_genes.py index 32a5934..f1fd587 100644 --- a/src/decima/cli/predict_genes.py +++ b/src/decima/cli/predict_genes.py @@ -8,6 +8,7 @@ import click from pathlib import Path +from decima.constants import DEFAULT_ENSEMBLE from decima.cli.callback import parse_model, parse_genes, validate_save_replicates from decima.tools.inference import predict_gene_expression @@ -25,9 +26,10 @@ "-m", "--model", type=str, - default="ensemble", + default=DEFAULT_ENSEMBLE, + show_default=True, callback=parse_model, - help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to checkpoint files", + help=f"`0`, `1`, `2`, `3`, `{DEFAULT_ENSEMBLE}` or a path or a comma-separated list of paths to checkpoint files", ) @click.option( "--metadata", diff --git a/src/decima/cli/vep.py b/src/decima/cli/vep.py index 795e74e..7cbb581 100644 --- a/src/decima/cli/vep.py +++ b/src/decima/cli/vep.py @@ -21,7 +21,7 @@ """ import click -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE from decima.cli.callback import parse_model, validate_save_replicates from decima.utils.dataframe import ensemble_predictions from decima.vep import predict_variant_effect @@ -46,7 +46,7 @@ @click.option( "--model", type=str, - default="ensemble", + default=DEFAULT_ENSEMBLE, callback=parse_model, help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files to perform variant effect prediction. Default: `ensemble`.", ) diff --git a/src/decima/constants.py b/src/decima/constants.py index ca2e086..976b409 100644 --- a/src/decima/constants.py +++ b/src/decima/constants.py @@ -1,13 +1,24 @@ """Decima constants.""" +import json import os -DECIMA_CONTEXT_SIZE = 524288 +DECIMA_CONTEXT_SIZE = 524_288 SUPPORTED_GENOMES = {"hg38"} NUM_CELLS = 8856 +DEFAULT_ENSEMBLE = "ensemble" +AVAILABLE_ENSEMBLES = [DEFAULT_ENSEMBLE] + +ENSEMBLE_MODELS_NAMES = dict() + if "DECIMA_ENSEMBLE_MODELS_NAMES" in os.environ: - ENSEMBLE_MODELS_NAMES = os.environ["DECIMA_ENSEMBLE_MODELS_NAMES"].split(",") + ENSEMBLE_MODELS_NAMES = json.loads(os.environ["DECIMA_ENSEMBLE_MODELS_NAMES"]) else: - ENSEMBLE_MODELS_NAMES = ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"] + ENSEMBLE_MODELS_NAMES["ensemble"] = ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"] + +assert all(ensemble_name in AVAILABLE_ENSEMBLES for ensemble_name in ENSEMBLE_MODELS_NAMES.keys()), ( + f"Invalid ensemble names: {ENSEMBLE_MODELS_NAMES.keys()}. Available ensembles are: {AVAILABLE_ENSEMBLES}" + "Check your `DECIMA_ENSEMBLE_MODELS_NAMES` environment variable if you are customizing the ensemble models." +) diff --git a/src/decima/core/result.py b/src/decima/core/result.py index 2c14d3b..fa19fae 100644 --- a/src/decima/core/result.py +++ b/src/decima/core/result.py @@ -7,7 +7,7 @@ from grelu.sequence.format import intervals_to_strings, strings_to_one_hot -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, AVAILABLE_ENSEMBLES from decima.hub import load_decima_metadata, load_decima_model from decima.core.metadata import GeneMetadata, CellMetadata from decima.tools.evaluate import marker_zscores @@ -172,7 +172,7 @@ def predicted_expression_matrix( Returns: pd.DataFrame: Predicted expression matrix (cells x genes) """ - model_name = "preds" if (model_name is None) or (model_name == "ensemble") else model_name + model_name = "preds" if (model_name is None) or (model_name in AVAILABLE_ENSEMBLES) else model_name if genes is None: return pd.DataFrame(self.anndata.layers[model_name], index=self.cells, columns=self.genes) else: diff --git a/src/decima/data/dataset.py b/src/decima/data/dataset.py index 0c1a341..3024ec3 100644 --- a/src/decima/data/dataset.py +++ b/src/decima/data/dataset.py @@ -22,7 +22,7 @@ from grelu.sequence.format import strings_to_one_hot from grelu.sequence.utils import reverse_complement -from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS_NAMES +from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS_NAMES, AVAILABLE_ENSEMBLES from decima.data.read_hdf5 import _extract_center, index_genes from decima.core.result import DecimaResult from decima.utils.io import read_fasta_gene_mask @@ -647,8 +647,8 @@ def __init__( if (model_name is None) or (not reference_cache): self.model_names = list() # no reference caching - elif model_name == "ensemble": - self.model_names = ENSEMBLE_MODELS_NAMES + elif model_name in AVAILABLE_ENSEMBLES: + self.model_names = ENSEMBLE_MODELS_NAMES[model_name] else: self.model_names = [model_name] diff --git a/src/decima/data/read_hdf5.py b/src/decima/data/read_hdf5.py index d588369..440b306 100644 --- a/src/decima/data/read_hdf5.py +++ b/src/decima/data/read_hdf5.py @@ -3,6 +3,8 @@ import torch from grelu.sequence.format import BASE_TO_INDEX_HASH, indices_to_one_hot +from decima.constants import DECIMA_CONTEXT_SIZE + def count_genes(h5_file, key=None): with h5py.File(h5_file, "r") as f: @@ -42,7 +44,7 @@ def _extract_center(x, seq_len, shift=0): return x[..., start : start + seq_len] -def extract_gene_data(h5_file, gene, seq_len=524288, merge=True): +def extract_gene_data(h5_file, gene, seq_len=DECIMA_CONTEXT_SIZE, merge=True): gene_idx = get_gene_idx(h5_file, key=None, gene=gene) with h5py.File(h5_file, "r") as f: diff --git a/src/decima/hub/__init__.py b/src/decima/hub/__init__.py index b6e7611..5263d30 100644 --- a/src/decima/hub/__init__.py +++ b/src/decima/hub/__init__.py @@ -6,6 +6,7 @@ from tempfile import TemporaryDirectory import anndata from grelu.resources import get_artifact, DEFAULT_WANDB_HOST +from decima.constants import DEFAULT_ENSEMBLE, AVAILABLE_ENSEMBLES from decima.model.lightning import LightningModel, EnsembleLightningModel @@ -37,7 +38,7 @@ def load_decima_model(model: Union[str, int, List[str]] = 0, device: Optional[st if isinstance(model, LightningModel): return model - elif model == "ensemble": + elif model in AVAILABLE_ENSEMBLES: return EnsembleLightningModel([load_decima_model(i, device) for i in range(4)]) elif isinstance(model, List): diff --git a/src/decima/hub/download.py b/src/decima/hub/download.py index 9e2fe72..76574d8 100644 --- a/src/decima/hub/download.py +++ b/src/decima/hub/download.py @@ -3,6 +3,7 @@ import logging import genomepy from grelu.resources import get_artifact +from decima.constants import DEFAULT_ENSEMBLE, AVAILABLE_ENSEMBLES, ENSEMBLE_MODELS_NAMES from decima.hub import login_wandb, load_decima_model, load_decima_metadata @@ -36,7 +37,7 @@ def cache_decima_data(): cache_decima_metadata() -def download_decima_weights(model_name: Union[str, int], download_dir: str): +def download_decima_weights(model_name: Union[str, int], download_dir: str, ensemble: str = DEFAULT_ENSEMBLE): """Download pre-trained Decima model weights from wandb. Args: @@ -46,11 +47,11 @@ def download_decima_weights(model_name: Union[str, int], download_dir: str): Returns: Path to the downloaded model weights. """ - if "ensemble" == model_name: + if DEFAULT_ENSEMBLE in AVAILABLE_ENSEMBLES: return [download_decima_weights(model, download_dir) for model in range(4)] if model_name in {0, 1, 2, 3}: - model_name = f"rep{model_name}" + model_name = ENSEMBLE_MODELS_NAMES[ensemble][model_name] download_dir = Path(download_dir) download_dir.mkdir(parents=True, exist_ok=True) @@ -92,6 +93,6 @@ def download_decima(download_dir: str): download_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Downloading Decima model weights and metadata to {download_dir}:") - download_decima_weights("ensemble", download_dir) + download_decima_weights(DEFAULT_ENSEMBLE, download_dir) download_decima_metadata(download_dir) return download_dir diff --git a/src/decima/interpret/attributions.py b/src/decima/interpret/attributions.py index 10388f9..1e536de 100644 --- a/src/decima/interpret/attributions.py +++ b/src/decima/interpret/attributions.py @@ -40,6 +40,7 @@ from torch.utils.data import DataLoader from pyfaidx import Faidx +from decima.constants import DEFAULT_ENSEMBLE, AVAILABLE_ENSEMBLES from decima.core.attribution import AttributionResult from decima.core.result import DecimaResult from decima.data.dataset import GeneDataset, SeqDataset @@ -119,8 +120,8 @@ def predict_save_attributions( ... genome="hg38", ... ) """ - if (model == "ensemble") or isinstance(model, (list, tuple)): - if model == "ensemble": + if (model in AVAILABLE_ENSEMBLES) or isinstance(model, (list, tuple)): + if model in AVAILABLE_ENSEMBLES: models = [0, 1, 2, 3] else: models = model @@ -338,7 +339,7 @@ def predict_attributions_seqlet_calling( seqs: Optional[Union[pd.DataFrame, np.ndarray, torch.Tensor]] = None, tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, - model: Optional[Union[str, int]] = "ensemble", + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, method: str = "inputxgradient", transform: str = "specificity", diff --git a/src/decima/model/lightning.py b/src/decima/model/lightning.py index fac5555..0a6fea1 100644 --- a/src/decima/model/lightning.py +++ b/src/decima/model/lightning.py @@ -19,6 +19,7 @@ from torchmetrics import MetricCollection import safetensors +from decima.constants import DEFAULT_ENSEMBLE from .decima_model import DecimaModel from .loss import TaskWisePoissonMultinomialLoss from .metrics import DiseaseLfcMSE, WarningCounter, GenePearsonCorrCoef @@ -523,7 +524,7 @@ def load_safetensor(cls, path: str, device: str = "cpu"): class EnsembleLightningModel(LightningModel): - def __init__(self, models: List[LightningModel], name="ensemble"): + def __init__(self, models: List[LightningModel], name=DEFAULT_ENSEMBLE): super().__init__( name=name, model_params=models[0].model_params, @@ -532,7 +533,7 @@ def __init__(self, models: List[LightningModel], name="ensemble"): ) self.models = nn.ModuleList(models) self.reset_transform() - self.name = "ensemble" + self.name = DEFAULT_ENSEMBLE def forward(self, x: Tensor) -> Tensor: return torch.concat([model(x) for model in self.models], dim=0) diff --git a/src/decima/tools/inference.py b/src/decima/tools/inference.py index 2e5aa32..425a49f 100644 --- a/src/decima/tools/inference.py +++ b/src/decima/tools/inference.py @@ -1,6 +1,7 @@ import anndata import logging import numpy as np +from decima.constants import DEFAULT_ENSEMBLE from decima.data.dataset import GeneDataset from decima.hub import load_decima_model from decima.utils import get_compute_device @@ -8,7 +9,7 @@ def predict_gene_expression( genes=None, - model="ensemble", + model=DEFAULT_ENSEMBLE, metadata_anndata=None, device=None, batch_size=1, diff --git a/src/decima/vep/__init__.py b/src/decima/vep/__init__.py index 39eb33c..82d137e 100644 --- a/src/decima/vep/__init__.py +++ b/src/decima/vep/__init__.py @@ -7,7 +7,7 @@ import pandas as pd from grelu.transforms.prediction_transforms import Aggregate -from decima.constants import SUPPORTED_GENOMES +from decima.constants import SUPPORTED_GENOMES, DEFAULT_ENSEMBLE from decima.model.metrics import WarningType from decima.utils import get_compute_device from decima.utils.dataframe import chunk_df, ChunkDataFrameWriter @@ -19,7 +19,7 @@ def _predict_variant_effect( df_variant: Union[pd.DataFrame, str], tasks: Optional[Union[str, List[str]]] = None, - model: Union[int, str] = "ensemble", + model: Union[int, str] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, batch_size: int = 1, num_workers: int = 16, @@ -120,7 +120,7 @@ def predict_variant_effect( df_variant: Union[pd.DataFrame, str], output_pq: Optional[str] = None, tasks: Optional[Union[str, List[str]]] = None, - model: Union[int, str, List[str]] = "ensemble", + model: Union[int, str, List[str]] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, chunksize: int = 10_000, batch_size: int = 1, @@ -142,7 +142,7 @@ def predict_variant_effect( df_variant (pd.DataFrame or str): DataFrame with variant information or path to variant file output_pq (str, optional): Path to save the parquet file. Defaults to None. tasks (str, optional): Tasks to predict. Defaults to None. - model (int, optional): Model to use. Defaults to "ensemble". + model (int, optional): Model to use. Defaults to DEFAULT_ENSEMBLE. metadata_anndata (str, optional): Path to anndata file. Defaults to None. chunksize (int, optional): Number of variants to predict in each chunk. Defaults to 10_000. batch_size (int, optional): Batch size. Defaults to 1. diff --git a/tests/test_cli.py b/tests/test_cli.py index 60ba05c..4439c03 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,7 +6,7 @@ from decima.cli import main from conftest import device -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE def test_cli_main(): @@ -254,7 +254,7 @@ def test_cli_vep_all_tasks_ensemble_custom_genome(tmp_path): "vep", "-v", "tests/data/variants.tsv", "-o", str(output_file), - "--model", "ensemble", + "--model", DEFAULT_ENSEMBLE, "--device", device, "--max-distance", "20000", "--chunksize", "5", @@ -277,7 +277,7 @@ def test_cli_vep_all_tasks_ensemble(tmp_path): "vep", "-v", "tests/data/variants.tsv", "-o", str(output_file), - "--model", "ensemble", + "--model", DEFAULT_ENSEMBLE, "--device", device, "--max-distance", "20000", "--chunksize", "5", diff --git a/tests/test_predict_gene_expression.py b/tests/test_predict_gene_expression.py index 66954e9..2dd54c6 100644 --- a/tests/test_predict_gene_expression.py +++ b/tests/test_predict_gene_expression.py @@ -1,4 +1,5 @@ import pytest +from decima.constants import DEFAULT_ENSEMBLE from decima.tools.inference import predict_gene_expression from conftest import device @@ -19,7 +20,7 @@ def test_predict_gene_expression(): ad = predict_gene_expression( genes=["SPI1", "GATA1"], - model="ensemble", device=device, + model=DEFAULT_ENSEMBLE, device=device, save_replicates=True, ) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c0876aa..e2a964e 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,8 +1,8 @@ - +from decima.constants import DECIMA_CONTEXT_SIZE from decima.utils.sequence import prepare_mask_gene def test_mask_gene(): mask = prepare_mask_gene(100, 200) - assert mask.shape == (1, 524288) + assert mask.shape == (1, DECIMA_CONTEXT_SIZE) assert mask[0, 150].item() == 1.0 diff --git a/tests/test_vep.py b/tests/test_vep.py index 650ffd0..313b7fa 100644 --- a/tests/test_vep.py +++ b/tests/test_vep.py @@ -5,6 +5,7 @@ import pyarrow.parquet as pq from scipy.stats import pearsonr +from decima.constants import DEFAULT_ENSEMBLE, DECIMA_CONTEXT_SIZE, NUM_CELLS from decima.core.result import DecimaResult from decima.hub import load_decima_model from decima.data.dataset import VariantDataset @@ -92,9 +93,9 @@ def test_VariantDataset(df_variant): ] assert len(dataset) == 82 * 2 - assert dataset[0]['seq'].shape == (5, 524288) + assert dataset[0]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) - assert dataset[0]['pred_expr']['v1_rep0'].shape == (8856,) + assert dataset[0]['pred_expr']['v1_rep0'].shape == (NUM_CELLS,) assert not dataset[0]['pred_expr']['v1_rep0'].isnan().any() assert dataset[1]['pred_expr']['v1_rep0'].isnan().all() assert not dataset[2]['pred_expr']['v1_rep0'].isnan().any() @@ -113,7 +114,7 @@ def test_VariantDataset(df_variant): assert cols.tolist() == [38435, 38435] # should be the same for both for i in range(len(dataset)): - assert dataset[i]['seq'].shape == (5, 524288) + assert dataset[i]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) rows, cols = np.where(dataset[162]['seq'] != dataset[163]['seq']) assert cols.min() == 505705 # the positions before should not be effected. @@ -122,11 +123,11 @@ def test_VariantDataset(df_variant): dataset = VariantDataset(df_variant, max_seq_shift=100) assert len(dataset) == 82 * 2 * 201 - assert dataset[0]['seq'].shape == (5, 524288) + assert dataset[0]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) for i in range(20): assert dataset[i]["warning"] == [] - assert dataset[i]['seq'].shape == (5, 524288) + assert dataset[i]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) assert dataset[44 * 2 * 201]["warning"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME] @@ -138,58 +139,58 @@ def test_VariantDataset(df_variant): @pytest.mark.long_running def test_VariantDataset_dataloader(df_variant): - dataset = VariantDataset(df_variant, model_name="ensemble") + dataset = VariantDataset(df_variant, model_name=DEFAULT_ENSEMBLE) dl = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=0, collate_fn=dataset.collate_fn) batches = iter(dl) batch = next(batches) - assert batch["seq"].shape == (64, 5, 524288) + assert batch["seq"].shape == (64, 5, DECIMA_CONTEXT_SIZE) assert batch["warning"] == [] - assert batch["pred_expr"]["v1_rep0"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep1"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep2"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep3"].shape == (64, 8856) + assert batch["pred_expr"]["v1_rep0"].shape == (64, NUM_CELLS) + assert batch["pred_expr"]["v1_rep1"].shape == (64, NUM_CELLS) + assert batch["pred_expr"]["v1_rep2"].shape == (64, NUM_CELLS) + assert batch["pred_expr"]["v1_rep3"].shape == (64, NUM_CELLS) batch = next(batches) - assert batch["seq"].shape == (64, 5, 524288) + assert batch["seq"].shape == (64, 5, DECIMA_CONTEXT_SIZE) assert len(batch["warning"]) > 0 assert WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME in batch["warning"] - assert batch["pred_expr"]["v1_rep0"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep1"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep2"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep3"].shape == (64, 8856) + assert batch["pred_expr"]["v1_rep0"].shape == (64, NUM_CELLS) + assert batch["pred_expr"]["v1_rep1"].shape == (64, NUM_CELLS) + assert batch["pred_expr"]["v1_rep2"].shape == (64, NUM_CELLS) + assert batch["pred_expr"]["v1_rep3"].shape == (64, NUM_CELLS) @pytest.mark.long_running def test_VariantDataset_dataloader_vcf(): df_variant = next(read_vcf_chunks("tests/data/test.vcf", 10000)) - dataset = VariantDataset(df_variant, model_name="ensemble", max_distance=20000) + dataset = VariantDataset(df_variant, model_name=DEFAULT_ENSEMBLE, max_distance=20000) dl = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=0, collate_fn=dataset.collate_fn) batches = iter(dl) batch = next(batches) - assert batch["seq"].shape == (8, 5, 524288) + assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE) assert batch["warning"] == [] - assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856) + assert batch["pred_expr"]['v1_rep0'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep1'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep2'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep3'].shape == (8, NUM_CELLS) batch = next(batches) - assert batch["seq"].shape == (8, 5, 524288) + assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE) assert batch["warning"] == [] - assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856) + assert batch["pred_expr"]['v1_rep0'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep1'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep2'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep3'].shape == (8, NUM_CELLS) batch = next(batches) - assert batch["seq"].shape == (8, 5, 524288) + assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE) assert len(batch["warning"]) > 0 - assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856) + assert batch["pred_expr"]['v1_rep0'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep1'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep2'].shape == (8, NUM_CELLS) + assert batch["pred_expr"]['v1_rep3'].shape == (8, NUM_CELLS) @pytest.mark.long_running @@ -225,7 +226,7 @@ def test_predict_variant_effect_save(df_variant, tmp_path): predict_variant_effect( df_variant, output_pq=str(output_file), - model="ensemble", + model=DEFAULT_ENSEMBLE, tasks=query, device=device, max_distance=5000, @@ -305,7 +306,7 @@ def test_predict_variant_effect_vcf_ensemble(tmp_path): predict_variant_effect( "tests/data/test.vcf", output_pq=str(output_file), - model="ensemble", + model=DEFAULT_ENSEMBLE, device=device, max_distance=20000, ) @@ -321,7 +322,7 @@ def test_predict_variant_effect_vcf_ensemble_replicates(tmp_path): predict_variant_effect( "tests/data/test.vcf", output_pq=str(output_file), - model="ensemble", + model=DEFAULT_ENSEMBLE, device=device, max_distance=20000, save_replicates=True, From e79417a9699e73fd20dafcac67517dab9ae6d595 Mon Sep 17 00:00:00 2001 From: Muhammed Hasan Celik Date: Thu, 13 Nov 2025 21:58:33 +0000 Subject: [PATCH 02/11] versioning new version --- .../1-attribution-motif-discovery.ipynb | 12 +-- docs/tutorials/3-finetune.html | 4 +- docs/tutorials/3-finetune.ipynb | 4 +- docs/tutorials/4-modisco.ipynb | 6 +- setup.cfg | 3 +- src/decima/__init__.py | 5 +- src/decima/cli/attributions.py | 2 +- src/decima/cli/callback.py | 8 +- src/decima/cli/download.py | 30 +++++-- src/decima/cli/finetune.py | 2 +- src/decima/cli/modisco.py | 2 +- src/decima/cli/query_cell.py | 13 ++- src/decima/constants.py | 45 +++++++--- src/decima/core/attribution.py | 10 +-- src/decima/core/result.py | 29 ++++--- src/decima/data/dataset.py | 10 +-- src/decima/hub/__init__.py | 85 +++++++++---------- src/decima/hub/download.py | 34 ++++---- src/decima/interpret/attributions.py | 22 ++--- src/decima/interpret/modisco.py | 18 ++-- src/decima/model/lightning.py | 2 +- src/decima/vep/__init__.py | 2 +- tests/test_lightning.py | 29 ++++--- tests/test_vep.py | 47 +++++----- 24 files changed, 244 insertions(+), 180 deletions(-) diff --git a/docs/tutorials/1-attribution-motif-discovery.ipynb b/docs/tutorials/1-attribution-motif-discovery.ipynb index c9f4400..ae114d2 100644 --- a/docs/tutorials/1-attribution-motif-discovery.ipynb +++ b/docs/tutorials/1-attribution-motif-discovery.ipynb @@ -122,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -159,7 +159,7 @@ } ], "source": [ - "! decima attributions --model 0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes" + "! decima attributions --model v1_rep0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes" ] }, { @@ -311,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -338,7 +338,7 @@ } ], "source": [ - "! decima attributions-predict --model 0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_0" + "! decima attributions-predict --model v1_rep0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_0" ] }, { @@ -1005,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1041,7 +1041,7 @@ } ], "source": [ - "! decima attributions --model 0 --seqs ../tests/data/seqs.fasta --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_custom_seqs" + "! decima attributions --model v1_rep0 --seqs ../tests/data/seqs.fasta --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_custom_seqs" ] }, { diff --git a/docs/tutorials/3-finetune.html b/docs/tutorials/3-finetune.html index 91e1626..d28d4aa 100644 --- a/docs/tutorials/3-finetune.html +++ b/docs/tutorials/3-finetune.html @@ -9135,7 +9135,7 @@

7. Generate training commands