diff --git a/.gitignore b/.gitignore index a6a1268..c599a11 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,4 @@ wandb/ docs/api docs/tutorials/example docs/wandb +run_tutorial.s* diff --git a/docs/tutorials/1-attribution-motif-discovery.ipynb b/docs/tutorials/1-attribution-motif-discovery.ipynb index c9f4400..76a75d5 100644 --- a/docs/tutorials/1-attribution-motif-discovery.ipynb +++ b/docs/tutorials/1-attribution-motif-discovery.ipynb @@ -122,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -159,7 +159,7 @@ } ], "source": [ - "! decima attributions --model 0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes" + "! decima attributions --model v1_rep0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes" ] }, { @@ -311,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -338,12 +338,12 @@ } ], "source": [ - "! decima attributions-predict --model 0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_0" + "! decima attributions-predict --model v1_rep0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_0" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -370,7 +370,7 @@ } ], "source": [ - "! decima attributions-predict --model 1 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_1" + "! decima attributions-predict --model v1_rep1 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_1" ] }, { @@ -1005,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1041,7 +1041,7 @@ } ], "source": [ - "! decima attributions --model 0 --seqs ../tests/data/seqs.fasta --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_custom_seqs" + "! decima attributions --model v1_rep0 --seqs ../tests/data/seqs.fasta --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_custom_seqs" ] }, { diff --git a/docs/tutorials/3-finetune.html b/docs/tutorials/3-finetune.html deleted file mode 100644 index 91e1626..0000000 --- a/docs/tutorials/3-finetune.html +++ /dev/null @@ -1,10657 +0,0 @@ - - - - - -3-finetune - - - - - - - - - - - - -
-
- -
- -
- -
-
- -
-
- -
- - -
-
- -
- - -
-
- -
- - -
-
- -
- - -
-
- -
-
- -
- - -
- - -
-
- -
-
- -
- - -
- - -
-
- -
- -
- - -
-
- -
- - -
- - -
-
- -
-
- -
-
- -
- - -
-
- -
- - -
-
- -
- -
-
- -
- - -
- - -
- - -
-
- -
-
- -
-
- -
- -
-
- -
-
- -
- -
- - -
-
- -
- -
-
- -
- -
- - -
- -
- -
-
- -
-
- -
- - -
- - -
- - -
-
- -
- -
- - -
-
- -
- - -
- - -
-
- - diff --git a/docs/tutorials/3-finetune.ipynb b/docs/tutorials/3-finetune.ipynb index 2c55b93..eecb02e 100644 --- a/docs/tutorials/3-finetune.ipynb +++ b/docs/tutorials/3-finetune.ipynb @@ -1507,7 +1507,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "id": "d0fdaa9d", "metadata": {}, "outputs": [ @@ -2579,7 +2579,7 @@ "source": [ "! CUDA_VISIBLE_DEVICES=0 decima finetune \\\n", "--name finetune_test_0 \\\n", - "--model 0 \\\n", + "--model v1_rep0 \\\n", "--device 0 \\\n", "--matrix-file {ad_file_path} \\\n", "--h5-file {h5_file_path} \\\n", diff --git a/docs/tutorials/4-modisco.ipynb b/docs/tutorials/4-modisco.ipynb index 24c7c5a..3944dd6 100644 --- a/docs/tutorials/4-modisco.ipynb +++ b/docs/tutorials/4-modisco.ipynb @@ -640,7 +640,7 @@ " --tasks \"cell_type.str.contains('neuron') and organ == 'CNS' and disease == 'healthy'\" \\\n", " --transform \"specificity\" \\\n", " --batch-size 1 \\\n", - " --model 0 \\\n", + " --model v1_rep0 \\\n", " --num-workers 8 \\\n", " -o example/modisco_neurons" ] @@ -1511,7 +1511,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "3b3faf06", "metadata": {}, "outputs": [ @@ -1547,7 +1547,7 @@ " --top-n-markers 50 \\\n", " --tasks \"cell_type.str.contains('neuron') and organ == 'CNS' and disease == 'healthy'\" \\\n", " --batch-size 1 \\\n", - " --model 0 \\\n", + " --model v1_rep0 \\\n", " --num-workers 8 \\\n", " -o example/modisco_subcommands/modisco_neurons_0" ] @@ -1562,7 +1562,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "dffdc854", "metadata": {}, "outputs": [ @@ -1598,7 +1598,7 @@ " --top-n-markers 50 \\\n", " --tasks \"cell_type.str.contains('neuron') and organ == 'CNS' and disease == 'healthy'\" \\\n", " --batch-size 1 \\\n", - " --model 1 \\\n", + " --model v1_rep1 \\\n", " --num-workers 8 \\\n", " -o example/modisco_subcommands/modisco_neurons_1" ] diff --git a/docs/tutorials/5-gene-expression-prediction.ipynb b/docs/tutorials/5-gene-expression-prediction.ipynb index da560b2..25413ab 100644 --- a/docs/tutorials/5-gene-expression-prediction.ipynb +++ b/docs/tutorials/5-gene-expression-prediction.ipynb @@ -46,10 +46,14 @@ "name": "stderr", "output_type": "stream", "text": [ + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", + " warnings.warn(\n", + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", + " warnings.warn(\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmhcelik\u001b[0m (\u001b[33mmhcw\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact metadata:latest, 3122.32MB. 1 files... \n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact 'metadata:latest', 3122.32MB. 1 files...\n", "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", - "Done. 0:0:1.6 (1995.7MB/s)\n" + "Done. 00:00:05.9 (533.2MB/s)\n" ] }, { @@ -1868,7 +1872,7 @@ ], "metadata": { "kernelspec": { - "display_name": "decima", + "display_name": "decima2", "language": "python", "name": "python3" }, @@ -1882,7 +1886,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.12" + "version": "3.11.14" } }, "nbformat": 4, diff --git a/setup.cfg b/setup.cfg index 467b7e1..d52eec3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,7 +71,8 @@ install_requires = pyarrow safetensors tangermeme>=1.0.0 - modisco-lite @ git+https://github.com/MuhammedHasan/tfmodisco-lite.git@faster-modisco + faster-modisco-lite>=3.0.0 + [options.packages.find] where = src diff --git a/src/decima/__init__.py b/src/decima/__init__.py index ebf1885..c94d98c 100644 --- a/src/decima/__init__.py +++ b/src/decima/__init__.py @@ -1,5 +1,5 @@ import sys -from decima.constants import NUM_CELLS, DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE, MODEL_METADATA from decima.core.result import DecimaResult from decima.interpret.attributions import predict_attributions_seqlet_calling from decima.vep import predict_variant_effect @@ -25,6 +25,7 @@ "DecimaResult", "predict_variant_effect", "predict_attributions_seqlet_calling", - "NUM_CELLS", "DECIMA_CONTEXT_SIZE", + "DEFAULT_ENSEMBLE", + "MODEL_METADATA", ] diff --git a/src/decima/cli/attributions.py b/src/decima/cli/attributions.py index f2dab90..783271a 100644 --- a/src/decima/cli/attributions.py +++ b/src/decima/cli/attributions.py @@ -17,7 +17,8 @@ """ import click -from decima.cli.callback import parse_genes, parse_model, parse_attributions +from decima.constants import DEFAULT_ENSEMBLE +from decima.cli.callback import parse_genes, parse_model, parse_attributions, parse_metadata from decima.interpret.attributions import ( plot_attributions, predict_save_attributions, @@ -47,16 +48,16 @@ "--model", type=str, required=False, - default=0, + default=DEFAULT_ENSEMBLE, callback=parse_model, help="Model to use for attribution analysis either replicate number or path to the model.", show_default=True, ) @click.option( "--metadata", - type=click.Path(exists=True), default=None, - help="Path to the metadata anndata file. If not provided, the default metadata will be downloaded from wandb.", + callback=parse_metadata, + help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.", show_default=True, ) @click.option( @@ -196,16 +197,16 @@ def cli_attributions_predict( "--model", type=str, required=False, - default="ensemble", + default=DEFAULT_ENSEMBLE, callback=parse_model, help="Model to use for attribution analysis either replicate number or path to the model.", show_default=True, ) @click.option( "--metadata", - type=click.Path(exists=True), + callback=parse_metadata, default=None, - help="Path to the metadata anndata file. Default: None.", + help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.", ) @click.option( "--method", type=str, required=False, default="inputxgradient", help="Method to use for attribution analysis." @@ -329,7 +330,12 @@ def cli_attributions( "--off-tasks", type=str, required=False, help="Optional query string to filter cell types to contrast against." ) @click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.") -@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") +@click.option( + "--metadata", + callback=parse_metadata, + default=None, + help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.", +) @click.option( "--genes", type=str, @@ -427,7 +433,12 @@ def cli_attributions_recursive_seqlet_calling( callback=parse_genes, help="Comma-separated list of gene symbols or IDs to analyze.", ) -@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") +@click.option( + "--metadata", + callback=parse_metadata, + default=None, + help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.", +) @click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.") @click.option("--seqlogo-window", type=int, default=50, help="Window size for sequence logo plots") @click.option("--custom-genome", is_flag=True, help="Use custom genome") diff --git a/src/decima/cli/callback.py b/src/decima/cli/callback.py index 1a24cc5..4af9770 100644 --- a/src/decima/cli/callback.py +++ b/src/decima/cli/callback.py @@ -1,15 +1,17 @@ import click from pathlib import Path +from decima.constants import MODEL_METADATA, ENSEMBLE_MODELS, DEFAULT_ENSEMBLE def parse_model(ctx, param, value): + if isinstance(value, int): + value = str(value) + if value is None: return None elif isinstance(value, str): - if value == "ensemble": - return "ensemble" - elif value in ["0", "1", "2", "3"]: - return int(value) + if value in MODEL_METADATA: + return value paths = value.split(",") for path in paths: @@ -32,7 +34,7 @@ def parse_genes(ctx, param, value): def validate_save_replicates(ctx, param, value): if value: - if ctx.params["model"] == "ensemble": + if ctx.params["model"] in ENSEMBLE_MODELS: return value elif isinstance(ctx.params["model"], list) and (len(ctx.params["model"]) > 1): return value @@ -43,6 +45,32 @@ def validate_save_replicates(ctx, param, value): return value +def parse_metadata(ctx, param, value): + if value is None: + if "model" in ctx.params: + model = ctx.params["model"] + else: + model = DEFAULT_ENSEMBLE + + if isinstance(model, list): + model = model[0] + if Path(model).exists(): + raise click.ClickException( + f"File path passed for model {model} but metadata filepath is not provided. Also, provide the metadata filepath." + ) + else: + return model + elif isinstance(value, str): + if value in MODEL_METADATA: + return value + elif Path(value).exists(): + return value + else: + raise click.ClickException( + f"Invalid name for the metadata dataset: {value}. Check if the name is correct or the metadata file exists." + ) + + def parse_attributions(ctx, param, value): value = value.split(",") for i in value: diff --git a/src/decima/cli/download.py b/src/decima/cli/download.py index a145cd1..726a43b 100644 --- a/src/decima/cli/download.py +++ b/src/decima/cli/download.py @@ -10,6 +10,7 @@ """ import click +from decima.constants import DEFAULT_ENSEMBLE from decima.cli.callback import parse_model from decima.hub.download import ( cache_decima_data, @@ -27,7 +28,11 @@ def cli_cache(): @click.command() @click.option( - "--model", type=str, default="ensemble", help="Model to download. Default: ensemble.", callback=parse_model + "--model", + type=str, + default=DEFAULT_ENSEMBLE, + help=f"Model to download. Default: {DEFAULT_ENSEMBLE}.", + callback=parse_model, ) @click.option( "--download-dir", @@ -41,21 +46,34 @@ def cli_download_weights(model, download_dir): @click.command() +@click.option( + "--metadata", + default=DEFAULT_ENSEMBLE, + help=f"Model to download metadata for using wandb. Default: {DEFAULT_ENSEMBLE}.", + callback=parse_model, +) @click.option( "--download-dir", type=click.Path(), default=".", help="Directory to download the metadata. Default: current directory.", ) -def cli_download_metadata(download_dir): - """Download pre-trained Decima metadata.""" - download_decima_metadata(str(download_dir)) +def cli_download_metadata(metadata=DEFAULT_ENSEMBLE, download_dir=None): + """Download pre-trained Decima metadata for a given model.""" + download_decima_metadata(metadata, str(download_dir)) @click.command() +@click.option( + "--model-name", + type=str, + default=DEFAULT_ENSEMBLE, + help=f"Model to download. Default: {DEFAULT_ENSEMBLE}.", + callback=parse_model, +) @click.option( "--download-dir", type=click.Path(), default=".", help="Directory to download the data. Default: current directory." ) -def cli_download(download_dir): +def cli_download(model_name=DEFAULT_ENSEMBLE, download_dir="."): """Download model weights and metadata for Decima.""" - download_decima(str(download_dir)) + download_decima(model_name, str(download_dir)) diff --git a/src/decima/cli/finetune.py b/src/decima/cli/finetune.py index 3105797..70105bb 100755 --- a/src/decima/cli/finetune.py +++ b/src/decima/cli/finetune.py @@ -68,7 +68,7 @@ def cli_finetune( Args: name: Name of the run for logging and checkpointing - model: Model path or replication number (0-3) + model: Borzoi model path or replication number (0-3) device: Device to use for training. Default: "0" matrix_file: Path to the matrix file containing training data h5_file: Path to the H5 file containing sequences diff --git a/src/decima/cli/modisco.py b/src/decima/cli/modisco.py index 264f285..c69800f 100644 --- a/src/decima/cli/modisco.py +++ b/src/decima/cli/modisco.py @@ -21,7 +21,8 @@ import click from typing import List, Optional, Union -from decima.cli.callback import parse_model, parse_genes, parse_attributions +from decima.constants import DEFAULT_ENSEMBLE +from decima.cli.callback import parse_model, parse_genes, parse_attributions, parse_metadata from decima.interpret.modisco import ( predict_save_modisco_attributions, modisco_patterns, @@ -45,8 +46,20 @@ default=None, help="Set of tasks will be subtracted from the attributions to calculate attribution on `specificity` transform. If not provided, all tasks will be computed.", ) -@click.option("--model", type=str, default=0, help="Model to use for the prediction.", callback=parse_model) -@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") +@click.option( + "--model", + type=str, + default=DEFAULT_ENSEMBLE, + help=f"Model to use for the prediction. Default: {DEFAULT_ENSEMBLE}.", + callback=parse_model, + show_default=True, +) +@click.option( + "--metadata", + default=None, + callback=parse_metadata, + help=f"Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.", +) @click.option( "--method", type=click.Choice(["saliency", "inputxgradient", "integratedgradients"]), @@ -59,7 +72,7 @@ type=click.Choice(["specificity", "aggregate"]), default="specificity", show_default=True, - help="Transform to use for attribution analysis.", + help="Transform to use for attribution analysis. Available options: 'specificity', 'aggregate'. Specificity transform is recommended for MoDISco to highlight cell-type-specific patterns.", ) @click.option("--batch-size", type=int, default=1, show_default=True, help="Batch size for the prediction.") @click.option( @@ -88,7 +101,7 @@ def cli_modisco_attributions( output_prefix: str, tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, - model: Optional[Union[str, int]] = 0, + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, metadata: Optional[str] = None, method: str = "saliency", transform: str = "specificity", @@ -144,7 +157,12 @@ def cli_modisco_attributions( help="Set of tasks will be subtracted from the attributions to calculate attribution on `specificity` transform. If not provided, all tasks will be computed.", ) @click.option("--tss-distance", type=int, default=10_000, show_default=True, help="TSS distance for the prediction.") -@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") +@click.option( + "--metadata", + default=None, + callback=parse_metadata, + help=f"Path to the metadata anndata file or name of the model. Default: {DEFAULT_ENSEMBLE}.", +) @click.option( "--genes", type=str, @@ -279,7 +297,12 @@ def cli_modisco_reports( @click.command() @click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files.") @click.option("--modisco-h5", type=click.Path(exists=True), required=True, help="Path to the modisco HDF5 file.") -@click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.") +@click.option( + "--metadata", + default=None, + callback=parse_metadata, + help=f"Path to the metadata anndata file or name of the model. Default: {DEFAULT_ENSEMBLE}.", +) @click.option("--trim-threshold", type=float, default=0.2, show_default=True, help="Trim threshold.") def cli_modisco_seqlet_bed( output_prefix: str, @@ -313,12 +336,17 @@ def cli_modisco_seqlet_bed( @click.option( "--model", type=str, - default="ensemble", + default=DEFAULT_ENSEMBLE, show_default=True, help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files. Default: `ensemble`.", callback=parse_model, ) -@click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.") +@click.option( + "--metadata", + default=None, + callback=parse_metadata, + help=f"Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.", +) @click.option( "--method", type=click.Choice(["saliency", "inputxgradient", "integratedgradients"]), @@ -326,6 +354,13 @@ def cli_modisco_seqlet_bed( show_default=True, help="Method to use for attribution analysis.", ) +@click.option( + "--transform", + type=click.Choice(["specificity", "aggregate"]), + default="specificity", + show_default=True, + help="Transform to use for attribution analysis. Available options: 'specificity', 'aggregate'. Specificity transform is recommended for MoDISco to highlight cell-type-specific patterns.", +) @click.option("--batch-size", type=int, default=1, show_default=True, help="Batch size for the prediction.") @click.option( "--genes", @@ -406,9 +441,10 @@ def cli_modisco( tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, tss_distance: int = 10_000, - model: Optional[Union[str, int]] = "ensemble", + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, metadata: Optional[str] = None, method: str = "saliency", + transform: str = "specificity", batch_size: int = 1, genes: Optional[str] = None, top_n_markers: Optional[int] = None, @@ -445,6 +481,7 @@ def cli_modisco( model=model, metadata_anndata=metadata, method=method, + transform=transform, batch_size=batch_size, genes=genes, top_n_markers=top_n_markers, diff --git a/src/decima/cli/predict_genes.py b/src/decima/cli/predict_genes.py index 32a5934..317706a 100644 --- a/src/decima/cli/predict_genes.py +++ b/src/decima/cli/predict_genes.py @@ -8,7 +8,8 @@ import click from pathlib import Path -from decima.cli.callback import parse_model, parse_genes, validate_save_replicates +from decima.constants import DEFAULT_ENSEMBLE +from decima.cli.callback import parse_model, parse_genes, validate_save_replicates, parse_metadata from decima.tools.inference import predict_gene_expression @@ -25,15 +26,16 @@ "-m", "--model", type=str, - default="ensemble", + default=DEFAULT_ENSEMBLE, + show_default=True, callback=parse_model, - help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to checkpoint files", + help=f"`0`, `1`, `2`, `3`, `{DEFAULT_ENSEMBLE}` or a path or a comma-separated list of paths to checkpoint files", ) @click.option( "--metadata", - type=click.Path(exists=True), default=None, - help="Path to the metadata anndata file. Default: None.", + callback=parse_metadata, + help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.", ) @click.option( "--device", diff --git a/src/decima/cli/query_cell.py b/src/decima/cli/query_cell.py index a6841c8..f86ba31 100644 --- a/src/decima/cli/query_cell.py +++ b/src/decima/cli/query_cell.py @@ -18,15 +18,20 @@ """ import click +from decima.constants import DEFAULT_ENSEMBLE +from decima.cli.callback import parse_metadata from decima.core.result import DecimaResult @click.command() @click.argument("query", default="") @click.option( - "--metadata-anndata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file." + "--metadata", + default=None, + callback=parse_metadata, + help=f"Path to the metadata anndata file or name of the model. Default: {DEFAULT_ENSEMBLE}.", ) -def cli_query_cell(query="", metadata_anndata=None): +def cli_query_cell(query="", metadata=None): """ Query a cell using query string @@ -42,7 +47,7 @@ def cli_query_cell(query="", metadata_anndata=None): ... """ - result = DecimaResult.load(metadata_anndata) + result = DecimaResult.load(metadata) df = result.cell_metadata if query != "": diff --git a/src/decima/cli/vep.py b/src/decima/cli/vep.py index 795e74e..7fc290b 100644 --- a/src/decima/cli/vep.py +++ b/src/decima/cli/vep.py @@ -21,8 +21,8 @@ """ import click -from decima.constants import DECIMA_CONTEXT_SIZE -from decima.cli.callback import parse_model, validate_save_replicates +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE +from decima.cli.callback import parse_model, validate_save_replicates, parse_metadata from decima.utils.dataframe import ensemble_predictions from decima.vep import predict_variant_effect @@ -46,15 +46,15 @@ @click.option( "--model", type=str, - default="ensemble", + default=DEFAULT_ENSEMBLE, callback=parse_model, help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files to perform variant effect prediction. Default: `ensemble`.", ) @click.option( "--metadata", - type=click.Path(exists=True), default=None, - help="Path to the metadata anndata file. Default: None.", + callback=parse_metadata, + help=f"Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.", ) @click.option( "--device", type=str, default=None, help="Device to use. Default: None which automatically selects the best device." diff --git a/src/decima/constants.py b/src/decima/constants.py index ca2e086..af70f35 100644 --- a/src/decima/constants.py +++ b/src/decima/constants.py @@ -1,13 +1,51 @@ """Decima constants.""" +import json import os -DECIMA_CONTEXT_SIZE = 524288 +# constants for all models +DECIMA_CONTEXT_SIZE = 524_288 SUPPORTED_GENOMES = {"hg38"} -NUM_CELLS = 8856 -if "DECIMA_ENSEMBLE_MODELS_NAMES" in os.environ: - ENSEMBLE_MODELS_NAMES = os.environ["DECIMA_ENSEMBLE_MODELS_NAMES"].split(",") -else: - ENSEMBLE_MODELS_NAMES = ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"] +# EDIT: following metadata to add new models; +# metadata of models models +# models has dict as values +# ensemble models have list of model names as values and fetched metadata from the models +# following fields are required in the metadata: +# - name of the models on wandb +# - number of cells of the model +# - metadata name in the wandb +# - model_path [optional] to the local model path +# - metadata_path [optional] to the local metadata path +MODEL_METADATA = { + "v1_rep0": {"name": "rep0", "num_tasks": 8856, "metadata": "metadata"}, + "v1_rep1": {"name": "rep1", "num_tasks": 8856, "metadata": "metadata"}, + "v1_rep2": {"name": "rep2", "num_tasks": 8856, "metadata": "metadata"}, + "v1_rep3": {"name": "rep3", "num_tasks": 8856, "metadata": "metadata"}, + "ensemble": ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"], +} +MODEL_METADATA["rep0"] = MODEL_METADATA["v1_rep0"] +MODEL_METADATA["rep1"] = MODEL_METADATA["v1_rep1"] +MODEL_METADATA["rep2"] = MODEL_METADATA["v1_rep2"] +MODEL_METADATA["rep3"] = MODEL_METADATA["v1_rep3"] +MODEL_METADATA[0] = MODEL_METADATA["v1_rep0"] +MODEL_METADATA[1] = MODEL_METADATA["v1_rep1"] +MODEL_METADATA[2] = MODEL_METADATA["v1_rep2"] +MODEL_METADATA[3] = MODEL_METADATA["v1_rep3"] +MODEL_METADATA["0"] = MODEL_METADATA["v1_rep0"] +MODEL_METADATA["1"] = MODEL_METADATA["v1_rep1"] +MODEL_METADATA["2"] = MODEL_METADATA["v1_rep2"] +MODEL_METADATA["3"] = MODEL_METADATA["v1_rep3"] + +# default version +DEFAULT_ENSEMBLE = "ensemble" + +# overwrite model metadata from environment variables +if "MODEL_METADATA" in os.environ: + MODEL_METADATA = json.loads(os.environ["MODEL_METADATA"]) + +if "DEFAULT_ENSEMBLE" in os.environ: + DEFAULT_ENSEMBLE = os.environ["DEFAULT_ENSEMBLE"] + +ENSEMBLE_MODELS = [k for k, v in MODEL_METADATA.items() if isinstance(v, list)] diff --git a/src/decima/core/attribution.py b/src/decima/core/attribution.py index 015a5dc..43d7938 100644 --- a/src/decima/core/attribution.py +++ b/src/decima/core/attribution.py @@ -20,7 +20,7 @@ from grelu.sequence.format import convert_input_type, strings_to_one_hot from grelu.visualize import plot_attributions -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE, MODEL_METADATA from decima.core.result import DecimaResult from decima.interpret.attributer import DecimaAttributer from decima.utils.sequence import one_hot_to_seq @@ -188,11 +188,11 @@ def from_seq( inputs: Union[str, torch.Tensor, np.ndarray], tasks: Optional[list] = None, off_tasks: Optional[list] = None, - model: Optional[Union[str, int]] = 0, + model: Optional[Union[str, int]] = MODEL_METADATA[DEFAULT_ENSEMBLE][0], transform: str = "specificity", method: str = "inputxgradient", device: Optional[str] = "cpu", - result: Optional[DecimaResult] = None, + result: Optional[str] = None, gene: Optional[str] = "", chrom: Optional[str] = None, start: Optional[int] = None, @@ -218,6 +218,7 @@ def from_seq( transform: Transformation to apply to attributions method: Method to use for attribution analysis available options: "saliency", "inputxgradient", "integratedgradients". device: Device to use for attribution analysis + result: Result object or path to result object or name of the model to load the result for. gene: Gene name chrom: Chromosome name start: Start position @@ -259,9 +260,7 @@ def from_seq( else: raise ValueError("`inputs` must be a string, torch.Tensor, or np.ndarray") - if result is None: - result = DecimaResult.load() - + result = DecimaResult.load(result or model) tasks, off_tasks = result.query_tasks(tasks, off_tasks) attrs = ( @@ -762,9 +761,7 @@ def _load_attribution( pattern_type=pattern_type, ) - def _get_metadata( - self, genes: List[str], metadata_anndata: Optional[DecimaResult] = None, custom_genome: bool = False - ): + def _get_metadata(self, genes: List[str], metadata_anndata: Optional[str] = None, custom_genome: bool = False): if custom_genome: chroms = genes starts = [0] * len(genes) @@ -773,7 +770,10 @@ def _get_metadata( else: ends = [DECIMA_CONTEXT_SIZE] * len(genes) else: - result = DecimaResult.load(metadata_anndata) + model_name = self.model_name + if isinstance(model_name, list): + model_name = model_name[0] + result = DecimaResult.load(metadata_anndata or model_name) chroms = result.gene_metadata.loc[genes].chrom if self.tss_distance is not None: tss_pos = np.where( @@ -791,7 +791,7 @@ def _get_metadata( def load_attribution( self, gene: str, - metadata_anndata: Optional[DecimaResult] = None, + metadata_anndata: Optional[str] = None, custom_genome: bool = False, threshold: float = 5e-4, min_seqlet_len: int = 4, @@ -871,7 +871,7 @@ def _recursive_seqlet_calling( def recursive_seqlet_calling( self, genes: Optional[List[str]] = None, - metadata_anndata: Optional[DecimaResult] = None, + metadata_anndata: Optional[str] = None, custom_genome: bool = False, threshold: float = 5e-4, min_seqlet_len: int = 4, diff --git a/src/decima/core/metadata.py b/src/decima/core/metadata.py index 51677a4..0d1361b 100644 --- a/src/decima/core/metadata.py +++ b/src/decima/core/metadata.py @@ -101,9 +101,6 @@ class CellMetadata: disease: str study: str dataset: str - region: Optional[str] - subregion: Optional[str] - celltype_coarse: Optional[str] n_cells: int total_counts: float n_genes: int @@ -111,6 +108,12 @@ class CellMetadata: train_pearson: float val_pearson: float test_pearson: float + region: Optional[str] = field(default=None) + subregion: Optional[str] = field(default=None) + celltype_coarse: Optional[str] = field(default=None) + co_term: Optional[str] = field(default=None) + co_name: Optional[str] = field(default=None) + frac_nan: Optional[float] = field(default=None) @classmethod def from_series(cls, name: str, series: pd.Series) -> "CellMetadata": diff --git a/src/decima/core/result.py b/src/decima/core/result.py index 2c14d3b..a7c3ac3 100644 --- a/src/decima/core/result.py +++ b/src/decima/core/result.py @@ -7,7 +7,7 @@ from grelu.sequence.format import intervals_to_strings, strings_to_one_hot -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE, MODEL_METADATA from decima.hub import load_decima_metadata, load_decima_model from decima.core.metadata import GeneMetadata, CellMetadata from decima.tools.evaluate import marker_zscores @@ -59,11 +59,12 @@ def __init__(self, anndata): self._model = None @classmethod - def load(cls, anndata_path: Optional[Union[str, anndata.AnnData]] = None): + def load(cls, anndata_name_or_path: Optional[Union[str, anndata.AnnData]] = None): """Load a DecimaResult object from an anndata file or a path to an anndata file. Args: - anndata_path: Path to anndata file or anndata object + anndata_name_or_path: Name of the model or path to anndata file or anndata object + model: Model name or path to model checkpoint. If not provided, the default model will be loaded. Returns: DecimaResult object @@ -74,16 +75,19 @@ def load(cls, anndata_path: Optional[Union[str, anndata.AnnData]] = None): ... "path/to/anndata.h5ad" ... ) # Load custom anndata object from file """ - if anndata_path is None: - return cls(load_decima_metadata()) - elif isinstance(anndata_path, str): - return cls(anndata.read_h5ad(anndata_path)) - elif isinstance(anndata_path, anndata.AnnData): - return cls(anndata_path) - elif isinstance(anndata_path, DecimaResult): - return anndata_path + if isinstance(anndata_name_or_path, list): + anndata_name_or_path = anndata_name_or_path[0] + + if (anndata_name_or_path is None) or (anndata_name_or_path in MODEL_METADATA): + return cls(load_decima_metadata(name_or_path=anndata_name_or_path)) + elif isinstance(anndata_name_or_path, str): + return cls(anndata.read_h5ad(anndata_name_or_path)) + elif isinstance(anndata_name_or_path, anndata.AnnData): + return cls(anndata_name_or_path) + elif isinstance(anndata_name_or_path, DecimaResult): + return anndata_name_or_path else: - raise ValueError(f"Invalid anndata path: {anndata_path}") + raise ValueError(f"Invalid anndata path: {anndata_name_or_path}") @property def model(self): @@ -92,7 +96,7 @@ def model(self): self.load_model() return self._model - def load_model(self, model: Optional[Union[str, int]] = 0, device: str = "cpu"): + def load_model(self, model: Optional[Union[str, int]] = MODEL_METADATA[DEFAULT_ENSEMBLE][0], device: str = "cpu"): """Load the trained model from a checkpoint path. Args: @@ -172,7 +176,7 @@ def predicted_expression_matrix( Returns: pd.DataFrame: Predicted expression matrix (cells x genes) """ - model_name = "preds" if (model_name is None) or (model_name == "ensemble") else model_name + model_name = "preds" if (model_name is None) or (model_name in MODEL_METADATA) else model_name if genes is None: return pd.DataFrame(self.anndata.layers[model_name], index=self.cells, columns=self.genes) else: @@ -222,7 +226,7 @@ def prepare_one_hot( Returns: torch.Tensor: One-hot encoding of the gene """ - assert gene in self.genes, f"{gene} is not in the anndata object" + assert gene in self.genes, f"{gene} is not in the anndata object. See avaliable genes with `result.genes`." gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding) if variants is None: @@ -249,11 +253,7 @@ def gene_sequence(self, gene: str, stranded: bool = True, genome: str = "hg38") Returns: str: Sequence for the gene """ - try: - assert gene in self.genes, f"{gene} is not in the anndata object" - except AssertionError: - print(gene) - print(self.genes) + assert gene in self.genes, f"{gene} is not in the anndata object. See avaliable genes with `result.genes`." gene_meta = self.gene_metadata.loc[gene] if not stranded: gene_meta = {"chrom": gene_meta.chrom, "start": gene_meta.start, "end": gene_meta.end} diff --git a/src/decima/data/dataset.py b/src/decima/data/dataset.py index 0c1a341..f73df53 100644 --- a/src/decima/data/dataset.py +++ b/src/decima/data/dataset.py @@ -8,7 +8,7 @@ - VariantDataset: Dataset for variant effect prediction. """ -from typing import List +from typing import List, Optional import warnings import torch import h5py @@ -22,7 +22,7 @@ from grelu.sequence.format import strings_to_one_hot from grelu.sequence.utils import reverse_complement -from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS_NAMES +from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS, MODEL_METADATA from decima.data.read_hdf5 import _extract_center, index_genes from decima.core.result import DecimaResult from decima.utils.io import read_fasta_gene_mask @@ -186,11 +186,11 @@ class GeneDataset(Dataset): def __init__( self, genes=None, - metadata_anndata=None, - max_seq_shift=0, - seed=0, - augment_mode="random", - genome="hg38", + metadata_anndata: Optional[str] = None, + max_seq_shift: int = 0, + seed: int = 0, + augment_mode: str = "random", + genome: str = "hg38", ): super().__init__() @@ -599,24 +599,24 @@ class VariantDataset(Dataset): def __init__( self, variants, - metadata_anndata=None, - max_seq_shift=0, - seed=0, + metadata_anndata: Optional[str] = None, + max_seq_shift: int = 0, + seed: int = 0, include_cols=None, - gene_col=None, - min_from_end=0, - distance_type="tss", - min_distance=0, - max_distance=float("inf"), - model_name=None, - reference_cache=True, - genome="hg38", + gene_col: Optional[str] = None, + min_from_end: int = 0, + distance_type: str = "tss", + min_distance: int = 0, + max_distance: float = float("inf"), + model_name: Optional[str] = None, + reference_cache: bool = True, + genome: str = "hg38", ): super().__init__() self.reference_cache = reference_cache self.genome = genome - self.result = DecimaResult.load(metadata_anndata) + self.result = DecimaResult.load(metadata_anndata or model_name) self.variants = self._overlap_genes( variants, @@ -647,8 +647,8 @@ def __init__( if (model_name is None) or (not reference_cache): self.model_names = list() # no reference caching - elif model_name == "ensemble": - self.model_names = ENSEMBLE_MODELS_NAMES + elif model_name in ENSEMBLE_MODELS: + self.model_names = MODEL_METADATA[model_name] else: self.model_names = [model_name] diff --git a/src/decima/data/read_hdf5.py b/src/decima/data/read_hdf5.py index d588369..440b306 100644 --- a/src/decima/data/read_hdf5.py +++ b/src/decima/data/read_hdf5.py @@ -3,6 +3,8 @@ import torch from grelu.sequence.format import BASE_TO_INDEX_HASH, indices_to_one_hot +from decima.constants import DECIMA_CONTEXT_SIZE + def count_genes(h5_file, key=None): with h5py.File(h5_file, "r") as f: @@ -42,7 +44,7 @@ def _extract_center(x, seq_len, shift=0): return x[..., start : start + seq_len] -def extract_gene_data(h5_file, gene, seq_len=524288, merge=True): +def extract_gene_data(h5_file, gene, seq_len=DECIMA_CONTEXT_SIZE, merge=True): gene_idx = get_gene_idx(h5_file, key=None, gene=gene) with h5py.File(h5_file, "r") as f: diff --git a/src/decima/hub/__init__.py b/src/decima/hub/__init__.py index b6e7611..e50c657 100644 --- a/src/decima/hub/__init__.py +++ b/src/decima/hub/__init__.py @@ -1,4 +1,5 @@ import os +import json from typing import Union, Optional, List import warnings import wandb @@ -6,18 +7,20 @@ from tempfile import TemporaryDirectory import anndata from grelu.resources import get_artifact, DEFAULT_WANDB_HOST +from decima.constants import DEFAULT_ENSEMBLE, ENSEMBLE_MODELS, MODEL_METADATA from decima.model.lightning import LightningModel, EnsembleLightningModel def login_wandb(): """Login to wandb either as anonymous or as a user.""" + host = os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST) try: - wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), anonymous="never", timeout=0) + wandb.login(host=host, anonymous="never", timeout=0) except wandb.errors.UsageError: # login anonymously if not logged in already - wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), relogin=True, anonymous="must", timeout=0) + wandb.login(host=host, relogin=True, anonymous="must", timeout=0) -def load_decima_model(model: Union[str, int, List[str]] = 0, device: Optional[str] = None): +def load_decima_model(model: Union[str, int, List[str]] = DEFAULT_ENSEMBLE, device: Optional[str] = None): """Load a pre-trained Decima model from wandb or local path. Args: @@ -37,74 +40,80 @@ def load_decima_model(model: Union[str, int, List[str]] = 0, device: Optional[st if isinstance(model, LightningModel): return model - elif model == "ensemble": - return EnsembleLightningModel([load_decima_model(i, device) for i in range(4)]) + elif model in ENSEMBLE_MODELS: + return EnsembleLightningModel( + [load_decima_model(model_name, device) for model_name in MODEL_METADATA[model]], + name=model, + ) elif isinstance(model, List): if len(model) == 1: return load_decima_model(model[0], device) else: - return EnsembleLightningModel([load_decima_model(path, device) for path in model]) - - elif model in {0, 1, 2, 3}: - model_name = f"rep{model}" + return EnsembleLightningModel([load_decima_model(path, device) for path in model], name=model) # Load directly from a path - elif isinstance(model, str): - if Path(model).exists(): - if model.endswith("ckpt"): - return LightningModel.load_from_checkpoint(model, map_location=device) - else: - return LightningModel.load_safetensor(model, device=device) + if model in MODEL_METADATA: + model_name = MODEL_METADATA[model]["name"] + if "model_path" in MODEL_METADATA[model]: # if model path exist in metadata load it from the path + return load_decima_model(MODEL_METADATA[model]["model_path"], device) + elif isinstance(model, str) and Path(model).exists(): + if model.endswith("ckpt"): + return LightningModel.load_from_checkpoint(model, map_location=device) else: - model_name = model - + return LightningModel.load_safetensor(model, device=device) else: raise ValueError( - f"Invalid model: {model} it needs to be either a string of model_names on wandb, " - "an integer of replicate number {0, 1, 2, 3}, a path to a local model or a list of paths." + f"Invalid model: {model} it needs to be either a string of model_names on wandb (" + f"{list(MODEL_METADATA.keys())}), path to a local model, or a list of paths." ) - - # If left with a model name, load from environment/wandb - if model_name.upper() in os.environ: - if Path(os.environ[model_name.upper()]).exists(): - return LightningModel.load_safetensor(os.environ[model_name.upper()], device=device) - else: - warnings.warn( - f"Model `{model_name}` provided in environment variables, " - f"but not found in `{os.environ[model_name.upper()]}` " - f"Trying to download `{model_name}` from wandb." - ) - - art = get_artifact(model_name, project="decima") + # load model from wandb + art = get_artifact( + model_name, + project="decima", + host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), + ) with TemporaryDirectory() as d: art.download(d) return LightningModel.load_safetensor(Path(d) / f"{model_name}.safetensors", device=device) -def load_decima_metadata(path: Optional[str] = None): +def load_decima_metadata(name_or_path: Optional[str] = None): """Load the Decima metadata from wandb. Args: - path: Path to local metadata file. If None, downloads from wandb. + name_or_path: Path to local metadata file or name of the model to load metadata for using wandb. If None, default model's metadata will be downloaded from wandb. Returns: An AnnData object containing the Decima metadata. """ - if path is not None: - return anndata.read_h5ad(path) + if name_or_path is not None: + if Path(name_or_path).exists(): + return anndata.read_h5ad(name_or_path) + + name_or_path = name_or_path or DEFAULT_ENSEMBLE + + if name_or_path in ENSEMBLE_MODELS: + name_or_path = MODEL_METADATA[name_or_path][0] + + if name_or_path in MODEL_METADATA: + metadata = MODEL_METADATA[name_or_path] - if "DECIMA_METADATA" in os.environ: - if Path(os.environ["DECIMA_METADATA"]).exists(): - return anndata.read_h5ad(os.environ["DECIMA_METADATA"]) + if "metadata_path" in metadata: + if Path(metadata["metadata_path"]).exists(): + return anndata.read_h5ad(metadata["metadata_path"]) else: warnings.warn( - f"Metadata `{os.environ['DECIMA_METADATA']}` provided in environment variables, " - f"but not found in `{os.environ['DECIMA_METADATA']}` " + f"Metadata `{metadata['metadata_path']}` provided in environment variables, " + f"but not found in `{metadata['metadata_path']}` " f"Trying to download `metadata` from wandb." ) - art = get_artifact("metadata", project="decima") + art = get_artifact( + metadata["metadata"], + project="decima", + host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), + ) with TemporaryDirectory() as d: art.download(d) - return anndata.read_h5ad(Path(d) / "metadata.h5ad") + return anndata.read_h5ad(Path(d) / f"{metadata['metadata']}.h5ad") diff --git a/src/decima/hub/download.py b/src/decima/hub/download.py index 9e2fe72..6788e74 100644 --- a/src/decima/hub/download.py +++ b/src/decima/hub/download.py @@ -1,8 +1,10 @@ +import os from pathlib import Path from typing import Union import logging import genomepy -from grelu.resources import get_artifact +from grelu.resources import get_artifact, DEFAULT_WANDB_HOST +from decima.constants import DEFAULT_ENSEMBLE, ENSEMBLE_MODELS, MODEL_METADATA from decima.hub import login_wandb, load_decima_model, load_decima_metadata @@ -18,7 +20,7 @@ def cache_hg38(): def cache_decima_weights(): """Download pre-trained Decima model weights from wandb.""" logger.info("Downloading Decima model weights...") - for rep in range(4): + for rep in MODEL_METADATA[DEFAULT_ENSEMBLE]: load_decima_model(rep) @@ -36,7 +38,7 @@ def cache_decima_data(): cache_decima_metadata() -def download_decima_weights(model_name: Union[str, int], download_dir: str): +def download_decima_weights(model: Union[str, int] = DEFAULT_ENSEMBLE, download_dir: str = "."): """Download pre-trained Decima model weights from wandb. Args: @@ -46,40 +48,44 @@ def download_decima_weights(model_name: Union[str, int], download_dir: str): Returns: Path to the downloaded model weights. """ - if "ensemble" == model_name: - return [download_decima_weights(model, download_dir) for model in range(4)] - - if model_name in {0, 1, 2, 3}: - model_name = f"rep{model_name}" + if model in ENSEMBLE_MODELS: + return [download_decima_weights(model, download_dir) for model in MODEL_METADATA[model]] + model_name = MODEL_METADATA[model]["name"] download_dir = Path(download_dir) download_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Downloading Decima model weights for {model_name} to {download_dir / f'{model_name}.safetensors'}") + logger.info(f"Downloading Decima model weights for {model} to {download_dir / f'{model_name}.safetensors'}") - art = get_artifact(model_name, project="decima") + art = get_artifact(model_name, project="decima", host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST)) art.download(str(download_dir)) return download_dir / f"{model_name}.safetensors" -def download_decima_metadata(download_dir: str): +def download_decima_metadata(metadata: str = DEFAULT_ENSEMBLE, download_dir: str = "."): """Download pre-trained Decima model data from wandb. Args: download_dir: Directory to download the metadata. + metadata: Name of the model to download metadata for using wandb. Returns: Path to the downloaded metadata. """ - art = get_artifact("metadata", project="decima") + metadata = metadata or DEFAULT_ENSEMBLE + if metadata in ENSEMBLE_MODELS: + metadata = MODEL_METADATA[metadata][0] + + metadata_name = MODEL_METADATA[metadata]["metadata"] + art = get_artifact(metadata_name, project="decima", host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST)) download_dir = Path(download_dir) download_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Downloading Decima metadata to {download_dir / 'metadata.h5ad'}.") + logger.info(f"Downloading Decima metadata to {download_dir / f'{metadata_name}.h5ad'}.") art.download(str(download_dir)) - return download_dir / "metadata.h5ad" + return download_dir / f"{metadata_name}.h5ad" -def download_decima(download_dir: str): +def download_decima(model: str = DEFAULT_ENSEMBLE, download_dir: str = "."): """Download all required data for Decima. Args: @@ -92,6 +98,6 @@ def download_decima(download_dir: str): download_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Downloading Decima model weights and metadata to {download_dir}:") - download_decima_weights("ensemble", download_dir) - download_decima_metadata(download_dir) + download_decima_weights(model, download_dir) + download_decima_metadata(model, download_dir) return download_dir diff --git a/src/decima/interpret/attributions.py b/src/decima/interpret/attributions.py index 10388f9..31b5922 100644 --- a/src/decima/interpret/attributions.py +++ b/src/decima/interpret/attributions.py @@ -40,9 +40,11 @@ from torch.utils.data import DataLoader from pyfaidx import Faidx +from decima.constants import DEFAULT_ENSEMBLE, MODEL_METADATA, ENSEMBLE_MODELS from decima.core.attribution import AttributionResult from decima.core.result import DecimaResult from decima.data.dataset import GeneDataset, SeqDataset +from decima.hub import load_decima_model from decima.interpret.attributer import DecimaAttributer from decima.utils import get_compute_device, _get_on_off_tasks, _get_genes from decima.utils.io import AttributionWriter @@ -53,7 +55,7 @@ def predict_save_attributions( output_prefix: str, tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, - model: Optional[int] = 0, + model: Optional[int] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, method: str = "inputxgradient", transform: str = "specificity", @@ -119,9 +121,9 @@ def predict_save_attributions( ... genome="hg38", ... ) """ - if (model == "ensemble") or isinstance(model, (list, tuple)): - if model == "ensemble": - models = [0, 1, 2, 3] + if (model in ENSEMBLE_MODELS) or isinstance(model, (list, tuple)): + if model in ENSEMBLE_MODELS: + models = MODEL_METADATA[model] else: models = model return [ @@ -154,17 +156,17 @@ def predict_save_attributions( device = get_compute_device(device) logger.info(f"Using device: {device}") - logger.info("Loading model and metadata to compute attributions...") - result = DecimaResult.load(metadata_anndata) + logger.info(f"Loading model {model} and metadata to compute attributions...") + model = load_decima_model(model, device=device) + result = DecimaResult.load(metadata_anndata or model.name) tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks) + attributer = DecimaAttributer(model, tasks, off_tasks, method, transform) - with QCLogger(str(output_prefix) + ".warnings.qc.log", metadata_anndata=metadata_anndata) as qc: + with QCLogger(str(output_prefix) + ".warnings.qc.log", metadata_anndata=result) as qc: if result.ground_truth is not None: qc.log_correlation(tasks, off_tasks, plot=True) - attributer = DecimaAttributer.load_decima_attributer(model, tasks, off_tasks, method, transform, device=device) - if (genes is not None) and (seqs is not None): raise ValueError("Only one of `genes` or `seqs` arguments must be provided not both.") elif seqs is not None: @@ -295,8 +297,6 @@ def recursive_seqlet_calling( logger = logging.getLogger("decima") logger.info("Loading model and metadata to compute attributions...") - result = DecimaResult.load(metadata_anndata) - if isinstance(attributions, (str, Path)): attributions_files = [Path(attributions).as_posix()] else: @@ -305,6 +305,8 @@ def recursive_seqlet_calling( with AttributionResult( attributions_files, tss_distance, correct_grad=False, num_workers=num_workers, agg_func=agg_func ) as ar: + result = DecimaResult.load(metadata_anndata or ar.model_name) + if top_n_markers is not None: tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks) all_genes = _get_genes(result, genes, top_n_markers, tasks, off_tasks) @@ -338,7 +340,7 @@ def predict_attributions_seqlet_calling( seqs: Optional[Union[pd.DataFrame, np.ndarray, torch.Tensor]] = None, tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, - model: Optional[Union[str, int]] = "ensemble", + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, method: str = "inputxgradient", transform: str = "specificity", diff --git a/src/decima/interpret/modisco.py b/src/decima/interpret/modisco.py index 3a4ddc2..4147fb6 100644 --- a/src/decima/interpret/modisco.py +++ b/src/decima/interpret/modisco.py @@ -21,12 +21,12 @@ import h5py import numpy as np import pandas as pd -import modiscolite +import fastermodiscolite from tqdm import tqdm from grelu.resources import get_meme_file_path from grelu.interpret.motifs import trim_pwm -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE from decima.core.result import DecimaResult from decima.utils import _get_on_off_tasks, _get_genes from decima.utils.motifs import motif_start_end @@ -38,7 +38,7 @@ def predict_save_modisco_attributions( output_prefix: str, tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, - model: Optional[int] = 0, + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, method: str = "saliency", transform: str = "specificity", @@ -156,7 +156,7 @@ def modisco_patterns( tasks: Tasks to analyze either list of task names or query string to filter cell types to analyze attributions for (e.g. 'cell_type == 'classical monocyte''). If not provided, all tasks will be analyzed. off_tasks: Off tasks to analyze either list of task names or query string to filter cell types to contrast against (e.g. 'cell_type == 'classical monocyte''). If not provided, all tasks will be used as off tasks. tss_distance: Distance from TSS to analyze for pattern discovery default is 10000. Controls the genomic window size around TSS for seqlet detection and motif discovery. - metadata_anndata: Path to metadata anndata file or DecimaResult object. If not provided, the default metadata will be used from the attribution files. + metadata_anndata: Name of the model or path to metadata anndata file or DecimaResult object. If not provided, the compatible metadata of the saved attribution files will be used. genes: Genes to analyze for pattern discovery if not provided, all genes will be used. Can be list of gene symbols or IDs to focus analysis on specific genes. top_n_markers: Top n markers to analyze for pattern discovery if not provided, all markers will be analyzed. Useful for focusing on the most important marker genes for the specified tasks. correct_grad: Whether to correct gradient for attribution analysis default is True. Applies gradient correction for better attribution quality before pattern discovery. @@ -206,23 +206,27 @@ def modisco_patterns( ... ) """ logger = logging.getLogger("decima") - logger.info("Loading metadata") - result = DecimaResult.load(metadata_anndata) if isinstance(attributions, (str, Path)): attributions_files = [Path(attributions).as_posix()] else: attributions_files = attributions - tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks) - all_genes = _get_genes(result, genes, top_n_markers, tasks, off_tasks) - - with AttributionResult(attributions_files, tss_distance, correct_grad, num_workers=1, agg_func="mean") as ar: - sequences, attributions = ar.load(all_genes) + with AttributionResult( + attributions_files, tss_distance, correct_grad, num_workers=num_workers, agg_func="mean" + ) as ar: genome = ar.genome model_names = ar.model_name - pos_patterns, neg_patterns = modiscolite.tfmodisco.TFMoDISco( + metadata_anndata = metadata_anndata or model_names[0] + logger.info(f"Loading metadata for model {metadata_anndata}...") + result = DecimaResult.load(metadata_anndata) + + tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks) + all_genes = _get_genes(result, genes, top_n_markers, tasks, off_tasks) + sequences, attributions = ar.load(all_genes) + + pos_patterns, neg_patterns = fastermodiscolite.tfmodisco.TFMoDISco( hypothetical_contribs=attributions.transpose(0, 2, 1), one_hot=sequences.transpose(0, 2, 1), sliding_window_size=sliding_window_size, @@ -258,7 +262,7 @@ def modisco_patterns( verbose=True, ) h5_path = Path(output_prefix).with_suffix(".modisco.h5").as_posix() - modiscolite.io.save_hdf5( + fastermodiscolite.io.save_hdf5( h5_path, pos_patterns, neg_patterns, @@ -310,7 +314,7 @@ def modisco_reports( """ output_dir = Path(f"{output_prefix}_report") output_dir.mkdir(parents=True, exist_ok=True) - modiscolite.report.report_motifs( + fastermodiscolite.report.report_motifs( modisco_h5, output_dir.as_posix(), img_path_suffix, @@ -328,7 +332,7 @@ def modisco_reports( def modisco_seqlet_bed( output_prefix: str, modisco_h5: str, - metadata_anndata: str = None, + metadata_anndata: Optional[str] = None, trim_threshold: float = 0.2, ): """Extract seqlet locations from MoDISco results and save as BED format file. @@ -351,11 +355,13 @@ def modisco_seqlet_bed( ... trim_threshold=0.15, ... ) """ - result = DecimaResult.load(metadata_anndata) df = list() with h5py.File(modisco_h5, "r") as f: + model_name = f.attrs["model_names"].split(",")[0] + result = DecimaResult.load(metadata_anndata or model_name) + tss_distance = f.attrs["tss_distance"] genes = [gene.decode("utf-8") for gene in f["genes"][:]] genes_idx = dict(enumerate(genes)) @@ -420,7 +426,7 @@ def modisco( output_prefix: str, tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, - model: Optional[Union[str, int]] = 0, + model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE, tss_distance: int = 1000, metadata_anndata: Optional[str] = None, genes: Optional[List[str]] = None, @@ -429,6 +435,7 @@ def modisco( num_workers: int = 4, genome: str = "hg38", method: str = "saliency", + transform: Optional[str] = "specificity", batch_size: int = 2, device: Optional[str] = None, # tfmodisco parameters @@ -492,6 +499,7 @@ def modisco( num_workers: Number of workers for parallel processing default is 4. Increasing number of workers will speed up the process but requires more memory. genome: Genome reference to use default is "hg38". Can be genome name or path to custom genome fasta file. method: Method to use for attribution analysis default is "saliency". Available options: "saliency", "inputxgradient", "integratedgradients". For MoDISco, "saliency" is often preferred for pattern discovery. + transform: Transform to use for attribution analysis default is "specificity". Available options: "specificity", "aggregate". Specificity transform is recommended for MoDISco to highlight cell-type-specific patterns. batch_size: Batch size for attribution analysis default is 2. Increasing batch size may speed up computation but requires more memory. device: Device to use for computation (e.g. 'cuda', 'cpu'). If not provided, the best available device will be used automatically. sliding_window_size: Sliding window size. @@ -544,6 +552,7 @@ def modisco( genes=genes, top_n_markers=top_n_markers, method=method, + transform=transform, batch_size=batch_size, correct_grad_bigwig=correct_grad, device=device, diff --git a/src/decima/model/lightning.py b/src/decima/model/lightning.py index b94dd0e..1c20ac7 100644 --- a/src/decima/model/lightning.py +++ b/src/decima/model/lightning.py @@ -19,6 +19,7 @@ from torchmetrics import MetricCollection import safetensors +from decima.constants import DEFAULT_ENSEMBLE from decima.utils import get_compute_device from .decima_model import DecimaModel from .loss import TaskWisePoissonMultinomialLoss @@ -515,7 +516,7 @@ def load_safetensor(cls, path: str, device: str = "cpu"): class EnsembleLightningModel(LightningModel): - def __init__(self, models: List[LightningModel], name="ensemble"): + def __init__(self, models: List[LightningModel], name: str = DEFAULT_ENSEMBLE): super().__init__( name=name, model_params=models[0].model_params, @@ -524,7 +525,7 @@ def __init__(self, models: List[LightningModel], name="ensemble"): ) self.models = nn.ModuleList(models) self.reset_transform() - self.name = "ensemble" + self.name = DEFAULT_ENSEMBLE def forward(self, x: Tensor) -> Tensor: return torch.concat([model(x) for model in self.models], dim=0) diff --git a/src/decima/tools/inference.py b/src/decima/tools/inference.py index c31234f..448a8c8 100644 --- a/src/decima/tools/inference.py +++ b/src/decima/tools/inference.py @@ -1,6 +1,8 @@ -import anndata import logging +from typing import Optional +import anndata import numpy as np +from decima.constants import DEFAULT_ENSEMBLE from decima.data.dataset import GeneDataset from decima.hub import load_decima_model from decima.utils import get_compute_device @@ -8,15 +10,15 @@ def predict_gene_expression( genes=None, - model="ensemble", - metadata_anndata=None, - device=None, - batch_size=1, - num_workers=4, + model=DEFAULT_ENSEMBLE, + metadata_anndata: Optional[str] = None, + device: Optional[str] = None, + batch_size: int = 1, + num_workers: int = 4, max_seq_shift=0, - genome="hg38", - save_replicates=False, - float_precision="32", + genome: str = "hg38", + save_replicates: bool = False, + float_precision: str = "32", ): """Predict gene expression for a list of genes @@ -42,10 +44,13 @@ def predict_gene_expression( device = get_compute_device(device) logger.info(f"Using device: {device} and genome: {genome} for prediction.") - logger.info("Making predictions") + logger.info(f"Loading model {model}...") model = load_decima_model(model, device=device) - ds = GeneDataset(genes=genes, metadata_anndata=metadata_anndata, max_seq_shift=max_seq_shift, genome=genome) + logger.info("Making predictions") + ds = GeneDataset( + genes=genes, metadata_anndata=metadata_anndata or model.name, max_seq_shift=max_seq_shift, genome=genome + ) preds = model.predict_on_dataset( ds, device=device, batch_size=batch_size, num_workers=num_workers, float_precision=float_precision ) @@ -68,7 +73,10 @@ def predict_gene_expression( ad.layers[f"preds_{model.name}"] = pred.T logger.info("Evaluating performance") - evaluate_gene_expression_predictions(ad) + if ad.X is not None: + evaluate_gene_expression_predictions(ad) + else: + logger.warning("No ground truth expression matrix found in the metadata. Skipping evaluation.") return ad diff --git a/src/decima/utils/__init__.py b/src/decima/utils/__init__.py index a5ce9b7..c016533 100644 --- a/src/decima/utils/__init__.py +++ b/src/decima/utils/__init__.py @@ -77,9 +77,12 @@ def get_compute_device(device: Optional[str] = None) -> torch.device: torch.device: The selected device for computation """ if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - + return 0 if torch.cuda.is_available() else "cpu" + elif device == "cuda": + return 0 + elif isinstance(device, int) or (device == "cpu"): + return device elif isinstance(device, str) and device.isdigit(): - device = int(device) - - return torch.device(device) + return int(device) + else: + raise ValueError(f"Invalid device: {device}") diff --git a/src/decima/utils/io.py b/src/decima/utils/io.py index c8d5a4d..e695cda 100644 --- a/src/decima/utils/io.py +++ b/src/decima/utils/io.py @@ -183,7 +183,7 @@ class AttributionWriter: path: Output HDF5 file path. genes: Gene names to write. model_name: Model identifier for metadata. - metadata_anndata: Gene metadata source. None uses default Decima data. + metadata_anndata: Gene metadata source or path to the metadata anndata file. If not provided, the compatible metadata for the model will be used. genome: Reference genome version. bigwig: Create BigWig file for genome browser. correct_grad_bigwig: Correct gradient bigwig for bias. @@ -221,7 +221,7 @@ def __init__( self.bigwig = bigwig self.model_name = model_name self.idx = {g: i for i, g in enumerate(self.genes)} - self.result = DecimaResult.load(metadata_anndata) + self.result = DecimaResult.load(metadata_anndata or model_name) self.correct_grad_bigwig = correct_grad_bigwig self.custom_genes = custom_genes diff --git a/src/decima/vep/__init__.py b/src/decima/vep/__init__.py index 7b5b057..53fb49a 100644 --- a/src/decima/vep/__init__.py +++ b/src/decima/vep/__init__.py @@ -7,7 +7,7 @@ import pandas as pd from grelu.transforms.prediction_transforms import Aggregate -from decima.constants import SUPPORTED_GENOMES +from decima.constants import SUPPORTED_GENOMES, DEFAULT_ENSEMBLE from decima.model.metrics import WarningType from decima.utils import get_compute_device from decima.utils.dataframe import chunk_df, ChunkDataFrameWriter @@ -19,7 +19,7 @@ def _predict_variant_effect( df_variant: Union[pd.DataFrame, str], tasks: Optional[Union[str, List[str]]] = None, - model: Union[int, str] = "ensemble", + model: Union[str, int] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, batch_size: int = 1, num_workers: int = 16, @@ -61,8 +61,6 @@ def _predict_variant_effect( raise ValueError(f"Genome {genome} not supported. Currently only hg38 is supported.") include_cols = include_cols or list() - model = load_decima_model(model=model, device=device) - try: dataset = VariantDataset( df_variant, @@ -83,8 +81,6 @@ def _predict_variant_effect( else: raise e - model = load_decima_model(model=model) - if tasks is not None: tasks = dataset.result.query_cells(tasks) @@ -120,7 +116,7 @@ def predict_variant_effect( df_variant: Union[pd.DataFrame, str], output_pq: Optional[str] = None, tasks: Optional[Union[str, List[str]]] = None, - model: Union[int, str, List[str]] = "ensemble", + model: Union[int, str, List[str]] = DEFAULT_ENSEMBLE, metadata_anndata: Optional[str] = None, chunksize: int = 10_000, batch_size: int = 1, @@ -142,7 +138,7 @@ def predict_variant_effect( df_variant (pd.DataFrame or str): DataFrame with variant information or path to variant file output_pq (str, optional): Path to save the parquet file. Defaults to None. tasks (str, optional): Tasks to predict. Defaults to None. - model (int, optional): Model to use. Defaults to "ensemble". + model (int, optional): Model to use. Defaults to DEFAULT_ENSEMBLE. metadata_anndata (str, optional): Path to anndata file. Defaults to None. chunksize (int, optional): Number of variants to predict in each chunk. Defaults to 10_000. batch_size (int, optional): Batch size. Defaults to 1. diff --git a/tests/conftest.py b/tests/conftest.py index 27ec0c8..82d5336 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,7 +86,7 @@ def attribution_h5_file(tmp_path, attribution_data): f.create_dataset('attribution', data=attribution_data['attributions']) f.create_dataset('gene_mask_start', data=attribution_data['gene_mask_start']) f.create_dataset('gene_mask_end', data=attribution_data['gene_mask_end']) - f.attrs['model_name'] = 'test_model' + f.attrs['model_name'] = 'v1_rep0' f.attrs['genome'] = 'hg38' return h5_path diff --git a/tests/test_attribution.py b/tests/test_attribution.py index 48acd4a..d0d33f8 100644 --- a/tests/test_attribution.py +++ b/tests/test_attribution.py @@ -14,7 +14,7 @@ def test_AttributionResult(attribution_h5_file, attribution_data): with AttributionResult(str(attribution_h5_file), tss_distance=10_000, num_workers=1) as ar: assert len(ar.genes) == 10 assert ar.genes == attribution_data['genes'] - assert ar.model_name == 'test_model' + assert ar.model_name == 'v1_rep0' assert ar.genome == 'hg38' assert ar.genes == attribution_data['genes'] @@ -73,7 +73,7 @@ def test_AttributionResult(attribution_h5_file, attribution_data): with AttributionResult([str(attribution_h5_file), str(attribution_h5_file)], tss_distance=10_000) as ar: assert len(ar.genes) == 10 assert ar.genes == attribution_data['genes'] - assert ar.model_name == ['test_model', 'test_model'] + assert ar.model_name == ['v1_rep0', 'v1_rep0'] assert ar.genome == 'hg38' with AttributionResult(str(attribution_h5_file), tss_distance=1_000_000) as ar: diff --git a/tests/test_cli.py b/tests/test_cli.py index 60ba05c..4439c03 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,7 +6,7 @@ from decima.cli import main from conftest import device -from decima.constants import DECIMA_CONTEXT_SIZE +from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE def test_cli_main(): @@ -254,7 +254,7 @@ def test_cli_vep_all_tasks_ensemble_custom_genome(tmp_path): "vep", "-v", "tests/data/variants.tsv", "-o", str(output_file), - "--model", "ensemble", + "--model", DEFAULT_ENSEMBLE, "--device", device, "--max-distance", "20000", "--chunksize", "5", @@ -277,7 +277,7 @@ def test_cli_vep_all_tasks_ensemble(tmp_path): "vep", "-v", "tests/data/variants.tsv", "-o", str(output_file), - "--model", "ensemble", + "--model", DEFAULT_ENSEMBLE, "--device", device, "--max-distance", "20000", "--chunksize", "5", diff --git a/tests/test_interpret_attribution.py b/tests/test_interpret_attribution.py index 96014e3..e2bca28 100644 --- a/tests/test_interpret_attribution.py +++ b/tests/test_interpret_attribution.py @@ -214,9 +214,9 @@ def test_predict_save_attributions_single_gene(tmp_path): @pytest.mark.long_running def test_predict_save_attributions_single_gene_list_models(tmp_path): # download models - download_decima_weights(0, str(tmp_path)) - download_decima_weights(1, str(tmp_path)) - download_decima_metadata(str(tmp_path)) + download_decima_weights("v1_rep0", str(tmp_path)) + download_decima_weights("v1_rep1", str(tmp_path)) + download_decima_metadata("v1_rep0", str(tmp_path)) output_prefix = tmp_path / "SPI1" predict_attributions_seqlet_calling( diff --git a/tests/test_io.py b/tests/test_io.py index 23acc01..36a5a7f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -35,14 +35,14 @@ def test_AttributionWriter(tmp_path): nucleotide_indices = np.random.randint(0, 4, DECIMA_CONTEXT_SIZE) seqs[nucleotide_indices, np.arange(DECIMA_CONTEXT_SIZE)] = 1 - with AttributionWriter(str(h5_file), genes, "test_model", bigwig=False) as writer: + with AttributionWriter(str(h5_file), genes, "v1_rep0", bigwig=False) as writer: writer.add("STRADA", seqs, attrs) assert h5_file.exists() with h5py.File(h5_file, "r") as f: assert set(f.keys()) == {"genes", "attribution", "sequence", "gene_mask_start", "gene_mask_end"} - assert f.attrs["model_name"] == "test_model" + assert f.attrs["model_name"] == "v1_rep0" assert f["attribution"].shape == (1, 4, DECIMA_CONTEXT_SIZE) np.testing.assert_array_equal(convert_input_type(f["sequence"][0], "one_hot", input_type="indices"), seqs) np.testing.assert_array_almost_equal(f["attribution"][0], attrs) @@ -51,13 +51,13 @@ def test_AttributionWriter(tmp_path): assert f["gene_mask_end"][0] == 223490 h5_file = tmp_path / "test_bigwig.h5" - with AttributionWriter(str(h5_file), genes, "test_model", bigwig=True) as writer: + with AttributionWriter(str(h5_file), genes, "v1_rep0", bigwig=True) as writer: writer.add("STRADA", seqs, attrs) assert h5_file.with_suffix(".bigwig").exists() h5_file = tmp_path / "test_bigwig_custom.h5" - with AttributionWriter(str(h5_file), genes, "test_model", bigwig=True, custom_genes=True) as writer: + with AttributionWriter(str(h5_file), genes, "v1_rep0", bigwig=True, custom_genes=True) as writer: writer.add("STRADA", seqs, attrs, gene_mask_start=100, gene_mask_end=200) with h5py.File(h5_file, "r") as f: diff --git a/tests/test_lightning.py b/tests/test_lightning.py index d1af0b0..40287f1 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -1,6 +1,6 @@ import pytest import torch -from decima.constants import DECIMA_CONTEXT_SIZE, NUM_CELLS +from decima.constants import DECIMA_CONTEXT_SIZE, MODEL_METADATA, DEFAULT_ENSEMBLE from decima.data.dataset import VariantDataset from decima.model.lightning import LightningModel, GeneMaskLightningModel from decima.model.metrics import WarningType @@ -10,31 +10,34 @@ @pytest.fixture def lightning_model(): - model = LightningModel(model_params={'n_tasks': NUM_CELLS, 'init_borzoi': False}, name='v1_rep0').to(device) + model_name = "v1_rep0" + metadata = MODEL_METADATA[model_name] + model = LightningModel(model_params={'n_tasks': metadata['num_tasks'], 'init_borzoi': False}, name=model_name).to(device) return model @pytest.mark.long_running def test_LightningModel_predict_step(lightning_model): + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] seq = torch.randn(1, 5, DECIMA_CONTEXT_SIZE).to(device) preds = lightning_model.predict_step(seq, 0) - assert preds.shape == (1, NUM_CELLS, 1) + assert preds.shape == (1, metadata['num_tasks'], 1) batch = {"seq": seq, "warning": [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME]} results = lightning_model.predict_step(batch, 1) - assert results["expression"].shape == (1, NUM_CELLS, 1) + assert results["expression"].shape == (1, metadata['num_tasks'], 1) assert results["warnings"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME] batch = { "seq": seq.to(device), "warning": [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME], - "pred_expr": {"v1_rep0": torch.ones((1, NUM_CELLS), device=device)} + "pred_expr": {"v1_rep0": torch.ones((1, metadata['num_tasks']), device=device)} } results = lightning_model.predict_step(batch, 1) - assert results["expression"].shape == (1, NUM_CELLS, 1) - assert (results["expression"] == torch.ones((1, NUM_CELLS, 1), device=device)).all() + assert results["expression"].shape == (1, metadata['num_tasks'], 1) + assert (results["expression"] == torch.ones((1, metadata['num_tasks'], 1), device=device)).all() assert results["warnings"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME] @@ -42,7 +45,10 @@ def test_LightningModel_predict_step(lightning_model): def test_LightningModel_predict_on_dataset(lightning_model, df_variant): dataset = VariantDataset(df_variant, model_name="v1_rep0") results = lightning_model.predict_on_dataset(dataset) - assert results["expression"].shape == (82, NUM_CELLS) + + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] + + assert results["expression"].shape == (82, metadata["num_tasks"]) assert results["warnings"]['unknown'] == 0 assert results["warnings"]['allele_mismatch_with_reference_genome'] == 13 @@ -51,7 +57,8 @@ def test_LightningModel_predict_on_dataset(lightning_model, df_variant): def test_LightningModel_predict_on_dataset_ensemble(lightning_model, df_variant): dataset = VariantDataset(df_variant) results = lightning_model.predict_on_dataset(dataset) - assert results["expression"].shape == (82, NUM_CELLS) + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] + assert results["expression"].shape == (82, metadata["num_tasks"]) assert results["warnings"]['unknown'] == 0 assert results["warnings"]['allele_mismatch_with_reference_genome'] == 13 @@ -59,9 +66,10 @@ def test_LightningModel_predict_on_dataset_ensemble(lightning_model, df_variant) @pytest.mark.long_running def test_GeneMaskLightningModel_forward(): seq = torch.randn(1, 4, DECIMA_CONTEXT_SIZE).to(device) + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] model = GeneMaskLightningModel( gene_mask_start=200_000, gene_mask_end=300_000, - model_params={"n_tasks": NUM_CELLS, "init_borzoi": False}, name="v1_rep0" + model_params={"n_tasks": metadata["num_tasks"], "init_borzoi": False}, name=metadata["name"] ).to(device) preds = model(seq) - assert preds.shape == (1, NUM_CELLS, 1) + assert preds.shape == (1, metadata["num_tasks"], 1) diff --git a/tests/test_predict_gene_expression.py b/tests/test_predict_gene_expression.py index 66954e9..2dd54c6 100644 --- a/tests/test_predict_gene_expression.py +++ b/tests/test_predict_gene_expression.py @@ -1,4 +1,5 @@ import pytest +from decima.constants import DEFAULT_ENSEMBLE from decima.tools.inference import predict_gene_expression from conftest import device @@ -19,7 +20,7 @@ def test_predict_gene_expression(): ad = predict_gene_expression( genes=["SPI1", "GATA1"], - model="ensemble", device=device, + model=DEFAULT_ENSEMBLE, device=device, save_replicates=True, ) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index c0876aa..e2a964e 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,8 +1,8 @@ - +from decima.constants import DECIMA_CONTEXT_SIZE from decima.utils.sequence import prepare_mask_gene def test_mask_gene(): mask = prepare_mask_gene(100, 200) - assert mask.shape == (1, 524288) + assert mask.shape == (1, DECIMA_CONTEXT_SIZE) assert mask[0, 150].item() == 1.0 diff --git a/tests/test_vep.py b/tests/test_vep.py index 650ffd0..45ad91b 100644 --- a/tests/test_vep.py +++ b/tests/test_vep.py @@ -5,6 +5,7 @@ import pyarrow.parquet as pq from scipy.stats import pearsonr +from decima.constants import DEFAULT_ENSEMBLE, DECIMA_CONTEXT_SIZE, MODEL_METADATA from decima.core.result import DecimaResult from decima.hub import load_decima_model from decima.data.dataset import VariantDataset @@ -92,9 +93,10 @@ def test_VariantDataset(df_variant): ] assert len(dataset) == 82 * 2 - assert dataset[0]['seq'].shape == (5, 524288) + assert dataset[0]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) - assert dataset[0]['pred_expr']['v1_rep0'].shape == (8856,) + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] + assert dataset[0]['pred_expr']['v1_rep0'].shape == (metadata['num_tasks'],) assert not dataset[0]['pred_expr']['v1_rep0'].isnan().any() assert dataset[1]['pred_expr']['v1_rep0'].isnan().all() assert not dataset[2]['pred_expr']['v1_rep0'].isnan().any() @@ -113,7 +115,7 @@ def test_VariantDataset(df_variant): assert cols.tolist() == [38435, 38435] # should be the same for both for i in range(len(dataset)): - assert dataset[i]['seq'].shape == (5, 524288) + assert dataset[i]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) rows, cols = np.where(dataset[162]['seq'] != dataset[163]['seq']) assert cols.min() == 505705 # the positions before should not be effected. @@ -122,11 +124,11 @@ def test_VariantDataset(df_variant): dataset = VariantDataset(df_variant, max_seq_shift=100) assert len(dataset) == 82 * 2 * 201 - assert dataset[0]['seq'].shape == (5, 524288) + assert dataset[0]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) for i in range(20): assert dataset[i]["warning"] == [] - assert dataset[i]['seq'].shape == (5, 524288) + assert dataset[i]['seq'].shape == (5, DECIMA_CONTEXT_SIZE) assert dataset[44 * 2 * 201]["warning"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME] @@ -138,58 +140,60 @@ def test_VariantDataset(df_variant): @pytest.mark.long_running def test_VariantDataset_dataloader(df_variant): - dataset = VariantDataset(df_variant, model_name="ensemble") + dataset = VariantDataset(df_variant, model_name=DEFAULT_ENSEMBLE) dl = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=0, collate_fn=dataset.collate_fn) batches = iter(dl) + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] batch = next(batches) - assert batch["seq"].shape == (64, 5, 524288) + assert batch["seq"].shape == (64, 5, DECIMA_CONTEXT_SIZE) assert batch["warning"] == [] - assert batch["pred_expr"]["v1_rep0"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep1"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep2"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep3"].shape == (64, 8856) + assert batch["pred_expr"]["v1_rep0"].shape == (64, metadata['num_tasks']) + assert batch["pred_expr"]["v1_rep1"].shape == (64, metadata['num_tasks']) + assert batch["pred_expr"]["v1_rep2"].shape == (64, metadata['num_tasks']) + assert batch["pred_expr"]["v1_rep3"].shape == (64, metadata['num_tasks']) batch = next(batches) - assert batch["seq"].shape == (64, 5, 524288) + assert batch["seq"].shape == (64, 5, DECIMA_CONTEXT_SIZE) assert len(batch["warning"]) > 0 assert WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME in batch["warning"] - assert batch["pred_expr"]["v1_rep0"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep1"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep2"].shape == (64, 8856) - assert batch["pred_expr"]["v1_rep3"].shape == (64, 8856) + assert batch["pred_expr"]["v1_rep0"].shape == (64, metadata['num_tasks']) + assert batch["pred_expr"]["v1_rep1"].shape == (64, metadata['num_tasks']) + assert batch["pred_expr"]["v1_rep2"].shape == (64, metadata['num_tasks']) + assert batch["pred_expr"]["v1_rep3"].shape == (64, metadata['num_tasks']) @pytest.mark.long_running def test_VariantDataset_dataloader_vcf(): df_variant = next(read_vcf_chunks("tests/data/test.vcf", 10000)) - dataset = VariantDataset(df_variant, model_name="ensemble", max_distance=20000) + dataset = VariantDataset(df_variant, model_name=DEFAULT_ENSEMBLE, max_distance=20000) dl = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=0, collate_fn=dataset.collate_fn) batches = iter(dl) + metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]] batch = next(batches) - assert batch["seq"].shape == (8, 5, 524288) + assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE) assert batch["warning"] == [] - assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856) + assert batch["pred_expr"]['v1_rep0'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep1'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep2'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep3'].shape == (8, metadata['num_tasks']) batch = next(batches) - assert batch["seq"].shape == (8, 5, 524288) + assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE) assert batch["warning"] == [] - assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856) + assert batch["pred_expr"]['v1_rep0'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep1'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep2'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep3'].shape == (8, metadata['num_tasks']) batch = next(batches) - assert batch["seq"].shape == (8, 5, 524288) + assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE) assert len(batch["warning"]) > 0 - assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856) - assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856) + assert batch["pred_expr"]['v1_rep0'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep1'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep2'].shape == (8, metadata['num_tasks']) + assert batch["pred_expr"]['v1_rep3'].shape == (8, metadata['num_tasks']) @pytest.mark.long_running @@ -198,7 +202,8 @@ def test_predict_variant_effect(df_variant): query = "cell_type == 'CD8-positive, alpha-beta T cell'" cells = DecimaResult.load().query_cells(query) - df, warnings, num_variants = _predict_variant_effect(df_variant, model=0, tasks=query, device=device, max_distance=5000) + model = load_decima_model(0, device) + df, warnings, num_variants = _predict_variant_effect(df_variant, model=model, tasks=query, device=device, max_distance=5000) assert num_variants == 4 assert df.shape == (4, 273) @@ -225,7 +230,7 @@ def test_predict_variant_effect_save(df_variant, tmp_path): predict_variant_effect( df_variant, output_pq=str(output_file), - model="ensemble", + model=DEFAULT_ENSEMBLE, tasks=query, device=device, max_distance=5000, @@ -305,7 +310,7 @@ def test_predict_variant_effect_vcf_ensemble(tmp_path): predict_variant_effect( "tests/data/test.vcf", output_pq=str(output_file), - model="ensemble", + model=DEFAULT_ENSEMBLE, device=device, max_distance=20000, ) @@ -321,7 +326,7 @@ def test_predict_variant_effect_vcf_ensemble_replicates(tmp_path): predict_variant_effect( "tests/data/test.vcf", output_pq=str(output_file), - model="ensemble", + model=DEFAULT_ENSEMBLE, device=device, max_distance=20000, save_replicates=True,