diff --git a/.gitignore b/.gitignore
index a6a1268..c599a11 100644
--- a/.gitignore
+++ b/.gitignore
@@ -186,3 +186,4 @@ wandb/
docs/api
docs/tutorials/example
docs/wandb
+run_tutorial.s*
diff --git a/docs/tutorials/1-attribution-motif-discovery.ipynb b/docs/tutorials/1-attribution-motif-discovery.ipynb
index c9f4400..76a75d5 100644
--- a/docs/tutorials/1-attribution-motif-discovery.ipynb
+++ b/docs/tutorials/1-attribution-motif-discovery.ipynb
@@ -122,7 +122,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -159,7 +159,7 @@
}
],
"source": [
- "! decima attributions --model 0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes"
+ "! decima attributions --model v1_rep0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes"
]
},
{
@@ -311,7 +311,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -338,12 +338,12 @@
}
],
"source": [
- "! decima attributions-predict --model 0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_0"
+ "! decima attributions-predict --model v1_rep0 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_0"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -370,7 +370,7 @@
}
],
"source": [
- "! decima attributions-predict --model 1 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_1"
+ "! decima attributions-predict --model v1_rep1 --genes \"SPI1,BRD3\" --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_classical_monoctypes_1"
]
},
{
@@ -1005,7 +1005,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -1041,7 +1041,7 @@
}
],
"source": [
- "! decima attributions --model 0 --seqs ../tests/data/seqs.fasta --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_custom_seqs"
+ "! decima attributions --model v1_rep0 --seqs ../tests/data/seqs.fasta --tasks \"cell_type == 'classical monocyte'\" --output-prefix example/output_custom_seqs"
]
},
{
diff --git a/docs/tutorials/3-finetune.html b/docs/tutorials/3-finetune.html
deleted file mode 100644
index 91e1626..0000000
--- a/docs/tutorials/3-finetune.html
+++ /dev/null
@@ -1,10657 +0,0 @@
-
-
-
-
-
-3-finetune
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[3]:
-
-
AnnData object with n_obs × n_vars = 50 × 988
- obs: 'cell_type', 'tissue', 'disease', 'study'
- var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset'
- uns: 'log1p'
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[4]:
-
-
-
-
-
-
- |
-cell_type |
-tissue |
-disease |
-study |
-
-
-
-
-| pseudobulk_0 |
-ct_0 |
-t_0 |
-d_0 |
-st_0 |
-
-
-| pseudobulk_1 |
-ct_0 |
-t_0 |
-d_1 |
-st_0 |
-
-
-| pseudobulk_2 |
-ct_0 |
-t_0 |
-d_2 |
-st_1 |
-
-
-| pseudobulk_3 |
-ct_0 |
-t_0 |
-d_0 |
-st_1 |
-
-
-| pseudobulk_4 |
-ct_0 |
-t_0 |
-d_1 |
-st_2 |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[5]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-strand |
-gene_start |
-gene_end |
-gene_length |
-gene_mask_start |
-gene_mask_end |
-dataset |
-
-
-
-
-| gene_0 |
-chr1 |
-28320920 |
-28845208 |
-+ |
-28484760 |
-29009048 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_1 |
-chr19 |
-39145337 |
-39669625 |
-- |
-38981497 |
-39505785 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_2 |
-chr1 |
-77807946 |
-78332234 |
-- |
-77644106 |
-78168394 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_3 |
-chr8 |
-143094013 |
-143618301 |
-- |
-142930173 |
-143454461 |
-524288 |
-163840 |
-524288 |
-val |
-
-
-| gene_4 |
-chr16 |
-1775288 |
-2299576 |
-- |
-1611448 |
-2135736 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[6]:
-
-
array([[0. , 7.2155137, 7.3277392, 0. , 7.2698054],
- [7.1914983, 7.3387527, 0. , 7.2105823, 7.180787 ],
- [7.045969 , 7.2056117, 7.15802 , 7.289302 , 7.282388 ],
- [7.2008514, 0. , 7.2667375, 7.321583 , 7.2398143],
- [7.2582483, 6.723016 , 0. , 0. , 7.3626666]],
- dtype=float32)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
WARNING: adata.X seems to be already log-transformed.
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[8]:
-
-
array([[0. , 7.2337112, 7.2491336, 0. , 7.241202 ],
- [7.2420583, 7.262313 , 0. , 7.244706 , 7.2405686],
- [7.207595 , 7.229983 , 7.223361 , 7.2415223, 7.240574 ],
- [7.2279363, 0. , 7.237038 , 7.2445517, 7.233329 ],
- [7.2675843, 7.191038 , 0. , 0. , 7.281858 ]],
- dtype=float32)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
- warnings.warn(
-/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
- warnings.warn(
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[10]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-strand |
-gene_start |
-gene_end |
-gene_length |
-gene_mask_start |
-gene_mask_end |
-dataset |
-
-
-
-
-| gene_0 |
-chr1 |
-28320920 |
-28845208 |
-+ |
-28484760 |
-29009048 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_1 |
-chr19 |
-39145337 |
-39669625 |
-- |
-38981497 |
-39505785 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_2 |
-chr1 |
-77807946 |
-78332234 |
-- |
-77644106 |
-78168394 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_3 |
-chr8 |
-143094013 |
-143618301 |
-- |
-142930173 |
-143454461 |
-524288 |
-163840 |
-524288 |
-val |
-
-
-| gene_4 |
-chr16 |
-1775288 |
-2299576 |
-- |
-1611448 |
-2135736 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[12]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-strand |
-gene_start |
-gene_end |
-gene_length |
-gene_mask_start |
-gene_mask_end |
-dataset |
-
-
-
-
-| gene_0 |
-chr1 |
-28320920 |
-28845208 |
-+ |
-28320920 |
-28845208 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_1 |
-chr19 |
-39145337 |
-39669625 |
-- |
-39145337 |
-39669625 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_2 |
-chr1 |
-77807946 |
-78332234 |
-- |
-77807946 |
-78332234 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_3 |
-chr8 |
-143094013 |
-143618301 |
-- |
-143094013 |
-143618301 |
-524288 |
-163840 |
-524288 |
-val |
-
-
-| gene_4 |
-chr16 |
-1775288 |
-2299576 |
-- |
-1775288 |
-2299576 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start.
-0 intervals extended beyond the chromosome start and have been shifted
-7 intervals extended beyond the chromosome end and have been shifted
-7 intervals did not extend far enough upstream of the TSS and have been dropped
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[14]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-strand |
-gene_start |
-gene_end |
-gene_length |
-gene_mask_start |
-gene_mask_end |
-dataset |
-
-
-
-
-| gene_0 |
-chr1 |
-28157080 |
-28681368 |
-+ |
-28320920 |
-28845208 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_1 |
-chr19 |
-39309177 |
-39833465 |
-- |
-39145337 |
-39669625 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_2 |
-chr1 |
-77971786 |
-78496074 |
-- |
-77807946 |
-78332234 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_3 |
-chr8 |
-143257853 |
-143782141 |
-- |
-143094013 |
-143618301 |
-524288 |
-163840 |
-524288 |
-val |
-
-
-| gene_4 |
-chr16 |
-1939128 |
-2463416 |
-- |
-1775288 |
-2299576 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[15]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-fold |
-
-
-
-
-| 0 |
-chr4 |
-82524421 |
-82721029 |
-fold0 |
-
-
-| 1 |
-chr13 |
-18604798 |
-18801406 |
-fold0 |
-
-
-| 2 |
-chr2 |
-189923408 |
-190120016 |
-fold0 |
-
-
-| 3 |
-chr10 |
-59875743 |
-60072351 |
-fold0 |
-
-
-| 4 |
-chr1 |
-117109467 |
-117306075 |
-fold0 |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[16]:
-
-
-
-
-
-
- |
-gene |
-fold_ |
-
-
-
-
-| 0 |
-gene_0 |
-fold5 |
-
-
-| 15 |
-gene_1 |
-fold0 |
-
-
-| 30 |
-gene_2 |
-fold0 |
-
-
-| 44 |
-gene_3 |
-fold4 |
-
-
-| 58 |
-gene_4 |
-fold0 |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
/tmp/slurmjob.11843307/ipykernel_1446559/3109841685.py:1: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual.
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[19]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-strand |
-gene_start |
-gene_end |
-gene_length |
-gene_mask_start |
-gene_mask_end |
-dataset |
-
-
-
-
-| gene_0 |
-chr1 |
-28157080 |
-28681368 |
-+ |
-28320920 |
-28845208 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_1 |
-chr19 |
-39309177 |
-39833465 |
-- |
-39145337 |
-39669625 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_2 |
-chr1 |
-77971786 |
-78496074 |
-- |
-77807946 |
-78332234 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-| gene_3 |
-chr8 |
-143257853 |
-143782141 |
-- |
-143094013 |
-143618301 |
-524288 |
-163840 |
-524288 |
-val |
-
-
-| gene_4 |
-chr16 |
-1939128 |
-2463416 |
-- |
-1775288 |
-2299576 |
-524288 |
-163840 |
-524288 |
-train |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[20]:
-
-
dataset
-train 803
-test 99
-val 79
-Name: count, dtype: int64
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Writing metadata
-Writing task indices
-Writing genes array of shape: (981, 2)
-Writing labels array of shape: (981, 50, 1)
-Making gene masks
-
-
-
-
-
-
-
Writing mask array of shape: (981, 534288)
-
-
-
-
-
-
-
-
Writing sequence array of shape: (981, 534288)
-Done!
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
-decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
-decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
-decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
lightning_logs/ie4tgmpg/checkpoints/epoch=0-step=42.ckpt
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[30]:
-
-
'lightning_logs/ie4tgmpg/checkpoints/epoch=0-step=42.ckpt,lightning_logs/ie4tgmpg/checkpoints/epoch=0-step=42.ckpt'
-
-
-
-
-
-
-
-
-
-
-
-
-
-
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
- warnings.warn(
-/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.
- warnings.warn(
-
-
-
-
-
-
-
decima - INFO - Using device: cuda:0 and genome: hg38 for prediction.
-decima - INFO - Making predictions
-
-
-
-
-
-
-
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
-
-
-
-
-
-
-
wandb: Currently logged in as: anony-mouse-591272909468377997 to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
-
-
-
-
-
-
-
wandb: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
-
-
-
-
-
-
-
wandb: 1 of 1 files downloaded.
-Done. 00:00:00.6 (1235.1MB/s)
-
-
-
-
-
-
-
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
-
-
-
-
-
-
-
wandb: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
-
-
-
-
-
-
-
wandb: 1 of 1 files downloaded.
-Done. 00:00:00.5 (1299.7MB/s)
-
-
-
-
-
-
-
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0
-
-
-
-
-
-
-
wandb: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
-
-
-
-
-
-
-
wandb: 1 of 1 files downloaded.
-Done. 00:00:00.6 (1244.1MB/s)
-
-
-
-
-
-
-
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
-/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /home/celikm5/miniforge3/envs/decima2/bin/decima ...
-💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
-
-
-
-
-
-
-
GPU available: True (cuda), used: True
-TPU available: False, using: 0 TPU cores
-HPU available: False, using: 0 HPUs
-/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
-
-
-
-
-
-
-
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
-
-
-
-
-
-
-
-Predicting: | | 0/? [00:00<?, ?it/s]
-
-
-
-
-
-
-Predicting: | | 0/? [00:00<?, ?it/s]
-Predicting DataLoader 0: 0%| | 0/123 [00:00<?, ?it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 1%|▏ | 1/123 [00:02<04:57, 0.41it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 2%|▎ | 2/123 [00:04<04:03, 0.50it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 2%|▍ | 3/123 [00:05<03:48, 0.52it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 3%|▌ | 4/123 [00:07<03:40, 0.54it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 4%|▋ | 5/123 [00:09<03:34, 0.55it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 5%|▉ | 6/123 [00:10<03:30, 0.56it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 6%|█ | 7/123 [00:12<03:26, 0.56it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 7%|█▏ | 8/123 [00:14<03:23, 0.56it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 7%|█▎ | 9/123 [00:15<03:21, 0.57it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 8%|█▍ | 10/123 [00:17<03:18, 0.57it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 9%|█▌ | 11/123 [00:19<03:16, 0.57it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 10%|█▋ | 12/123 [00:20<03:13, 0.57it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 11%|█▊ | 13/123 [00:22<03:11, 0.57it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 11%|█▉ | 14/123 [00:24<03:09, 0.57it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 12%|██ | 15/123 [00:26<03:07, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 13%|██▏ | 16/123 [00:27<03:05, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 14%|██▎ | 17/123 [00:29<03:03, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 15%|██▍ | 18/123 [00:31<03:01, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 15%|██▋ | 19/123 [00:32<02:59, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 16%|██▊ | 20/123 [00:34<02:58, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 17%|██▉ | 21/123 [00:36<02:56, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 18%|███ | 22/123 [00:37<02:54, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 19%|███▏ | 23/123 [00:39<02:52, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 20%|███▎ | 24/123 [00:41<02:50, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 20%|███▍ | 25/123 [00:43<02:48, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 21%|███▌ | 26/123 [00:44<02:47, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 22%|███▋ | 27/123 [00:46<02:45, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 23%|███▊ | 28/123 [00:48<02:43, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 24%|████ | 29/123 [00:49<02:41, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 24%|████▏ | 30/123 [00:51<02:39, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 25%|████▎ | 31/123 [00:53<02:38, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 26%|████▍ | 32/123 [00:54<02:36, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 27%|████▌ | 33/123 [00:56<02:34, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 28%|████▋ | 34/123 [00:58<02:32, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 28%|████▊ | 35/123 [01:00<02:31, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 29%|████▉ | 36/123 [01:01<02:29, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 30%|█████ | 37/123 [01:03<02:27, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 31%|█████▎ | 38/123 [01:05<02:25, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 32%|█████▍ | 39/123 [01:06<02:24, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 33%|█████▌ | 40/123 [01:08<02:22, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 33%|█████▋ | 41/123 [01:10<02:20, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 34%|█████▊ | 42/123 [01:11<02:18, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 35%|█████▉ | 43/123 [01:13<02:17, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 36%|██████ | 44/123 [01:15<02:15, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 37%|██████▏ | 45/123 [01:17<02:13, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 37%|██████▎ | 46/123 [01:18<02:11, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 38%|██████▍ | 47/123 [01:20<02:10, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 39%|██████▋ | 48/123 [01:22<02:08, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 40%|██████▊ | 49/123 [01:23<02:06, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 41%|██████▉ | 50/123 [01:25<02:04, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 41%|███████ | 51/123 [01:27<02:03, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 42%|███████▏ | 52/123 [01:29<02:01, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 43%|███████▎ | 53/123 [01:30<01:59, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 44%|███████▍ | 54/123 [01:32<01:58, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 45%|███████▌ | 55/123 [01:34<01:56, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 46%|███████▋ | 56/123 [01:35<01:54, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 46%|███████▉ | 57/123 [01:37<01:52, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 47%|████████ | 58/123 [01:39<01:51, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 48%|████████▏ | 59/123 [01:40<01:49, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 49%|████████▎ | 60/123 [01:42<01:47, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 50%|████████▍ | 61/123 [01:44<01:46, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 50%|████████▌ | 62/123 [01:46<01:44, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 51%|████████▋ | 63/123 [01:47<01:42, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 52%|████████▊ | 64/123 [01:49<01:40, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 53%|████████▉ | 65/123 [01:51<01:39, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 54%|█████████ | 66/123 [01:52<01:37, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 54%|█████████▎ | 67/123 [01:54<01:35, 0.58it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 55%|█████████▍ | 68/123 [01:56<01:34, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 56%|█████████▌ | 69/123 [01:57<01:32, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 57%|█████████▋ | 70/123 [01:59<01:30, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 58%|█████████▊ | 71/123 [02:01<01:28, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 59%|█████████▉ | 72/123 [02:03<01:27, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 59%|██████████ | 73/123 [02:04<01:25, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 60%|██████████▏ | 74/123 [02:06<01:23, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 61%|██████████▎ | 75/123 [02:08<01:22, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 62%|██████████▌ | 76/123 [02:09<01:20, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 63%|██████████▋ | 77/123 [02:11<01:18, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 63%|██████████▊ | 78/123 [02:13<01:16, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 64%|██████████▉ | 79/123 [02:14<01:15, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 65%|███████████ | 80/123 [02:16<01:13, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 66%|███████████▏ | 81/123 [02:18<01:11, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 67%|███████████▎ | 82/123 [02:20<01:10, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 67%|███████████▍ | 83/123 [02:21<01:08, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 68%|███████████▌ | 84/123 [02:23<01:06, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 69%|███████████▋ | 85/123 [02:25<01:04, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 70%|███████████▉ | 86/123 [02:26<01:03, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 71%|████████████ | 87/123 [02:28<01:01, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 72%|████████████▏ | 88/123 [02:30<00:59, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 72%|████████████▎ | 89/123 [02:31<00:58, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 73%|████████████▍ | 90/123 [02:33<00:56, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 74%|████████████▌ | 91/123 [02:35<00:54, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 75%|████████████▋ | 92/123 [02:37<00:52, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 76%|████████████▊ | 93/123 [02:38<00:51, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 76%|████████████▉ | 94/123 [02:40<00:49, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 77%|█████████████▏ | 95/123 [02:42<00:47, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 78%|█████████████▎ | 96/123 [02:43<00:46, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 79%|█████████████▍ | 97/123 [02:45<00:44, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 80%|█████████████▌ | 98/123 [02:47<00:42, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 80%|█████████████▋ | 99/123 [02:49<00:40, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 81%|█████████████ | 100/123 [02:50<00:39, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 82%|█████████████▏ | 101/123 [02:52<00:37, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 83%|█████████████▎ | 102/123 [02:54<00:35, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 84%|█████████████▍ | 103/123 [02:55<00:34, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 85%|█████████████▌ | 104/123 [02:57<00:32, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 85%|█████████████▋ | 105/123 [02:59<00:30, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 86%|█████████████▊ | 106/123 [03:00<00:29, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 87%|█████████████▉ | 107/123 [03:02<00:27, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 88%|██████████████ | 108/123 [03:04<00:25, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 89%|██████████████▏ | 109/123 [03:06<00:23, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 89%|██████████████▎ | 110/123 [03:07<00:22, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 90%|██████████████▍ | 111/123 [03:09<00:20, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 91%|██████████████▌ | 112/123 [03:11<00:18, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 92%|██████████████▋ | 113/123 [03:12<00:17, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 93%|██████████████▊ | 114/123 [03:14<00:15, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 93%|██████████████▉ | 115/123 [03:16<00:13, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 94%|███████████████ | 116/123 [03:17<00:11, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 95%|███████████████▏| 117/123 [03:19<00:10, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 96%|███████████████▎| 118/123 [03:21<00:08, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 97%|███████████████▍| 119/123 [03:23<00:06, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 98%|███████████████▌| 120/123 [03:24<00:05, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 98%|███████████████▋| 121/123 [03:26<00:03, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 99%|███████████████▊| 122/123 [03:28<00:01, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 100%|████████████████| 123/123 [03:29<00:00, 0.59it/s]
-
-
-
-
-
-
-Predicting DataLoader 0: 100%|████████████████| 123/123 [03:29<00:00, 0.59it/s]
-
-
-
-
-
-
-
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
-decima - INFO - Creating anndata
-decima - INFO - Evaluating performance
-
-
-
-
-
-
-
Performance on genes in the train dataset.
-Mean Pearson Correlation per gene: Mean: 0.00.
-Mean Pearson Correlation per gene using size factor (baseline): 0.02.
-Mean Pearson Correlation per pseudobulk: 0.00
-
-Performance on genes in the val dataset.
-Mean Pearson Correlation per gene: Mean: 0.04.
-Mean Pearson Correlation per gene using size factor (baseline): 0.05.
-Mean Pearson Correlation per pseudobulk: 0.02
-
-
-
-
-
-
-
-
Performance on genes in the test dataset.
-Mean Pearson Correlation per gene: Mean: -0.02.
-Mean Pearson Correlation per gene using size factor (baseline): 0.01.
-Mean Pearson Correlation per pseudobulk: 0.01
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[33]:
-
-
AnnData object with n_obs × n_vars = 50 × 981
- obs: 'cell_type', 'tissue', 'disease', 'study', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson'
- var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset', 'pearson', 'size_factor_pearson'
- layers: 'preds', 'preds_finetune_test_0'
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[34]:
-
-
-
-
-
-
- |
-cell_type |
-tissue |
-disease |
-study |
-size_factor |
-train_pearson |
-val_pearson |
-test_pearson |
-
-
-
-
-| pseudobulk_0 |
-ct_0 |
-t_0 |
-d_0 |
-st_0 |
-5193.049805 |
-0.070174 |
-0.214402 |
-0.088188 |
-
-
-| pseudobulk_1 |
-ct_0 |
-t_0 |
-d_1 |
-st_0 |
-5137.830566 |
--0.004344 |
-0.058580 |
--0.015836 |
-
-
-| pseudobulk_2 |
-ct_0 |
-t_0 |
-d_2 |
-st_1 |
-5198.248535 |
-0.022892 |
-0.212270 |
--0.026279 |
-
-
-| pseudobulk_3 |
-ct_0 |
-t_0 |
-d_0 |
-st_1 |
-5204.543457 |
-0.067001 |
--0.053795 |
-0.041648 |
-
-
-| pseudobulk_4 |
-ct_0 |
-t_0 |
-d_1 |
-st_2 |
-5056.311523 |
-0.009684 |
-0.001823 |
-0.020882 |
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Out[35]:
-
-
-
-
-
-
- |
-chrom |
-start |
-end |
-strand |
-gene_start |
-gene_end |
-gene_length |
-gene_mask_start |
-gene_mask_end |
-dataset |
-pearson |
-size_factor_pearson |
-
-
-
-
-| gene_0 |
-chr1 |
-28157080 |
-28681368 |
-+ |
-28320920 |
-28845208 |
-524288 |
-163840 |
-524288 |
-train |
-0.042477 |
--0.036051 |
-
-
-| gene_1 |
-chr19 |
-39309177 |
-39833465 |
-- |
-39145337 |
-39669625 |
-524288 |
-163840 |
-524288 |
-train |
-0.041681 |
--0.075098 |
-
-
-| gene_2 |
-chr1 |
-77971786 |
-78496074 |
-- |
-77807946 |
-78332234 |
-524288 |
-163840 |
-524288 |
-train |
--0.070010 |
-0.220900 |
-
-
-| gene_3 |
-chr8 |
-143257853 |
-143782141 |
-- |
-143094013 |
-143618301 |
-524288 |
-163840 |
-524288 |
-val |
--0.104826 |
-0.128605 |
-
-
-| gene_4 |
-chr16 |
-1939128 |
-2463416 |
-- |
-1775288 |
-2299576 |
-524288 |
-163840 |
-524288 |
-train |
--0.082712 |
--0.001255 |
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/docs/tutorials/3-finetune.ipynb b/docs/tutorials/3-finetune.ipynb
index 2c55b93..eecb02e 100644
--- a/docs/tutorials/3-finetune.ipynb
+++ b/docs/tutorials/3-finetune.ipynb
@@ -1507,7 +1507,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"id": "d0fdaa9d",
"metadata": {},
"outputs": [
@@ -2579,7 +2579,7 @@
"source": [
"! CUDA_VISIBLE_DEVICES=0 decima finetune \\\n",
"--name finetune_test_0 \\\n",
- "--model 0 \\\n",
+ "--model v1_rep0 \\\n",
"--device 0 \\\n",
"--matrix-file {ad_file_path} \\\n",
"--h5-file {h5_file_path} \\\n",
diff --git a/docs/tutorials/4-modisco.ipynb b/docs/tutorials/4-modisco.ipynb
index 24c7c5a..3944dd6 100644
--- a/docs/tutorials/4-modisco.ipynb
+++ b/docs/tutorials/4-modisco.ipynb
@@ -640,7 +640,7 @@
" --tasks \"cell_type.str.contains('neuron') and organ == 'CNS' and disease == 'healthy'\" \\\n",
" --transform \"specificity\" \\\n",
" --batch-size 1 \\\n",
- " --model 0 \\\n",
+ " --model v1_rep0 \\\n",
" --num-workers 8 \\\n",
" -o example/modisco_neurons"
]
@@ -1511,7 +1511,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "3b3faf06",
"metadata": {},
"outputs": [
@@ -1547,7 +1547,7 @@
" --top-n-markers 50 \\\n",
" --tasks \"cell_type.str.contains('neuron') and organ == 'CNS' and disease == 'healthy'\" \\\n",
" --batch-size 1 \\\n",
- " --model 0 \\\n",
+ " --model v1_rep0 \\\n",
" --num-workers 8 \\\n",
" -o example/modisco_subcommands/modisco_neurons_0"
]
@@ -1562,7 +1562,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"id": "dffdc854",
"metadata": {},
"outputs": [
@@ -1598,7 +1598,7 @@
" --top-n-markers 50 \\\n",
" --tasks \"cell_type.str.contains('neuron') and organ == 'CNS' and disease == 'healthy'\" \\\n",
" --batch-size 1 \\\n",
- " --model 1 \\\n",
+ " --model v1_rep1 \\\n",
" --num-workers 8 \\\n",
" -o example/modisco_subcommands/modisco_neurons_1"
]
diff --git a/docs/tutorials/5-gene-expression-prediction.ipynb b/docs/tutorials/5-gene-expression-prediction.ipynb
index da560b2..25413ab 100644
--- a/docs/tutorials/5-gene-expression-prediction.ipynb
+++ b/docs/tutorials/5-gene-expression-prediction.ipynb
@@ -46,10 +46,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
+ "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n",
+ " warnings.warn(\n",
+ "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n",
+ " warnings.warn(\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmhcelik\u001b[0m (\u001b[33mmhcw\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
- "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact metadata:latest, 3122.32MB. 1 files... \n",
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact 'metadata:latest', 3122.32MB. 1 files...\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
- "Done. 0:0:1.6 (1995.7MB/s)\n"
+ "Done. 00:00:05.9 (533.2MB/s)\n"
]
},
{
@@ -1868,7 +1872,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "decima",
+ "display_name": "decima2",
"language": "python",
"name": "python3"
},
@@ -1882,7 +1886,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.12"
+ "version": "3.11.14"
}
},
"nbformat": 4,
diff --git a/setup.cfg b/setup.cfg
index 467b7e1..d52eec3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -71,7 +71,8 @@ install_requires =
pyarrow
safetensors
tangermeme>=1.0.0
- modisco-lite @ git+https://github.com/MuhammedHasan/tfmodisco-lite.git@faster-modisco
+ faster-modisco-lite>=3.0.0
+
[options.packages.find]
where = src
diff --git a/src/decima/__init__.py b/src/decima/__init__.py
index ebf1885..c94d98c 100644
--- a/src/decima/__init__.py
+++ b/src/decima/__init__.py
@@ -1,5 +1,5 @@
import sys
-from decima.constants import NUM_CELLS, DECIMA_CONTEXT_SIZE
+from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE, MODEL_METADATA
from decima.core.result import DecimaResult
from decima.interpret.attributions import predict_attributions_seqlet_calling
from decima.vep import predict_variant_effect
@@ -25,6 +25,7 @@
"DecimaResult",
"predict_variant_effect",
"predict_attributions_seqlet_calling",
- "NUM_CELLS",
"DECIMA_CONTEXT_SIZE",
+ "DEFAULT_ENSEMBLE",
+ "MODEL_METADATA",
]
diff --git a/src/decima/cli/attributions.py b/src/decima/cli/attributions.py
index f2dab90..783271a 100644
--- a/src/decima/cli/attributions.py
+++ b/src/decima/cli/attributions.py
@@ -17,7 +17,8 @@
"""
import click
-from decima.cli.callback import parse_genes, parse_model, parse_attributions
+from decima.constants import DEFAULT_ENSEMBLE
+from decima.cli.callback import parse_genes, parse_model, parse_attributions, parse_metadata
from decima.interpret.attributions import (
plot_attributions,
predict_save_attributions,
@@ -47,16 +48,16 @@
"--model",
type=str,
required=False,
- default=0,
+ default=DEFAULT_ENSEMBLE,
callback=parse_model,
help="Model to use for attribution analysis either replicate number or path to the model.",
show_default=True,
)
@click.option(
"--metadata",
- type=click.Path(exists=True),
default=None,
- help="Path to the metadata anndata file. If not provided, the default metadata will be downloaded from wandb.",
+ callback=parse_metadata,
+ help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.",
show_default=True,
)
@click.option(
@@ -196,16 +197,16 @@ def cli_attributions_predict(
"--model",
type=str,
required=False,
- default="ensemble",
+ default=DEFAULT_ENSEMBLE,
callback=parse_model,
help="Model to use for attribution analysis either replicate number or path to the model.",
show_default=True,
)
@click.option(
"--metadata",
- type=click.Path(exists=True),
+ callback=parse_metadata,
default=None,
- help="Path to the metadata anndata file. Default: None.",
+ help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.",
)
@click.option(
"--method", type=str, required=False, default="inputxgradient", help="Method to use for attribution analysis."
@@ -329,7 +330,12 @@ def cli_attributions(
"--off-tasks", type=str, required=False, help="Optional query string to filter cell types to contrast against."
)
@click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.")
-@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.")
+@click.option(
+ "--metadata",
+ callback=parse_metadata,
+ default=None,
+ help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.",
+)
@click.option(
"--genes",
type=str,
@@ -427,7 +433,12 @@ def cli_attributions_recursive_seqlet_calling(
callback=parse_genes,
help="Comma-separated list of gene symbols or IDs to analyze.",
)
-@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.")
+@click.option(
+ "--metadata",
+ callback=parse_metadata,
+ default=None,
+ help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used.",
+)
@click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.")
@click.option("--seqlogo-window", type=int, default=50, help="Window size for sequence logo plots")
@click.option("--custom-genome", is_flag=True, help="Use custom genome")
diff --git a/src/decima/cli/callback.py b/src/decima/cli/callback.py
index 1a24cc5..4af9770 100644
--- a/src/decima/cli/callback.py
+++ b/src/decima/cli/callback.py
@@ -1,15 +1,17 @@
import click
from pathlib import Path
+from decima.constants import MODEL_METADATA, ENSEMBLE_MODELS, DEFAULT_ENSEMBLE
def parse_model(ctx, param, value):
+ if isinstance(value, int):
+ value = str(value)
+
if value is None:
return None
elif isinstance(value, str):
- if value == "ensemble":
- return "ensemble"
- elif value in ["0", "1", "2", "3"]:
- return int(value)
+ if value in MODEL_METADATA:
+ return value
paths = value.split(",")
for path in paths:
@@ -32,7 +34,7 @@ def parse_genes(ctx, param, value):
def validate_save_replicates(ctx, param, value):
if value:
- if ctx.params["model"] == "ensemble":
+ if ctx.params["model"] in ENSEMBLE_MODELS:
return value
elif isinstance(ctx.params["model"], list) and (len(ctx.params["model"]) > 1):
return value
@@ -43,6 +45,32 @@ def validate_save_replicates(ctx, param, value):
return value
+def parse_metadata(ctx, param, value):
+ if value is None:
+ if "model" in ctx.params:
+ model = ctx.params["model"]
+ else:
+ model = DEFAULT_ENSEMBLE
+
+ if isinstance(model, list):
+ model = model[0]
+ if Path(model).exists():
+ raise click.ClickException(
+ f"File path passed for model {model} but metadata filepath is not provided. Also, provide the metadata filepath."
+ )
+ else:
+ return model
+ elif isinstance(value, str):
+ if value in MODEL_METADATA:
+ return value
+ elif Path(value).exists():
+ return value
+ else:
+ raise click.ClickException(
+ f"Invalid name for the metadata dataset: {value}. Check if the name is correct or the metadata file exists."
+ )
+
+
def parse_attributions(ctx, param, value):
value = value.split(",")
for i in value:
diff --git a/src/decima/cli/download.py b/src/decima/cli/download.py
index a145cd1..726a43b 100644
--- a/src/decima/cli/download.py
+++ b/src/decima/cli/download.py
@@ -10,6 +10,7 @@
"""
import click
+from decima.constants import DEFAULT_ENSEMBLE
from decima.cli.callback import parse_model
from decima.hub.download import (
cache_decima_data,
@@ -27,7 +28,11 @@ def cli_cache():
@click.command()
@click.option(
- "--model", type=str, default="ensemble", help="Model to download. Default: ensemble.", callback=parse_model
+ "--model",
+ type=str,
+ default=DEFAULT_ENSEMBLE,
+ help=f"Model to download. Default: {DEFAULT_ENSEMBLE}.",
+ callback=parse_model,
)
@click.option(
"--download-dir",
@@ -41,21 +46,34 @@ def cli_download_weights(model, download_dir):
@click.command()
+@click.option(
+ "--metadata",
+ default=DEFAULT_ENSEMBLE,
+ help=f"Model to download metadata for using wandb. Default: {DEFAULT_ENSEMBLE}.",
+ callback=parse_model,
+)
@click.option(
"--download-dir",
type=click.Path(),
default=".",
help="Directory to download the metadata. Default: current directory.",
)
-def cli_download_metadata(download_dir):
- """Download pre-trained Decima metadata."""
- download_decima_metadata(str(download_dir))
+def cli_download_metadata(metadata=DEFAULT_ENSEMBLE, download_dir=None):
+ """Download pre-trained Decima metadata for a given model."""
+ download_decima_metadata(metadata, str(download_dir))
@click.command()
+@click.option(
+ "--model-name",
+ type=str,
+ default=DEFAULT_ENSEMBLE,
+ help=f"Model to download. Default: {DEFAULT_ENSEMBLE}.",
+ callback=parse_model,
+)
@click.option(
"--download-dir", type=click.Path(), default=".", help="Directory to download the data. Default: current directory."
)
-def cli_download(download_dir):
+def cli_download(model_name=DEFAULT_ENSEMBLE, download_dir="."):
"""Download model weights and metadata for Decima."""
- download_decima(str(download_dir))
+ download_decima(model_name, str(download_dir))
diff --git a/src/decima/cli/finetune.py b/src/decima/cli/finetune.py
index 3105797..70105bb 100755
--- a/src/decima/cli/finetune.py
+++ b/src/decima/cli/finetune.py
@@ -68,7 +68,7 @@ def cli_finetune(
Args:
name: Name of the run for logging and checkpointing
- model: Model path or replication number (0-3)
+ model: Borzoi model path or replication number (0-3)
device: Device to use for training. Default: "0"
matrix_file: Path to the matrix file containing training data
h5_file: Path to the H5 file containing sequences
diff --git a/src/decima/cli/modisco.py b/src/decima/cli/modisco.py
index 264f285..c69800f 100644
--- a/src/decima/cli/modisco.py
+++ b/src/decima/cli/modisco.py
@@ -21,7 +21,8 @@
import click
from typing import List, Optional, Union
-from decima.cli.callback import parse_model, parse_genes, parse_attributions
+from decima.constants import DEFAULT_ENSEMBLE
+from decima.cli.callback import parse_model, parse_genes, parse_attributions, parse_metadata
from decima.interpret.modisco import (
predict_save_modisco_attributions,
modisco_patterns,
@@ -45,8 +46,20 @@
default=None,
help="Set of tasks will be subtracted from the attributions to calculate attribution on `specificity` transform. If not provided, all tasks will be computed.",
)
-@click.option("--model", type=str, default=0, help="Model to use for the prediction.", callback=parse_model)
-@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.")
+@click.option(
+ "--model",
+ type=str,
+ default=DEFAULT_ENSEMBLE,
+ help=f"Model to use for the prediction. Default: {DEFAULT_ENSEMBLE}.",
+ callback=parse_model,
+ show_default=True,
+)
+@click.option(
+ "--metadata",
+ default=None,
+ callback=parse_metadata,
+ help=f"Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.",
+)
@click.option(
"--method",
type=click.Choice(["saliency", "inputxgradient", "integratedgradients"]),
@@ -59,7 +72,7 @@
type=click.Choice(["specificity", "aggregate"]),
default="specificity",
show_default=True,
- help="Transform to use for attribution analysis.",
+ help="Transform to use for attribution analysis. Available options: 'specificity', 'aggregate'. Specificity transform is recommended for MoDISco to highlight cell-type-specific patterns.",
)
@click.option("--batch-size", type=int, default=1, show_default=True, help="Batch size for the prediction.")
@click.option(
@@ -88,7 +101,7 @@ def cli_modisco_attributions(
output_prefix: str,
tasks: Optional[List[str]] = None,
off_tasks: Optional[List[str]] = None,
- model: Optional[Union[str, int]] = 0,
+ model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE,
metadata: Optional[str] = None,
method: str = "saliency",
transform: str = "specificity",
@@ -144,7 +157,12 @@ def cli_modisco_attributions(
help="Set of tasks will be subtracted from the attributions to calculate attribution on `specificity` transform. If not provided, all tasks will be computed.",
)
@click.option("--tss-distance", type=int, default=10_000, show_default=True, help="TSS distance for the prediction.")
-@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.")
+@click.option(
+ "--metadata",
+ default=None,
+ callback=parse_metadata,
+ help=f"Path to the metadata anndata file or name of the model. Default: {DEFAULT_ENSEMBLE}.",
+)
@click.option(
"--genes",
type=str,
@@ -279,7 +297,12 @@ def cli_modisco_reports(
@click.command()
@click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files.")
@click.option("--modisco-h5", type=click.Path(exists=True), required=True, help="Path to the modisco HDF5 file.")
-@click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.")
+@click.option(
+ "--metadata",
+ default=None,
+ callback=parse_metadata,
+ help=f"Path to the metadata anndata file or name of the model. Default: {DEFAULT_ENSEMBLE}.",
+)
@click.option("--trim-threshold", type=float, default=0.2, show_default=True, help="Trim threshold.")
def cli_modisco_seqlet_bed(
output_prefix: str,
@@ -313,12 +336,17 @@ def cli_modisco_seqlet_bed(
@click.option(
"--model",
type=str,
- default="ensemble",
+ default=DEFAULT_ENSEMBLE,
show_default=True,
help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files. Default: `ensemble`.",
callback=parse_model,
)
-@click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.")
+@click.option(
+ "--metadata",
+ default=None,
+ callback=parse_metadata,
+ help=f"Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.",
+)
@click.option(
"--method",
type=click.Choice(["saliency", "inputxgradient", "integratedgradients"]),
@@ -326,6 +354,13 @@ def cli_modisco_seqlet_bed(
show_default=True,
help="Method to use for attribution analysis.",
)
+@click.option(
+ "--transform",
+ type=click.Choice(["specificity", "aggregate"]),
+ default="specificity",
+ show_default=True,
+ help="Transform to use for attribution analysis. Available options: 'specificity', 'aggregate'. Specificity transform is recommended for MoDISco to highlight cell-type-specific patterns.",
+)
@click.option("--batch-size", type=int, default=1, show_default=True, help="Batch size for the prediction.")
@click.option(
"--genes",
@@ -406,9 +441,10 @@ def cli_modisco(
tasks: Optional[List[str]] = None,
off_tasks: Optional[List[str]] = None,
tss_distance: int = 10_000,
- model: Optional[Union[str, int]] = "ensemble",
+ model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE,
metadata: Optional[str] = None,
method: str = "saliency",
+ transform: str = "specificity",
batch_size: int = 1,
genes: Optional[str] = None,
top_n_markers: Optional[int] = None,
@@ -445,6 +481,7 @@ def cli_modisco(
model=model,
metadata_anndata=metadata,
method=method,
+ transform=transform,
batch_size=batch_size,
genes=genes,
top_n_markers=top_n_markers,
diff --git a/src/decima/cli/predict_genes.py b/src/decima/cli/predict_genes.py
index 32a5934..317706a 100644
--- a/src/decima/cli/predict_genes.py
+++ b/src/decima/cli/predict_genes.py
@@ -8,7 +8,8 @@
import click
from pathlib import Path
-from decima.cli.callback import parse_model, parse_genes, validate_save_replicates
+from decima.constants import DEFAULT_ENSEMBLE
+from decima.cli.callback import parse_model, parse_genes, validate_save_replicates, parse_metadata
from decima.tools.inference import predict_gene_expression
@@ -25,15 +26,16 @@
"-m",
"--model",
type=str,
- default="ensemble",
+ default=DEFAULT_ENSEMBLE,
+ show_default=True,
callback=parse_model,
- help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to checkpoint files",
+ help=f"`0`, `1`, `2`, `3`, `{DEFAULT_ENSEMBLE}` or a path or a comma-separated list of paths to checkpoint files",
)
@click.option(
"--metadata",
- type=click.Path(exists=True),
default=None,
- help="Path to the metadata anndata file. Default: None.",
+ callback=parse_metadata,
+ help="Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.",
)
@click.option(
"--device",
diff --git a/src/decima/cli/query_cell.py b/src/decima/cli/query_cell.py
index a6841c8..f86ba31 100644
--- a/src/decima/cli/query_cell.py
+++ b/src/decima/cli/query_cell.py
@@ -18,15 +18,20 @@
"""
import click
+from decima.constants import DEFAULT_ENSEMBLE
+from decima.cli.callback import parse_metadata
from decima.core.result import DecimaResult
@click.command()
@click.argument("query", default="")
@click.option(
- "--metadata-anndata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file."
+ "--metadata",
+ default=None,
+ callback=parse_metadata,
+ help=f"Path to the metadata anndata file or name of the model. Default: {DEFAULT_ENSEMBLE}.",
)
-def cli_query_cell(query="", metadata_anndata=None):
+def cli_query_cell(query="", metadata=None):
"""
Query a cell using query string
@@ -42,7 +47,7 @@ def cli_query_cell(query="", metadata_anndata=None):
...
"""
- result = DecimaResult.load(metadata_anndata)
+ result = DecimaResult.load(metadata)
df = result.cell_metadata
if query != "":
diff --git a/src/decima/cli/vep.py b/src/decima/cli/vep.py
index 795e74e..7fc290b 100644
--- a/src/decima/cli/vep.py
+++ b/src/decima/cli/vep.py
@@ -21,8 +21,8 @@
"""
import click
-from decima.constants import DECIMA_CONTEXT_SIZE
-from decima.cli.callback import parse_model, validate_save_replicates
+from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE
+from decima.cli.callback import parse_model, validate_save_replicates, parse_metadata
from decima.utils.dataframe import ensemble_predictions
from decima.vep import predict_variant_effect
@@ -46,15 +46,15 @@
@click.option(
"--model",
type=str,
- default="ensemble",
+ default=DEFAULT_ENSEMBLE,
callback=parse_model,
help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files to perform variant effect prediction. Default: `ensemble`.",
)
@click.option(
"--metadata",
- type=click.Path(exists=True),
default=None,
- help="Path to the metadata anndata file. Default: None.",
+ callback=parse_metadata,
+ help=f"Path to the metadata anndata file or name of the model. If not provided, the compabilite metadata for the model will be used. Default: {DEFAULT_ENSEMBLE}.",
)
@click.option(
"--device", type=str, default=None, help="Device to use. Default: None which automatically selects the best device."
diff --git a/src/decima/constants.py b/src/decima/constants.py
index ca2e086..af70f35 100644
--- a/src/decima/constants.py
+++ b/src/decima/constants.py
@@ -1,13 +1,51 @@
"""Decima constants."""
+import json
import os
-DECIMA_CONTEXT_SIZE = 524288
+# constants for all models
+DECIMA_CONTEXT_SIZE = 524_288
SUPPORTED_GENOMES = {"hg38"}
-NUM_CELLS = 8856
-if "DECIMA_ENSEMBLE_MODELS_NAMES" in os.environ:
- ENSEMBLE_MODELS_NAMES = os.environ["DECIMA_ENSEMBLE_MODELS_NAMES"].split(",")
-else:
- ENSEMBLE_MODELS_NAMES = ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"]
+# EDIT: following metadata to add new models;
+# metadata of models models
+# models has dict as values
+# ensemble models have list of model names as values and fetched metadata from the models
+# following fields are required in the metadata:
+# - name of the models on wandb
+# - number of cells of the model
+# - metadata name in the wandb
+# - model_path [optional] to the local model path
+# - metadata_path [optional] to the local metadata path
+MODEL_METADATA = {
+ "v1_rep0": {"name": "rep0", "num_tasks": 8856, "metadata": "metadata"},
+ "v1_rep1": {"name": "rep1", "num_tasks": 8856, "metadata": "metadata"},
+ "v1_rep2": {"name": "rep2", "num_tasks": 8856, "metadata": "metadata"},
+ "v1_rep3": {"name": "rep3", "num_tasks": 8856, "metadata": "metadata"},
+ "ensemble": ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"],
+}
+MODEL_METADATA["rep0"] = MODEL_METADATA["v1_rep0"]
+MODEL_METADATA["rep1"] = MODEL_METADATA["v1_rep1"]
+MODEL_METADATA["rep2"] = MODEL_METADATA["v1_rep2"]
+MODEL_METADATA["rep3"] = MODEL_METADATA["v1_rep3"]
+MODEL_METADATA[0] = MODEL_METADATA["v1_rep0"]
+MODEL_METADATA[1] = MODEL_METADATA["v1_rep1"]
+MODEL_METADATA[2] = MODEL_METADATA["v1_rep2"]
+MODEL_METADATA[3] = MODEL_METADATA["v1_rep3"]
+MODEL_METADATA["0"] = MODEL_METADATA["v1_rep0"]
+MODEL_METADATA["1"] = MODEL_METADATA["v1_rep1"]
+MODEL_METADATA["2"] = MODEL_METADATA["v1_rep2"]
+MODEL_METADATA["3"] = MODEL_METADATA["v1_rep3"]
+
+# default version
+DEFAULT_ENSEMBLE = "ensemble"
+
+# overwrite model metadata from environment variables
+if "MODEL_METADATA" in os.environ:
+ MODEL_METADATA = json.loads(os.environ["MODEL_METADATA"])
+
+if "DEFAULT_ENSEMBLE" in os.environ:
+ DEFAULT_ENSEMBLE = os.environ["DEFAULT_ENSEMBLE"]
+
+ENSEMBLE_MODELS = [k for k, v in MODEL_METADATA.items() if isinstance(v, list)]
diff --git a/src/decima/core/attribution.py b/src/decima/core/attribution.py
index 015a5dc..43d7938 100644
--- a/src/decima/core/attribution.py
+++ b/src/decima/core/attribution.py
@@ -20,7 +20,7 @@
from grelu.sequence.format import convert_input_type, strings_to_one_hot
from grelu.visualize import plot_attributions
-from decima.constants import DECIMA_CONTEXT_SIZE
+from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE, MODEL_METADATA
from decima.core.result import DecimaResult
from decima.interpret.attributer import DecimaAttributer
from decima.utils.sequence import one_hot_to_seq
@@ -188,11 +188,11 @@ def from_seq(
inputs: Union[str, torch.Tensor, np.ndarray],
tasks: Optional[list] = None,
off_tasks: Optional[list] = None,
- model: Optional[Union[str, int]] = 0,
+ model: Optional[Union[str, int]] = MODEL_METADATA[DEFAULT_ENSEMBLE][0],
transform: str = "specificity",
method: str = "inputxgradient",
device: Optional[str] = "cpu",
- result: Optional[DecimaResult] = None,
+ result: Optional[str] = None,
gene: Optional[str] = "",
chrom: Optional[str] = None,
start: Optional[int] = None,
@@ -218,6 +218,7 @@ def from_seq(
transform: Transformation to apply to attributions
method: Method to use for attribution analysis available options: "saliency", "inputxgradient", "integratedgradients".
device: Device to use for attribution analysis
+ result: Result object or path to result object or name of the model to load the result for.
gene: Gene name
chrom: Chromosome name
start: Start position
@@ -259,9 +260,7 @@ def from_seq(
else:
raise ValueError("`inputs` must be a string, torch.Tensor, or np.ndarray")
- if result is None:
- result = DecimaResult.load()
-
+ result = DecimaResult.load(result or model)
tasks, off_tasks = result.query_tasks(tasks, off_tasks)
attrs = (
@@ -762,9 +761,7 @@ def _load_attribution(
pattern_type=pattern_type,
)
- def _get_metadata(
- self, genes: List[str], metadata_anndata: Optional[DecimaResult] = None, custom_genome: bool = False
- ):
+ def _get_metadata(self, genes: List[str], metadata_anndata: Optional[str] = None, custom_genome: bool = False):
if custom_genome:
chroms = genes
starts = [0] * len(genes)
@@ -773,7 +770,10 @@ def _get_metadata(
else:
ends = [DECIMA_CONTEXT_SIZE] * len(genes)
else:
- result = DecimaResult.load(metadata_anndata)
+ model_name = self.model_name
+ if isinstance(model_name, list):
+ model_name = model_name[0]
+ result = DecimaResult.load(metadata_anndata or model_name)
chroms = result.gene_metadata.loc[genes].chrom
if self.tss_distance is not None:
tss_pos = np.where(
@@ -791,7 +791,7 @@ def _get_metadata(
def load_attribution(
self,
gene: str,
- metadata_anndata: Optional[DecimaResult] = None,
+ metadata_anndata: Optional[str] = None,
custom_genome: bool = False,
threshold: float = 5e-4,
min_seqlet_len: int = 4,
@@ -871,7 +871,7 @@ def _recursive_seqlet_calling(
def recursive_seqlet_calling(
self,
genes: Optional[List[str]] = None,
- metadata_anndata: Optional[DecimaResult] = None,
+ metadata_anndata: Optional[str] = None,
custom_genome: bool = False,
threshold: float = 5e-4,
min_seqlet_len: int = 4,
diff --git a/src/decima/core/metadata.py b/src/decima/core/metadata.py
index 51677a4..0d1361b 100644
--- a/src/decima/core/metadata.py
+++ b/src/decima/core/metadata.py
@@ -101,9 +101,6 @@ class CellMetadata:
disease: str
study: str
dataset: str
- region: Optional[str]
- subregion: Optional[str]
- celltype_coarse: Optional[str]
n_cells: int
total_counts: float
n_genes: int
@@ -111,6 +108,12 @@ class CellMetadata:
train_pearson: float
val_pearson: float
test_pearson: float
+ region: Optional[str] = field(default=None)
+ subregion: Optional[str] = field(default=None)
+ celltype_coarse: Optional[str] = field(default=None)
+ co_term: Optional[str] = field(default=None)
+ co_name: Optional[str] = field(default=None)
+ frac_nan: Optional[float] = field(default=None)
@classmethod
def from_series(cls, name: str, series: pd.Series) -> "CellMetadata":
diff --git a/src/decima/core/result.py b/src/decima/core/result.py
index 2c14d3b..a7c3ac3 100644
--- a/src/decima/core/result.py
+++ b/src/decima/core/result.py
@@ -7,7 +7,7 @@
from grelu.sequence.format import intervals_to_strings, strings_to_one_hot
-from decima.constants import DECIMA_CONTEXT_SIZE
+from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE, MODEL_METADATA
from decima.hub import load_decima_metadata, load_decima_model
from decima.core.metadata import GeneMetadata, CellMetadata
from decima.tools.evaluate import marker_zscores
@@ -59,11 +59,12 @@ def __init__(self, anndata):
self._model = None
@classmethod
- def load(cls, anndata_path: Optional[Union[str, anndata.AnnData]] = None):
+ def load(cls, anndata_name_or_path: Optional[Union[str, anndata.AnnData]] = None):
"""Load a DecimaResult object from an anndata file or a path to an anndata file.
Args:
- anndata_path: Path to anndata file or anndata object
+ anndata_name_or_path: Name of the model or path to anndata file or anndata object
+ model: Model name or path to model checkpoint. If not provided, the default model will be loaded.
Returns:
DecimaResult object
@@ -74,16 +75,19 @@ def load(cls, anndata_path: Optional[Union[str, anndata.AnnData]] = None):
... "path/to/anndata.h5ad"
... ) # Load custom anndata object from file
"""
- if anndata_path is None:
- return cls(load_decima_metadata())
- elif isinstance(anndata_path, str):
- return cls(anndata.read_h5ad(anndata_path))
- elif isinstance(anndata_path, anndata.AnnData):
- return cls(anndata_path)
- elif isinstance(anndata_path, DecimaResult):
- return anndata_path
+ if isinstance(anndata_name_or_path, list):
+ anndata_name_or_path = anndata_name_or_path[0]
+
+ if (anndata_name_or_path is None) or (anndata_name_or_path in MODEL_METADATA):
+ return cls(load_decima_metadata(name_or_path=anndata_name_or_path))
+ elif isinstance(anndata_name_or_path, str):
+ return cls(anndata.read_h5ad(anndata_name_or_path))
+ elif isinstance(anndata_name_or_path, anndata.AnnData):
+ return cls(anndata_name_or_path)
+ elif isinstance(anndata_name_or_path, DecimaResult):
+ return anndata_name_or_path
else:
- raise ValueError(f"Invalid anndata path: {anndata_path}")
+ raise ValueError(f"Invalid anndata path: {anndata_name_or_path}")
@property
def model(self):
@@ -92,7 +96,7 @@ def model(self):
self.load_model()
return self._model
- def load_model(self, model: Optional[Union[str, int]] = 0, device: str = "cpu"):
+ def load_model(self, model: Optional[Union[str, int]] = MODEL_METADATA[DEFAULT_ENSEMBLE][0], device: str = "cpu"):
"""Load the trained model from a checkpoint path.
Args:
@@ -172,7 +176,7 @@ def predicted_expression_matrix(
Returns:
pd.DataFrame: Predicted expression matrix (cells x genes)
"""
- model_name = "preds" if (model_name is None) or (model_name == "ensemble") else model_name
+ model_name = "preds" if (model_name is None) or (model_name in MODEL_METADATA) else model_name
if genes is None:
return pd.DataFrame(self.anndata.layers[model_name], index=self.cells, columns=self.genes)
else:
@@ -222,7 +226,7 @@ def prepare_one_hot(
Returns:
torch.Tensor: One-hot encoding of the gene
"""
- assert gene in self.genes, f"{gene} is not in the anndata object"
+ assert gene in self.genes, f"{gene} is not in the anndata object. See avaliable genes with `result.genes`."
gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding)
if variants is None:
@@ -249,11 +253,7 @@ def gene_sequence(self, gene: str, stranded: bool = True, genome: str = "hg38")
Returns:
str: Sequence for the gene
"""
- try:
- assert gene in self.genes, f"{gene} is not in the anndata object"
- except AssertionError:
- print(gene)
- print(self.genes)
+ assert gene in self.genes, f"{gene} is not in the anndata object. See avaliable genes with `result.genes`."
gene_meta = self.gene_metadata.loc[gene]
if not stranded:
gene_meta = {"chrom": gene_meta.chrom, "start": gene_meta.start, "end": gene_meta.end}
diff --git a/src/decima/data/dataset.py b/src/decima/data/dataset.py
index 0c1a341..f73df53 100644
--- a/src/decima/data/dataset.py
+++ b/src/decima/data/dataset.py
@@ -8,7 +8,7 @@
- VariantDataset: Dataset for variant effect prediction.
"""
-from typing import List
+from typing import List, Optional
import warnings
import torch
import h5py
@@ -22,7 +22,7 @@
from grelu.sequence.format import strings_to_one_hot
from grelu.sequence.utils import reverse_complement
-from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS_NAMES
+from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS, MODEL_METADATA
from decima.data.read_hdf5 import _extract_center, index_genes
from decima.core.result import DecimaResult
from decima.utils.io import read_fasta_gene_mask
@@ -186,11 +186,11 @@ class GeneDataset(Dataset):
def __init__(
self,
genes=None,
- metadata_anndata=None,
- max_seq_shift=0,
- seed=0,
- augment_mode="random",
- genome="hg38",
+ metadata_anndata: Optional[str] = None,
+ max_seq_shift: int = 0,
+ seed: int = 0,
+ augment_mode: str = "random",
+ genome: str = "hg38",
):
super().__init__()
@@ -599,24 +599,24 @@ class VariantDataset(Dataset):
def __init__(
self,
variants,
- metadata_anndata=None,
- max_seq_shift=0,
- seed=0,
+ metadata_anndata: Optional[str] = None,
+ max_seq_shift: int = 0,
+ seed: int = 0,
include_cols=None,
- gene_col=None,
- min_from_end=0,
- distance_type="tss",
- min_distance=0,
- max_distance=float("inf"),
- model_name=None,
- reference_cache=True,
- genome="hg38",
+ gene_col: Optional[str] = None,
+ min_from_end: int = 0,
+ distance_type: str = "tss",
+ min_distance: int = 0,
+ max_distance: float = float("inf"),
+ model_name: Optional[str] = None,
+ reference_cache: bool = True,
+ genome: str = "hg38",
):
super().__init__()
self.reference_cache = reference_cache
self.genome = genome
- self.result = DecimaResult.load(metadata_anndata)
+ self.result = DecimaResult.load(metadata_anndata or model_name)
self.variants = self._overlap_genes(
variants,
@@ -647,8 +647,8 @@ def __init__(
if (model_name is None) or (not reference_cache):
self.model_names = list() # no reference caching
- elif model_name == "ensemble":
- self.model_names = ENSEMBLE_MODELS_NAMES
+ elif model_name in ENSEMBLE_MODELS:
+ self.model_names = MODEL_METADATA[model_name]
else:
self.model_names = [model_name]
diff --git a/src/decima/data/read_hdf5.py b/src/decima/data/read_hdf5.py
index d588369..440b306 100644
--- a/src/decima/data/read_hdf5.py
+++ b/src/decima/data/read_hdf5.py
@@ -3,6 +3,8 @@
import torch
from grelu.sequence.format import BASE_TO_INDEX_HASH, indices_to_one_hot
+from decima.constants import DECIMA_CONTEXT_SIZE
+
def count_genes(h5_file, key=None):
with h5py.File(h5_file, "r") as f:
@@ -42,7 +44,7 @@ def _extract_center(x, seq_len, shift=0):
return x[..., start : start + seq_len]
-def extract_gene_data(h5_file, gene, seq_len=524288, merge=True):
+def extract_gene_data(h5_file, gene, seq_len=DECIMA_CONTEXT_SIZE, merge=True):
gene_idx = get_gene_idx(h5_file, key=None, gene=gene)
with h5py.File(h5_file, "r") as f:
diff --git a/src/decima/hub/__init__.py b/src/decima/hub/__init__.py
index b6e7611..e50c657 100644
--- a/src/decima/hub/__init__.py
+++ b/src/decima/hub/__init__.py
@@ -1,4 +1,5 @@
import os
+import json
from typing import Union, Optional, List
import warnings
import wandb
@@ -6,18 +7,20 @@
from tempfile import TemporaryDirectory
import anndata
from grelu.resources import get_artifact, DEFAULT_WANDB_HOST
+from decima.constants import DEFAULT_ENSEMBLE, ENSEMBLE_MODELS, MODEL_METADATA
from decima.model.lightning import LightningModel, EnsembleLightningModel
def login_wandb():
"""Login to wandb either as anonymous or as a user."""
+ host = os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST)
try:
- wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), anonymous="never", timeout=0)
+ wandb.login(host=host, anonymous="never", timeout=0)
except wandb.errors.UsageError: # login anonymously if not logged in already
- wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), relogin=True, anonymous="must", timeout=0)
+ wandb.login(host=host, relogin=True, anonymous="must", timeout=0)
-def load_decima_model(model: Union[str, int, List[str]] = 0, device: Optional[str] = None):
+def load_decima_model(model: Union[str, int, List[str]] = DEFAULT_ENSEMBLE, device: Optional[str] = None):
"""Load a pre-trained Decima model from wandb or local path.
Args:
@@ -37,74 +40,80 @@ def load_decima_model(model: Union[str, int, List[str]] = 0, device: Optional[st
if isinstance(model, LightningModel):
return model
- elif model == "ensemble":
- return EnsembleLightningModel([load_decima_model(i, device) for i in range(4)])
+ elif model in ENSEMBLE_MODELS:
+ return EnsembleLightningModel(
+ [load_decima_model(model_name, device) for model_name in MODEL_METADATA[model]],
+ name=model,
+ )
elif isinstance(model, List):
if len(model) == 1:
return load_decima_model(model[0], device)
else:
- return EnsembleLightningModel([load_decima_model(path, device) for path in model])
-
- elif model in {0, 1, 2, 3}:
- model_name = f"rep{model}"
+ return EnsembleLightningModel([load_decima_model(path, device) for path in model], name=model)
# Load directly from a path
- elif isinstance(model, str):
- if Path(model).exists():
- if model.endswith("ckpt"):
- return LightningModel.load_from_checkpoint(model, map_location=device)
- else:
- return LightningModel.load_safetensor(model, device=device)
+ if model in MODEL_METADATA:
+ model_name = MODEL_METADATA[model]["name"]
+ if "model_path" in MODEL_METADATA[model]: # if model path exist in metadata load it from the path
+ return load_decima_model(MODEL_METADATA[model]["model_path"], device)
+ elif isinstance(model, str) and Path(model).exists():
+ if model.endswith("ckpt"):
+ return LightningModel.load_from_checkpoint(model, map_location=device)
else:
- model_name = model
-
+ return LightningModel.load_safetensor(model, device=device)
else:
raise ValueError(
- f"Invalid model: {model} it needs to be either a string of model_names on wandb, "
- "an integer of replicate number {0, 1, 2, 3}, a path to a local model or a list of paths."
+ f"Invalid model: {model} it needs to be either a string of model_names on wandb ("
+ f"{list(MODEL_METADATA.keys())}), path to a local model, or a list of paths."
)
-
- # If left with a model name, load from environment/wandb
- if model_name.upper() in os.environ:
- if Path(os.environ[model_name.upper()]).exists():
- return LightningModel.load_safetensor(os.environ[model_name.upper()], device=device)
- else:
- warnings.warn(
- f"Model `{model_name}` provided in environment variables, "
- f"but not found in `{os.environ[model_name.upper()]}` "
- f"Trying to download `{model_name}` from wandb."
- )
-
- art = get_artifact(model_name, project="decima")
+ # load model from wandb
+ art = get_artifact(
+ model_name,
+ project="decima",
+ host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST),
+ )
with TemporaryDirectory() as d:
art.download(d)
return LightningModel.load_safetensor(Path(d) / f"{model_name}.safetensors", device=device)
-def load_decima_metadata(path: Optional[str] = None):
+def load_decima_metadata(name_or_path: Optional[str] = None):
"""Load the Decima metadata from wandb.
Args:
- path: Path to local metadata file. If None, downloads from wandb.
+ name_or_path: Path to local metadata file or name of the model to load metadata for using wandb. If None, default model's metadata will be downloaded from wandb.
Returns:
An AnnData object containing the Decima metadata.
"""
- if path is not None:
- return anndata.read_h5ad(path)
+ if name_or_path is not None:
+ if Path(name_or_path).exists():
+ return anndata.read_h5ad(name_or_path)
+
+ name_or_path = name_or_path or DEFAULT_ENSEMBLE
+
+ if name_or_path in ENSEMBLE_MODELS:
+ name_or_path = MODEL_METADATA[name_or_path][0]
+
+ if name_or_path in MODEL_METADATA:
+ metadata = MODEL_METADATA[name_or_path]
- if "DECIMA_METADATA" in os.environ:
- if Path(os.environ["DECIMA_METADATA"]).exists():
- return anndata.read_h5ad(os.environ["DECIMA_METADATA"])
+ if "metadata_path" in metadata:
+ if Path(metadata["metadata_path"]).exists():
+ return anndata.read_h5ad(metadata["metadata_path"])
else:
warnings.warn(
- f"Metadata `{os.environ['DECIMA_METADATA']}` provided in environment variables, "
- f"but not found in `{os.environ['DECIMA_METADATA']}` "
+ f"Metadata `{metadata['metadata_path']}` provided in environment variables, "
+ f"but not found in `{metadata['metadata_path']}` "
f"Trying to download `metadata` from wandb."
)
- art = get_artifact("metadata", project="decima")
+ art = get_artifact(
+ metadata["metadata"],
+ project="decima",
+ host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST),
+ )
with TemporaryDirectory() as d:
art.download(d)
- return anndata.read_h5ad(Path(d) / "metadata.h5ad")
+ return anndata.read_h5ad(Path(d) / f"{metadata['metadata']}.h5ad")
diff --git a/src/decima/hub/download.py b/src/decima/hub/download.py
index 9e2fe72..6788e74 100644
--- a/src/decima/hub/download.py
+++ b/src/decima/hub/download.py
@@ -1,8 +1,10 @@
+import os
from pathlib import Path
from typing import Union
import logging
import genomepy
-from grelu.resources import get_artifact
+from grelu.resources import get_artifact, DEFAULT_WANDB_HOST
+from decima.constants import DEFAULT_ENSEMBLE, ENSEMBLE_MODELS, MODEL_METADATA
from decima.hub import login_wandb, load_decima_model, load_decima_metadata
@@ -18,7 +20,7 @@ def cache_hg38():
def cache_decima_weights():
"""Download pre-trained Decima model weights from wandb."""
logger.info("Downloading Decima model weights...")
- for rep in range(4):
+ for rep in MODEL_METADATA[DEFAULT_ENSEMBLE]:
load_decima_model(rep)
@@ -36,7 +38,7 @@ def cache_decima_data():
cache_decima_metadata()
-def download_decima_weights(model_name: Union[str, int], download_dir: str):
+def download_decima_weights(model: Union[str, int] = DEFAULT_ENSEMBLE, download_dir: str = "."):
"""Download pre-trained Decima model weights from wandb.
Args:
@@ -46,40 +48,44 @@ def download_decima_weights(model_name: Union[str, int], download_dir: str):
Returns:
Path to the downloaded model weights.
"""
- if "ensemble" == model_name:
- return [download_decima_weights(model, download_dir) for model in range(4)]
-
- if model_name in {0, 1, 2, 3}:
- model_name = f"rep{model_name}"
+ if model in ENSEMBLE_MODELS:
+ return [download_decima_weights(model, download_dir) for model in MODEL_METADATA[model]]
+ model_name = MODEL_METADATA[model]["name"]
download_dir = Path(download_dir)
download_dir.mkdir(parents=True, exist_ok=True)
- logger.info(f"Downloading Decima model weights for {model_name} to {download_dir / f'{model_name}.safetensors'}")
+ logger.info(f"Downloading Decima model weights for {model} to {download_dir / f'{model_name}.safetensors'}")
- art = get_artifact(model_name, project="decima")
+ art = get_artifact(model_name, project="decima", host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST))
art.download(str(download_dir))
return download_dir / f"{model_name}.safetensors"
-def download_decima_metadata(download_dir: str):
+def download_decima_metadata(metadata: str = DEFAULT_ENSEMBLE, download_dir: str = "."):
"""Download pre-trained Decima model data from wandb.
Args:
download_dir: Directory to download the metadata.
+ metadata: Name of the model to download metadata for using wandb.
Returns:
Path to the downloaded metadata.
"""
- art = get_artifact("metadata", project="decima")
+ metadata = metadata or DEFAULT_ENSEMBLE
+ if metadata in ENSEMBLE_MODELS:
+ metadata = MODEL_METADATA[metadata][0]
+
+ metadata_name = MODEL_METADATA[metadata]["metadata"]
+ art = get_artifact(metadata_name, project="decima", host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST))
download_dir = Path(download_dir)
download_dir.mkdir(parents=True, exist_ok=True)
- logger.info(f"Downloading Decima metadata to {download_dir / 'metadata.h5ad'}.")
+ logger.info(f"Downloading Decima metadata to {download_dir / f'{metadata_name}.h5ad'}.")
art.download(str(download_dir))
- return download_dir / "metadata.h5ad"
+ return download_dir / f"{metadata_name}.h5ad"
-def download_decima(download_dir: str):
+def download_decima(model: str = DEFAULT_ENSEMBLE, download_dir: str = "."):
"""Download all required data for Decima.
Args:
@@ -92,6 +98,6 @@ def download_decima(download_dir: str):
download_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Downloading Decima model weights and metadata to {download_dir}:")
- download_decima_weights("ensemble", download_dir)
- download_decima_metadata(download_dir)
+ download_decima_weights(model, download_dir)
+ download_decima_metadata(model, download_dir)
return download_dir
diff --git a/src/decima/interpret/attributions.py b/src/decima/interpret/attributions.py
index 10388f9..31b5922 100644
--- a/src/decima/interpret/attributions.py
+++ b/src/decima/interpret/attributions.py
@@ -40,9 +40,11 @@
from torch.utils.data import DataLoader
from pyfaidx import Faidx
+from decima.constants import DEFAULT_ENSEMBLE, MODEL_METADATA, ENSEMBLE_MODELS
from decima.core.attribution import AttributionResult
from decima.core.result import DecimaResult
from decima.data.dataset import GeneDataset, SeqDataset
+from decima.hub import load_decima_model
from decima.interpret.attributer import DecimaAttributer
from decima.utils import get_compute_device, _get_on_off_tasks, _get_genes
from decima.utils.io import AttributionWriter
@@ -53,7 +55,7 @@ def predict_save_attributions(
output_prefix: str,
tasks: Optional[List[str]] = None,
off_tasks: Optional[List[str]] = None,
- model: Optional[int] = 0,
+ model: Optional[int] = DEFAULT_ENSEMBLE,
metadata_anndata: Optional[str] = None,
method: str = "inputxgradient",
transform: str = "specificity",
@@ -119,9 +121,9 @@ def predict_save_attributions(
... genome="hg38",
... )
"""
- if (model == "ensemble") or isinstance(model, (list, tuple)):
- if model == "ensemble":
- models = [0, 1, 2, 3]
+ if (model in ENSEMBLE_MODELS) or isinstance(model, (list, tuple)):
+ if model in ENSEMBLE_MODELS:
+ models = MODEL_METADATA[model]
else:
models = model
return [
@@ -154,17 +156,17 @@ def predict_save_attributions(
device = get_compute_device(device)
logger.info(f"Using device: {device}")
- logger.info("Loading model and metadata to compute attributions...")
- result = DecimaResult.load(metadata_anndata)
+ logger.info(f"Loading model {model} and metadata to compute attributions...")
+ model = load_decima_model(model, device=device)
+ result = DecimaResult.load(metadata_anndata or model.name)
tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks)
+ attributer = DecimaAttributer(model, tasks, off_tasks, method, transform)
- with QCLogger(str(output_prefix) + ".warnings.qc.log", metadata_anndata=metadata_anndata) as qc:
+ with QCLogger(str(output_prefix) + ".warnings.qc.log", metadata_anndata=result) as qc:
if result.ground_truth is not None:
qc.log_correlation(tasks, off_tasks, plot=True)
- attributer = DecimaAttributer.load_decima_attributer(model, tasks, off_tasks, method, transform, device=device)
-
if (genes is not None) and (seqs is not None):
raise ValueError("Only one of `genes` or `seqs` arguments must be provided not both.")
elif seqs is not None:
@@ -295,8 +297,6 @@ def recursive_seqlet_calling(
logger = logging.getLogger("decima")
logger.info("Loading model and metadata to compute attributions...")
- result = DecimaResult.load(metadata_anndata)
-
if isinstance(attributions, (str, Path)):
attributions_files = [Path(attributions).as_posix()]
else:
@@ -305,6 +305,8 @@ def recursive_seqlet_calling(
with AttributionResult(
attributions_files, tss_distance, correct_grad=False, num_workers=num_workers, agg_func=agg_func
) as ar:
+ result = DecimaResult.load(metadata_anndata or ar.model_name)
+
if top_n_markers is not None:
tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks)
all_genes = _get_genes(result, genes, top_n_markers, tasks, off_tasks)
@@ -338,7 +340,7 @@ def predict_attributions_seqlet_calling(
seqs: Optional[Union[pd.DataFrame, np.ndarray, torch.Tensor]] = None,
tasks: Optional[List[str]] = None,
off_tasks: Optional[List[str]] = None,
- model: Optional[Union[str, int]] = "ensemble",
+ model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE,
metadata_anndata: Optional[str] = None,
method: str = "inputxgradient",
transform: str = "specificity",
diff --git a/src/decima/interpret/modisco.py b/src/decima/interpret/modisco.py
index 3a4ddc2..4147fb6 100644
--- a/src/decima/interpret/modisco.py
+++ b/src/decima/interpret/modisco.py
@@ -21,12 +21,12 @@
import h5py
import numpy as np
import pandas as pd
-import modiscolite
+import fastermodiscolite
from tqdm import tqdm
from grelu.resources import get_meme_file_path
from grelu.interpret.motifs import trim_pwm
-from decima.constants import DECIMA_CONTEXT_SIZE
+from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE
from decima.core.result import DecimaResult
from decima.utils import _get_on_off_tasks, _get_genes
from decima.utils.motifs import motif_start_end
@@ -38,7 +38,7 @@ def predict_save_modisco_attributions(
output_prefix: str,
tasks: Optional[List[str]] = None,
off_tasks: Optional[List[str]] = None,
- model: Optional[int] = 0,
+ model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE,
metadata_anndata: Optional[str] = None,
method: str = "saliency",
transform: str = "specificity",
@@ -156,7 +156,7 @@ def modisco_patterns(
tasks: Tasks to analyze either list of task names or query string to filter cell types to analyze attributions for (e.g. 'cell_type == 'classical monocyte''). If not provided, all tasks will be analyzed.
off_tasks: Off tasks to analyze either list of task names or query string to filter cell types to contrast against (e.g. 'cell_type == 'classical monocyte''). If not provided, all tasks will be used as off tasks.
tss_distance: Distance from TSS to analyze for pattern discovery default is 10000. Controls the genomic window size around TSS for seqlet detection and motif discovery.
- metadata_anndata: Path to metadata anndata file or DecimaResult object. If not provided, the default metadata will be used from the attribution files.
+ metadata_anndata: Name of the model or path to metadata anndata file or DecimaResult object. If not provided, the compatible metadata of the saved attribution files will be used.
genes: Genes to analyze for pattern discovery if not provided, all genes will be used. Can be list of gene symbols or IDs to focus analysis on specific genes.
top_n_markers: Top n markers to analyze for pattern discovery if not provided, all markers will be analyzed. Useful for focusing on the most important marker genes for the specified tasks.
correct_grad: Whether to correct gradient for attribution analysis default is True. Applies gradient correction for better attribution quality before pattern discovery.
@@ -206,23 +206,27 @@ def modisco_patterns(
... )
"""
logger = logging.getLogger("decima")
- logger.info("Loading metadata")
- result = DecimaResult.load(metadata_anndata)
if isinstance(attributions, (str, Path)):
attributions_files = [Path(attributions).as_posix()]
else:
attributions_files = attributions
- tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks)
- all_genes = _get_genes(result, genes, top_n_markers, tasks, off_tasks)
-
- with AttributionResult(attributions_files, tss_distance, correct_grad, num_workers=1, agg_func="mean") as ar:
- sequences, attributions = ar.load(all_genes)
+ with AttributionResult(
+ attributions_files, tss_distance, correct_grad, num_workers=num_workers, agg_func="mean"
+ ) as ar:
genome = ar.genome
model_names = ar.model_name
- pos_patterns, neg_patterns = modiscolite.tfmodisco.TFMoDISco(
+ metadata_anndata = metadata_anndata or model_names[0]
+ logger.info(f"Loading metadata for model {metadata_anndata}...")
+ result = DecimaResult.load(metadata_anndata)
+
+ tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks)
+ all_genes = _get_genes(result, genes, top_n_markers, tasks, off_tasks)
+ sequences, attributions = ar.load(all_genes)
+
+ pos_patterns, neg_patterns = fastermodiscolite.tfmodisco.TFMoDISco(
hypothetical_contribs=attributions.transpose(0, 2, 1),
one_hot=sequences.transpose(0, 2, 1),
sliding_window_size=sliding_window_size,
@@ -258,7 +262,7 @@ def modisco_patterns(
verbose=True,
)
h5_path = Path(output_prefix).with_suffix(".modisco.h5").as_posix()
- modiscolite.io.save_hdf5(
+ fastermodiscolite.io.save_hdf5(
h5_path,
pos_patterns,
neg_patterns,
@@ -310,7 +314,7 @@ def modisco_reports(
"""
output_dir = Path(f"{output_prefix}_report")
output_dir.mkdir(parents=True, exist_ok=True)
- modiscolite.report.report_motifs(
+ fastermodiscolite.report.report_motifs(
modisco_h5,
output_dir.as_posix(),
img_path_suffix,
@@ -328,7 +332,7 @@ def modisco_reports(
def modisco_seqlet_bed(
output_prefix: str,
modisco_h5: str,
- metadata_anndata: str = None,
+ metadata_anndata: Optional[str] = None,
trim_threshold: float = 0.2,
):
"""Extract seqlet locations from MoDISco results and save as BED format file.
@@ -351,11 +355,13 @@ def modisco_seqlet_bed(
... trim_threshold=0.15,
... )
"""
- result = DecimaResult.load(metadata_anndata)
df = list()
with h5py.File(modisco_h5, "r") as f:
+ model_name = f.attrs["model_names"].split(",")[0]
+ result = DecimaResult.load(metadata_anndata or model_name)
+
tss_distance = f.attrs["tss_distance"]
genes = [gene.decode("utf-8") for gene in f["genes"][:]]
genes_idx = dict(enumerate(genes))
@@ -420,7 +426,7 @@ def modisco(
output_prefix: str,
tasks: Optional[List[str]] = None,
off_tasks: Optional[List[str]] = None,
- model: Optional[Union[str, int]] = 0,
+ model: Optional[Union[str, int]] = DEFAULT_ENSEMBLE,
tss_distance: int = 1000,
metadata_anndata: Optional[str] = None,
genes: Optional[List[str]] = None,
@@ -429,6 +435,7 @@ def modisco(
num_workers: int = 4,
genome: str = "hg38",
method: str = "saliency",
+ transform: Optional[str] = "specificity",
batch_size: int = 2,
device: Optional[str] = None,
# tfmodisco parameters
@@ -492,6 +499,7 @@ def modisco(
num_workers: Number of workers for parallel processing default is 4. Increasing number of workers will speed up the process but requires more memory.
genome: Genome reference to use default is "hg38". Can be genome name or path to custom genome fasta file.
method: Method to use for attribution analysis default is "saliency". Available options: "saliency", "inputxgradient", "integratedgradients". For MoDISco, "saliency" is often preferred for pattern discovery.
+ transform: Transform to use for attribution analysis default is "specificity". Available options: "specificity", "aggregate". Specificity transform is recommended for MoDISco to highlight cell-type-specific patterns.
batch_size: Batch size for attribution analysis default is 2. Increasing batch size may speed up computation but requires more memory.
device: Device to use for computation (e.g. 'cuda', 'cpu'). If not provided, the best available device will be used automatically.
sliding_window_size: Sliding window size.
@@ -544,6 +552,7 @@ def modisco(
genes=genes,
top_n_markers=top_n_markers,
method=method,
+ transform=transform,
batch_size=batch_size,
correct_grad_bigwig=correct_grad,
device=device,
diff --git a/src/decima/model/lightning.py b/src/decima/model/lightning.py
index b94dd0e..1c20ac7 100644
--- a/src/decima/model/lightning.py
+++ b/src/decima/model/lightning.py
@@ -19,6 +19,7 @@
from torchmetrics import MetricCollection
import safetensors
+from decima.constants import DEFAULT_ENSEMBLE
from decima.utils import get_compute_device
from .decima_model import DecimaModel
from .loss import TaskWisePoissonMultinomialLoss
@@ -515,7 +516,7 @@ def load_safetensor(cls, path: str, device: str = "cpu"):
class EnsembleLightningModel(LightningModel):
- def __init__(self, models: List[LightningModel], name="ensemble"):
+ def __init__(self, models: List[LightningModel], name: str = DEFAULT_ENSEMBLE):
super().__init__(
name=name,
model_params=models[0].model_params,
@@ -524,7 +525,7 @@ def __init__(self, models: List[LightningModel], name="ensemble"):
)
self.models = nn.ModuleList(models)
self.reset_transform()
- self.name = "ensemble"
+ self.name = DEFAULT_ENSEMBLE
def forward(self, x: Tensor) -> Tensor:
return torch.concat([model(x) for model in self.models], dim=0)
diff --git a/src/decima/tools/inference.py b/src/decima/tools/inference.py
index c31234f..448a8c8 100644
--- a/src/decima/tools/inference.py
+++ b/src/decima/tools/inference.py
@@ -1,6 +1,8 @@
-import anndata
import logging
+from typing import Optional
+import anndata
import numpy as np
+from decima.constants import DEFAULT_ENSEMBLE
from decima.data.dataset import GeneDataset
from decima.hub import load_decima_model
from decima.utils import get_compute_device
@@ -8,15 +10,15 @@
def predict_gene_expression(
genes=None,
- model="ensemble",
- metadata_anndata=None,
- device=None,
- batch_size=1,
- num_workers=4,
+ model=DEFAULT_ENSEMBLE,
+ metadata_anndata: Optional[str] = None,
+ device: Optional[str] = None,
+ batch_size: int = 1,
+ num_workers: int = 4,
max_seq_shift=0,
- genome="hg38",
- save_replicates=False,
- float_precision="32",
+ genome: str = "hg38",
+ save_replicates: bool = False,
+ float_precision: str = "32",
):
"""Predict gene expression for a list of genes
@@ -42,10 +44,13 @@ def predict_gene_expression(
device = get_compute_device(device)
logger.info(f"Using device: {device} and genome: {genome} for prediction.")
- logger.info("Making predictions")
+ logger.info(f"Loading model {model}...")
model = load_decima_model(model, device=device)
- ds = GeneDataset(genes=genes, metadata_anndata=metadata_anndata, max_seq_shift=max_seq_shift, genome=genome)
+ logger.info("Making predictions")
+ ds = GeneDataset(
+ genes=genes, metadata_anndata=metadata_anndata or model.name, max_seq_shift=max_seq_shift, genome=genome
+ )
preds = model.predict_on_dataset(
ds, device=device, batch_size=batch_size, num_workers=num_workers, float_precision=float_precision
)
@@ -68,7 +73,10 @@ def predict_gene_expression(
ad.layers[f"preds_{model.name}"] = pred.T
logger.info("Evaluating performance")
- evaluate_gene_expression_predictions(ad)
+ if ad.X is not None:
+ evaluate_gene_expression_predictions(ad)
+ else:
+ logger.warning("No ground truth expression matrix found in the metadata. Skipping evaluation.")
return ad
diff --git a/src/decima/utils/__init__.py b/src/decima/utils/__init__.py
index a5ce9b7..c016533 100644
--- a/src/decima/utils/__init__.py
+++ b/src/decima/utils/__init__.py
@@ -77,9 +77,12 @@ def get_compute_device(device: Optional[str] = None) -> torch.device:
torch.device: The selected device for computation
"""
if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
+ return 0 if torch.cuda.is_available() else "cpu"
+ elif device == "cuda":
+ return 0
+ elif isinstance(device, int) or (device == "cpu"):
+ return device
elif isinstance(device, str) and device.isdigit():
- device = int(device)
-
- return torch.device(device)
+ return int(device)
+ else:
+ raise ValueError(f"Invalid device: {device}")
diff --git a/src/decima/utils/io.py b/src/decima/utils/io.py
index c8d5a4d..e695cda 100644
--- a/src/decima/utils/io.py
+++ b/src/decima/utils/io.py
@@ -183,7 +183,7 @@ class AttributionWriter:
path: Output HDF5 file path.
genes: Gene names to write.
model_name: Model identifier for metadata.
- metadata_anndata: Gene metadata source. None uses default Decima data.
+ metadata_anndata: Gene metadata source or path to the metadata anndata file. If not provided, the compatible metadata for the model will be used.
genome: Reference genome version.
bigwig: Create BigWig file for genome browser.
correct_grad_bigwig: Correct gradient bigwig for bias.
@@ -221,7 +221,7 @@ def __init__(
self.bigwig = bigwig
self.model_name = model_name
self.idx = {g: i for i, g in enumerate(self.genes)}
- self.result = DecimaResult.load(metadata_anndata)
+ self.result = DecimaResult.load(metadata_anndata or model_name)
self.correct_grad_bigwig = correct_grad_bigwig
self.custom_genes = custom_genes
diff --git a/src/decima/vep/__init__.py b/src/decima/vep/__init__.py
index 7b5b057..53fb49a 100644
--- a/src/decima/vep/__init__.py
+++ b/src/decima/vep/__init__.py
@@ -7,7 +7,7 @@
import pandas as pd
from grelu.transforms.prediction_transforms import Aggregate
-from decima.constants import SUPPORTED_GENOMES
+from decima.constants import SUPPORTED_GENOMES, DEFAULT_ENSEMBLE
from decima.model.metrics import WarningType
from decima.utils import get_compute_device
from decima.utils.dataframe import chunk_df, ChunkDataFrameWriter
@@ -19,7 +19,7 @@
def _predict_variant_effect(
df_variant: Union[pd.DataFrame, str],
tasks: Optional[Union[str, List[str]]] = None,
- model: Union[int, str] = "ensemble",
+ model: Union[str, int] = DEFAULT_ENSEMBLE,
metadata_anndata: Optional[str] = None,
batch_size: int = 1,
num_workers: int = 16,
@@ -61,8 +61,6 @@ def _predict_variant_effect(
raise ValueError(f"Genome {genome} not supported. Currently only hg38 is supported.")
include_cols = include_cols or list()
- model = load_decima_model(model=model, device=device)
-
try:
dataset = VariantDataset(
df_variant,
@@ -83,8 +81,6 @@ def _predict_variant_effect(
else:
raise e
- model = load_decima_model(model=model)
-
if tasks is not None:
tasks = dataset.result.query_cells(tasks)
@@ -120,7 +116,7 @@ def predict_variant_effect(
df_variant: Union[pd.DataFrame, str],
output_pq: Optional[str] = None,
tasks: Optional[Union[str, List[str]]] = None,
- model: Union[int, str, List[str]] = "ensemble",
+ model: Union[int, str, List[str]] = DEFAULT_ENSEMBLE,
metadata_anndata: Optional[str] = None,
chunksize: int = 10_000,
batch_size: int = 1,
@@ -142,7 +138,7 @@ def predict_variant_effect(
df_variant (pd.DataFrame or str): DataFrame with variant information or path to variant file
output_pq (str, optional): Path to save the parquet file. Defaults to None.
tasks (str, optional): Tasks to predict. Defaults to None.
- model (int, optional): Model to use. Defaults to "ensemble".
+ model (int, optional): Model to use. Defaults to DEFAULT_ENSEMBLE.
metadata_anndata (str, optional): Path to anndata file. Defaults to None.
chunksize (int, optional): Number of variants to predict in each chunk. Defaults to 10_000.
batch_size (int, optional): Batch size. Defaults to 1.
diff --git a/tests/conftest.py b/tests/conftest.py
index 27ec0c8..82d5336 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -86,7 +86,7 @@ def attribution_h5_file(tmp_path, attribution_data):
f.create_dataset('attribution', data=attribution_data['attributions'])
f.create_dataset('gene_mask_start', data=attribution_data['gene_mask_start'])
f.create_dataset('gene_mask_end', data=attribution_data['gene_mask_end'])
- f.attrs['model_name'] = 'test_model'
+ f.attrs['model_name'] = 'v1_rep0'
f.attrs['genome'] = 'hg38'
return h5_path
diff --git a/tests/test_attribution.py b/tests/test_attribution.py
index 48acd4a..d0d33f8 100644
--- a/tests/test_attribution.py
+++ b/tests/test_attribution.py
@@ -14,7 +14,7 @@ def test_AttributionResult(attribution_h5_file, attribution_data):
with AttributionResult(str(attribution_h5_file), tss_distance=10_000, num_workers=1) as ar:
assert len(ar.genes) == 10
assert ar.genes == attribution_data['genes']
- assert ar.model_name == 'test_model'
+ assert ar.model_name == 'v1_rep0'
assert ar.genome == 'hg38'
assert ar.genes == attribution_data['genes']
@@ -73,7 +73,7 @@ def test_AttributionResult(attribution_h5_file, attribution_data):
with AttributionResult([str(attribution_h5_file), str(attribution_h5_file)], tss_distance=10_000) as ar:
assert len(ar.genes) == 10
assert ar.genes == attribution_data['genes']
- assert ar.model_name == ['test_model', 'test_model']
+ assert ar.model_name == ['v1_rep0', 'v1_rep0']
assert ar.genome == 'hg38'
with AttributionResult(str(attribution_h5_file), tss_distance=1_000_000) as ar:
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 60ba05c..4439c03 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -6,7 +6,7 @@
from decima.cli import main
from conftest import device
-from decima.constants import DECIMA_CONTEXT_SIZE
+from decima.constants import DECIMA_CONTEXT_SIZE, DEFAULT_ENSEMBLE
def test_cli_main():
@@ -254,7 +254,7 @@ def test_cli_vep_all_tasks_ensemble_custom_genome(tmp_path):
"vep",
"-v", "tests/data/variants.tsv",
"-o", str(output_file),
- "--model", "ensemble",
+ "--model", DEFAULT_ENSEMBLE,
"--device", device,
"--max-distance", "20000",
"--chunksize", "5",
@@ -277,7 +277,7 @@ def test_cli_vep_all_tasks_ensemble(tmp_path):
"vep",
"-v", "tests/data/variants.tsv",
"-o", str(output_file),
- "--model", "ensemble",
+ "--model", DEFAULT_ENSEMBLE,
"--device", device,
"--max-distance", "20000",
"--chunksize", "5",
diff --git a/tests/test_interpret_attribution.py b/tests/test_interpret_attribution.py
index 96014e3..e2bca28 100644
--- a/tests/test_interpret_attribution.py
+++ b/tests/test_interpret_attribution.py
@@ -214,9 +214,9 @@ def test_predict_save_attributions_single_gene(tmp_path):
@pytest.mark.long_running
def test_predict_save_attributions_single_gene_list_models(tmp_path):
# download models
- download_decima_weights(0, str(tmp_path))
- download_decima_weights(1, str(tmp_path))
- download_decima_metadata(str(tmp_path))
+ download_decima_weights("v1_rep0", str(tmp_path))
+ download_decima_weights("v1_rep1", str(tmp_path))
+ download_decima_metadata("v1_rep0", str(tmp_path))
output_prefix = tmp_path / "SPI1"
predict_attributions_seqlet_calling(
diff --git a/tests/test_io.py b/tests/test_io.py
index 23acc01..36a5a7f 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -35,14 +35,14 @@ def test_AttributionWriter(tmp_path):
nucleotide_indices = np.random.randint(0, 4, DECIMA_CONTEXT_SIZE)
seqs[nucleotide_indices, np.arange(DECIMA_CONTEXT_SIZE)] = 1
- with AttributionWriter(str(h5_file), genes, "test_model", bigwig=False) as writer:
+ with AttributionWriter(str(h5_file), genes, "v1_rep0", bigwig=False) as writer:
writer.add("STRADA", seqs, attrs)
assert h5_file.exists()
with h5py.File(h5_file, "r") as f:
assert set(f.keys()) == {"genes", "attribution", "sequence", "gene_mask_start", "gene_mask_end"}
- assert f.attrs["model_name"] == "test_model"
+ assert f.attrs["model_name"] == "v1_rep0"
assert f["attribution"].shape == (1, 4, DECIMA_CONTEXT_SIZE)
np.testing.assert_array_equal(convert_input_type(f["sequence"][0], "one_hot", input_type="indices"), seqs)
np.testing.assert_array_almost_equal(f["attribution"][0], attrs)
@@ -51,13 +51,13 @@ def test_AttributionWriter(tmp_path):
assert f["gene_mask_end"][0] == 223490
h5_file = tmp_path / "test_bigwig.h5"
- with AttributionWriter(str(h5_file), genes, "test_model", bigwig=True) as writer:
+ with AttributionWriter(str(h5_file), genes, "v1_rep0", bigwig=True) as writer:
writer.add("STRADA", seqs, attrs)
assert h5_file.with_suffix(".bigwig").exists()
h5_file = tmp_path / "test_bigwig_custom.h5"
- with AttributionWriter(str(h5_file), genes, "test_model", bigwig=True, custom_genes=True) as writer:
+ with AttributionWriter(str(h5_file), genes, "v1_rep0", bigwig=True, custom_genes=True) as writer:
writer.add("STRADA", seqs, attrs, gene_mask_start=100, gene_mask_end=200)
with h5py.File(h5_file, "r") as f:
diff --git a/tests/test_lightning.py b/tests/test_lightning.py
index d1af0b0..40287f1 100644
--- a/tests/test_lightning.py
+++ b/tests/test_lightning.py
@@ -1,6 +1,6 @@
import pytest
import torch
-from decima.constants import DECIMA_CONTEXT_SIZE, NUM_CELLS
+from decima.constants import DECIMA_CONTEXT_SIZE, MODEL_METADATA, DEFAULT_ENSEMBLE
from decima.data.dataset import VariantDataset
from decima.model.lightning import LightningModel, GeneMaskLightningModel
from decima.model.metrics import WarningType
@@ -10,31 +10,34 @@
@pytest.fixture
def lightning_model():
- model = LightningModel(model_params={'n_tasks': NUM_CELLS, 'init_borzoi': False}, name='v1_rep0').to(device)
+ model_name = "v1_rep0"
+ metadata = MODEL_METADATA[model_name]
+ model = LightningModel(model_params={'n_tasks': metadata['num_tasks'], 'init_borzoi': False}, name=model_name).to(device)
return model
@pytest.mark.long_running
def test_LightningModel_predict_step(lightning_model):
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
seq = torch.randn(1, 5, DECIMA_CONTEXT_SIZE).to(device)
preds = lightning_model.predict_step(seq, 0)
- assert preds.shape == (1, NUM_CELLS, 1)
+ assert preds.shape == (1, metadata['num_tasks'], 1)
batch = {"seq": seq, "warning": [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME]}
results = lightning_model.predict_step(batch, 1)
- assert results["expression"].shape == (1, NUM_CELLS, 1)
+ assert results["expression"].shape == (1, metadata['num_tasks'], 1)
assert results["warnings"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME]
batch = {
"seq": seq.to(device),
"warning": [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME],
- "pred_expr": {"v1_rep0": torch.ones((1, NUM_CELLS), device=device)}
+ "pred_expr": {"v1_rep0": torch.ones((1, metadata['num_tasks']), device=device)}
}
results = lightning_model.predict_step(batch, 1)
- assert results["expression"].shape == (1, NUM_CELLS, 1)
- assert (results["expression"] == torch.ones((1, NUM_CELLS, 1), device=device)).all()
+ assert results["expression"].shape == (1, metadata['num_tasks'], 1)
+ assert (results["expression"] == torch.ones((1, metadata['num_tasks'], 1), device=device)).all()
assert results["warnings"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME]
@@ -42,7 +45,10 @@ def test_LightningModel_predict_step(lightning_model):
def test_LightningModel_predict_on_dataset(lightning_model, df_variant):
dataset = VariantDataset(df_variant, model_name="v1_rep0")
results = lightning_model.predict_on_dataset(dataset)
- assert results["expression"].shape == (82, NUM_CELLS)
+
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
+
+ assert results["expression"].shape == (82, metadata["num_tasks"])
assert results["warnings"]['unknown'] == 0
assert results["warnings"]['allele_mismatch_with_reference_genome'] == 13
@@ -51,7 +57,8 @@ def test_LightningModel_predict_on_dataset(lightning_model, df_variant):
def test_LightningModel_predict_on_dataset_ensemble(lightning_model, df_variant):
dataset = VariantDataset(df_variant)
results = lightning_model.predict_on_dataset(dataset)
- assert results["expression"].shape == (82, NUM_CELLS)
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
+ assert results["expression"].shape == (82, metadata["num_tasks"])
assert results["warnings"]['unknown'] == 0
assert results["warnings"]['allele_mismatch_with_reference_genome'] == 13
@@ -59,9 +66,10 @@ def test_LightningModel_predict_on_dataset_ensemble(lightning_model, df_variant)
@pytest.mark.long_running
def test_GeneMaskLightningModel_forward():
seq = torch.randn(1, 4, DECIMA_CONTEXT_SIZE).to(device)
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
model = GeneMaskLightningModel(
gene_mask_start=200_000, gene_mask_end=300_000,
- model_params={"n_tasks": NUM_CELLS, "init_borzoi": False}, name="v1_rep0"
+ model_params={"n_tasks": metadata["num_tasks"], "init_borzoi": False}, name=metadata["name"]
).to(device)
preds = model(seq)
- assert preds.shape == (1, NUM_CELLS, 1)
+ assert preds.shape == (1, metadata["num_tasks"], 1)
diff --git a/tests/test_predict_gene_expression.py b/tests/test_predict_gene_expression.py
index 66954e9..2dd54c6 100644
--- a/tests/test_predict_gene_expression.py
+++ b/tests/test_predict_gene_expression.py
@@ -1,4 +1,5 @@
import pytest
+from decima.constants import DEFAULT_ENSEMBLE
from decima.tools.inference import predict_gene_expression
from conftest import device
@@ -19,7 +20,7 @@ def test_predict_gene_expression():
ad = predict_gene_expression(
genes=["SPI1", "GATA1"],
- model="ensemble", device=device,
+ model=DEFAULT_ENSEMBLE, device=device,
save_replicates=True,
)
diff --git a/tests/test_sequence.py b/tests/test_sequence.py
index c0876aa..e2a964e 100644
--- a/tests/test_sequence.py
+++ b/tests/test_sequence.py
@@ -1,8 +1,8 @@
-
+from decima.constants import DECIMA_CONTEXT_SIZE
from decima.utils.sequence import prepare_mask_gene
def test_mask_gene():
mask = prepare_mask_gene(100, 200)
- assert mask.shape == (1, 524288)
+ assert mask.shape == (1, DECIMA_CONTEXT_SIZE)
assert mask[0, 150].item() == 1.0
diff --git a/tests/test_vep.py b/tests/test_vep.py
index 650ffd0..45ad91b 100644
--- a/tests/test_vep.py
+++ b/tests/test_vep.py
@@ -5,6 +5,7 @@
import pyarrow.parquet as pq
from scipy.stats import pearsonr
+from decima.constants import DEFAULT_ENSEMBLE, DECIMA_CONTEXT_SIZE, MODEL_METADATA
from decima.core.result import DecimaResult
from decima.hub import load_decima_model
from decima.data.dataset import VariantDataset
@@ -92,9 +93,10 @@ def test_VariantDataset(df_variant):
]
assert len(dataset) == 82 * 2
- assert dataset[0]['seq'].shape == (5, 524288)
+ assert dataset[0]['seq'].shape == (5, DECIMA_CONTEXT_SIZE)
- assert dataset[0]['pred_expr']['v1_rep0'].shape == (8856,)
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
+ assert dataset[0]['pred_expr']['v1_rep0'].shape == (metadata['num_tasks'],)
assert not dataset[0]['pred_expr']['v1_rep0'].isnan().any()
assert dataset[1]['pred_expr']['v1_rep0'].isnan().all()
assert not dataset[2]['pred_expr']['v1_rep0'].isnan().any()
@@ -113,7 +115,7 @@ def test_VariantDataset(df_variant):
assert cols.tolist() == [38435, 38435] # should be the same for both
for i in range(len(dataset)):
- assert dataset[i]['seq'].shape == (5, 524288)
+ assert dataset[i]['seq'].shape == (5, DECIMA_CONTEXT_SIZE)
rows, cols = np.where(dataset[162]['seq'] != dataset[163]['seq'])
assert cols.min() == 505705 # the positions before should not be effected.
@@ -122,11 +124,11 @@ def test_VariantDataset(df_variant):
dataset = VariantDataset(df_variant, max_seq_shift=100)
assert len(dataset) == 82 * 2 * 201
- assert dataset[0]['seq'].shape == (5, 524288)
+ assert dataset[0]['seq'].shape == (5, DECIMA_CONTEXT_SIZE)
for i in range(20):
assert dataset[i]["warning"] == []
- assert dataset[i]['seq'].shape == (5, 524288)
+ assert dataset[i]['seq'].shape == (5, DECIMA_CONTEXT_SIZE)
assert dataset[44 * 2 * 201]["warning"] == [WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME]
@@ -138,58 +140,60 @@ def test_VariantDataset(df_variant):
@pytest.mark.long_running
def test_VariantDataset_dataloader(df_variant):
- dataset = VariantDataset(df_variant, model_name="ensemble")
+ dataset = VariantDataset(df_variant, model_name=DEFAULT_ENSEMBLE)
dl = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=0, collate_fn=dataset.collate_fn)
batches = iter(dl)
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
batch = next(batches)
- assert batch["seq"].shape == (64, 5, 524288)
+ assert batch["seq"].shape == (64, 5, DECIMA_CONTEXT_SIZE)
assert batch["warning"] == []
- assert batch["pred_expr"]["v1_rep0"].shape == (64, 8856)
- assert batch["pred_expr"]["v1_rep1"].shape == (64, 8856)
- assert batch["pred_expr"]["v1_rep2"].shape == (64, 8856)
- assert batch["pred_expr"]["v1_rep3"].shape == (64, 8856)
+ assert batch["pred_expr"]["v1_rep0"].shape == (64, metadata['num_tasks'])
+ assert batch["pred_expr"]["v1_rep1"].shape == (64, metadata['num_tasks'])
+ assert batch["pred_expr"]["v1_rep2"].shape == (64, metadata['num_tasks'])
+ assert batch["pred_expr"]["v1_rep3"].shape == (64, metadata['num_tasks'])
batch = next(batches)
- assert batch["seq"].shape == (64, 5, 524288)
+ assert batch["seq"].shape == (64, 5, DECIMA_CONTEXT_SIZE)
assert len(batch["warning"]) > 0
assert WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME in batch["warning"]
- assert batch["pred_expr"]["v1_rep0"].shape == (64, 8856)
- assert batch["pred_expr"]["v1_rep1"].shape == (64, 8856)
- assert batch["pred_expr"]["v1_rep2"].shape == (64, 8856)
- assert batch["pred_expr"]["v1_rep3"].shape == (64, 8856)
+ assert batch["pred_expr"]["v1_rep0"].shape == (64, metadata['num_tasks'])
+ assert batch["pred_expr"]["v1_rep1"].shape == (64, metadata['num_tasks'])
+ assert batch["pred_expr"]["v1_rep2"].shape == (64, metadata['num_tasks'])
+ assert batch["pred_expr"]["v1_rep3"].shape == (64, metadata['num_tasks'])
@pytest.mark.long_running
def test_VariantDataset_dataloader_vcf():
df_variant = next(read_vcf_chunks("tests/data/test.vcf", 10000))
- dataset = VariantDataset(df_variant, model_name="ensemble", max_distance=20000)
+ dataset = VariantDataset(df_variant, model_name=DEFAULT_ENSEMBLE, max_distance=20000)
dl = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=0, collate_fn=dataset.collate_fn)
batches = iter(dl)
+ metadata = MODEL_METADATA[MODEL_METADATA[DEFAULT_ENSEMBLE][0]]
batch = next(batches)
- assert batch["seq"].shape == (8, 5, 524288)
+ assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE)
assert batch["warning"] == []
- assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856)
+ assert batch["pred_expr"]['v1_rep0'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep1'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep2'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep3'].shape == (8, metadata['num_tasks'])
batch = next(batches)
- assert batch["seq"].shape == (8, 5, 524288)
+ assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE)
assert batch["warning"] == []
- assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856)
+ assert batch["pred_expr"]['v1_rep0'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep1'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep2'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep3'].shape == (8, metadata['num_tasks'])
batch = next(batches)
- assert batch["seq"].shape == (8, 5, 524288)
+ assert batch["seq"].shape == (8, 5, DECIMA_CONTEXT_SIZE)
assert len(batch["warning"]) > 0
- assert batch["pred_expr"]['v1_rep0'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep1'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep2'].shape == (8, 8856)
- assert batch["pred_expr"]['v1_rep3'].shape == (8, 8856)
+ assert batch["pred_expr"]['v1_rep0'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep1'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep2'].shape == (8, metadata['num_tasks'])
+ assert batch["pred_expr"]['v1_rep3'].shape == (8, metadata['num_tasks'])
@pytest.mark.long_running
@@ -198,7 +202,8 @@ def test_predict_variant_effect(df_variant):
query = "cell_type == 'CD8-positive, alpha-beta T cell'"
cells = DecimaResult.load().query_cells(query)
- df, warnings, num_variants = _predict_variant_effect(df_variant, model=0, tasks=query, device=device, max_distance=5000)
+ model = load_decima_model(0, device)
+ df, warnings, num_variants = _predict_variant_effect(df_variant, model=model, tasks=query, device=device, max_distance=5000)
assert num_variants == 4
assert df.shape == (4, 273)
@@ -225,7 +230,7 @@ def test_predict_variant_effect_save(df_variant, tmp_path):
predict_variant_effect(
df_variant,
output_pq=str(output_file),
- model="ensemble",
+ model=DEFAULT_ENSEMBLE,
tasks=query,
device=device,
max_distance=5000,
@@ -305,7 +310,7 @@ def test_predict_variant_effect_vcf_ensemble(tmp_path):
predict_variant_effect(
"tests/data/test.vcf",
output_pq=str(output_file),
- model="ensemble",
+ model=DEFAULT_ENSEMBLE,
device=device,
max_distance=20000,
)
@@ -321,7 +326,7 @@ def test_predict_variant_effect_vcf_ensemble_replicates(tmp_path):
predict_variant_effect(
"tests/data/test.vcf",
output_pq=str(output_file),
- model="ensemble",
+ model=DEFAULT_ENSEMBLE,
device=device,
max_distance=20000,
save_replicates=True,