diff --git a/README.md b/README.md index cd0c8f5..bfdaf8a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,27 @@ +**EXTREMELY IMPORTANT NOTE ABOUT DATA**: The data for CpGPT is hosted on an AWS S3 bucket with a requester-pays configuration. This means that you will be charged by AWS for the data you download. **The total cost for downloading all datasets will *exceed* $150.00 USD at the time of writing.** We will do our best to estimate the cost of downloading for you, but we cannot guarantee the accuracy of our estimate. To estimate the cost of downloading part of the bucket, you can use the following command: + +```bash +S3_BUCKET='s3://cpgpt-lucascamillo-public' + +# Replace with the s3 path you want to estimate the cost for example: +BUCKET_PREFIX="${S3_BUCKET}//" +# For example: BUCKET_PREFIX="${S3_BUCKET}/data/cpgcorpus/raw/" + +# To estimate the size of data and cost before downloading, use: +TOTAL_SIZE=$(aws s3 ls --recursive --human-readable --summarize --request-payer requester $BUCKET_PREFIX | grep "Total Size") + +# Likely cost is $0.09 per GB downloaded. +TOTAL_SIZE_GB=$(grep -oP 'Total Size: \K[0-9.]+' <<< "$TOTAL_SIZE") +EXPECTED_COST=$(echo "$TOTAL_SIZE_GB" | python -c "print(float(input()) * 0.09)") + +# Print the estimated cost +echo "Estimated cost to download all data: \$$EXPECTED_COST USD for $TOTAL_SIZE_GB GB" +``` + + ## 📋 Table of Contents - [📖 Overview](#-overview) @@ -42,17 +63,31 @@ CpGPT is a foundation model for DNA methylation, trained on genome-wide DNA meth ### Installation Instructions -We recommend using `poetry` for installation: +Typical installation (`pip`): + +```bash +pip install cpgpt +``` + +For local installation: ```bash # Clone the repository git clone https://github.com/lcamillo/CpGPT.git cd CpGPT -# Install poetry if not available -pip install poetry +# Install poetry if you don't already have it +curl -sSL https://install.python-poetry.org | python3 - + +# Install conda if you don't already have it +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh +bash ~/miniconda.sh + +# Create a conda environment +conda create -y -n cpgpt python=3.12 pygraphviz +conda activate cpgpt -# Install dependencies with Poetry +# Install dependencies in conda environment poetry install ``` @@ -140,7 +175,7 @@ You'll need to input: Verify your setup with this command that lists the contents (without downloading): ```bash -aws s3 ls s3://cpgpt-lucascamillo-public/data/cpgcorpus/raw/ --requester-payer requester +aws s3 ls --request-payer requester s3://cpgpt-lucascamillo-public/data/cpgcorpus/raw/ ``` You should see a list of GSE folders if your configuration is correct. @@ -152,10 +187,10 @@ You should see a list of GSE folders if your configuration is correct.
Download the Full Corpus -To download the entire CpGCorpus from our S3 bucket, run the following command: +To download the entire CpGCorpus from our S3 bucket, run the following command. **SINCE YOU ARE USING A REQUESTER-PAYS-FOR-DATA CONFIGURATION, YOU WILL BE CHARGED BY AWS FOR THE DATA YOU DOWNLOAD, WHICH IS ~$100.00 USD AT THE TIME OF WRITING.** ```bash -aws s3 sync s3://cpgpt-lucascamillo-public/data/cpgcorpus/raw ./data/cpgcorpus/raw --requester-payer requester +aws s3 sync s3://cpgpt-lucascamillo-public/data/cpgcorpus/raw ./data/cpgcorpus/raw --request-payer requester ```
diff --git a/tutorials/quick_setup.ipynb b/tutorials/quick_setup.ipynb index 46bdf54..1d093d5 100644 --- a/tutorials/quick_setup.ipynb +++ b/tutorials/quick_setup.ipynb @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -98,7 +98,7 @@ "42" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -152,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -162,8 +162,14 @@ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mInitializing class CpGPTInferencer.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mUsing device: cuda.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mUsing dependencies directory: ../dependencies\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mUsing data directory: ../data\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mThere are 18 CpGPT models available such as age, age_cot, average_adultweight, etc.\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mUsing data directory: ../data\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mThere are 19 CpGPT models available such as age, age_cot, average_adultweight, etc.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mThere are 2088 GSE datasets available such as GSE100184, GSE100208, GSE100209, etc.\u001b[0m\n" ] } @@ -184,12 +190,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The already-processed dependencies contain the sequence embeddings for both human (`s3://cpgpt-lucascamillo-public/dependencies/human`) and several mammalian species (`s3://cpgpt-lucascamillo-public/dependencies/mammalian`). Here, let's use the human as an example:" + "The already-processed dependencies contain the sequence embeddings for both human (`s3://cpgpt-lucascamillo-public/dependencies/human`) and several mammalian species (`s3://cpgpt-lucascamillo-public/dependencies/mammalian`). Here, let's use the human as an example:\n", + "\n", + "**WARNING**: This operation uses a requester-pays AWS copy operation. Please be mindful that this will likely incur additional costs up to approximately $4.00." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -198,16 +206,6 @@ "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mDependencies for human already exist at ../dependencies/human (skipping download).\u001b[0m\n" ] - }, - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -230,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -274,16 +272,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mModel checkpoint already exists at ../dependencies/model/weights/cancer.ckpt (skipping download).\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mModel config already exists at ../dependencies/model/config/cancer.yaml (skipping download).\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mModel vocabulary already exists at ../dependencies/model/vocab/cancer.json (skipping download).\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mDownloading model checkpoint to ../dependencies/model/weights/cancer.ckpt.\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mDownloading model config to ../dependencies/model/config/cancer.yaml\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mDownloading model vocabulary to ../dependencies/model/vocab/cancer.json\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mSuccessfully downloaded model 'cancer'.\u001b[0m\n" ] } @@ -302,21 +300,21 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mLoaded CpGPT model config.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mInstantiated CpGPT model from config.\u001b[0m\n" + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mLoaded CpGPT model config.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mInstantiated CpGPT model from config.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mUsing device: cuda.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mLoading checkpoint from: ../dependencies/model/weights/cancer.ckpt\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mCheckpoint loaded into the model.\u001b[0m\n" @@ -358,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -367,16 +365,6 @@ "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTInferencer\u001b[0m: \u001b[1mDataset GSE182215 already exists at ../data/cpgcorpus/raw/GSE182215 (skipping download).\u001b[0m\n" ] - }, - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -392,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -625,7 +613,7 @@ "[5 rows x 485578 columns]" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -652,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -662,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -895,7 +883,7 @@ "[5 rows x 19948 columns]" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -907,10 +895,11 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ + "os.makedirs(os.path.dirname(ARROW_DF_FILTERED_PATH), exist_ok=True)\n", "df.to_feather(ARROW_DF_FILTERED_PATH)" ] }, @@ -930,7 +919,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -939,7 +928,13 @@ "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mDNALLMEmbedder\u001b[0m: \u001b[1mInitializing class DNALLMEmbedder.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mDNALLMEmbedder\u001b[0m: \u001b[1mGenome files will be stored under ../dependencies/human/genomes.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mDNALLMEmbedder\u001b[0m: \u001b[1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mDNALLMEmbedder\u001b[0m: \u001b[1mDNA embeddings will be stored under ../dependencies/human/dna_embeddings and subdirectories.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mDNALLMEmbedder\u001b[0m: \u001b[1mEnsembl metadata dictionary loaded successfully\u001b[0m\n" ] } @@ -950,7 +945,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -958,7 +953,13 @@ "output_type": "stream", "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mIlluminaMethylationProber\u001b[0m: \u001b[1mInitializing class IlluminaMethylationProber.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mIlluminaMethylationProber\u001b[0m: \u001b[1mIllumina methylation manifest files will be stored under ../dependencies/human/manifests.\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mIlluminaMethylationProber\u001b[0m: \u001b[1mIllumina methylation manifest files will be stored under ../dependencies/human/manifests.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mIlluminaMethylationProber\u001b[0m: \u001b[1mIllumina metadata dictionary loaded successfully.\u001b[0m\n" ] } @@ -969,7 +970,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -978,10 +979,47 @@ "text": [ "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mInitializing class CpGPTDataSaver.\u001b[0m\n", "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mDataset folders will be stored under ../data/tutorials/processed/quick_setup.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mLoaded existing dataset metrics.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mLoaded existing genomic locations.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mStarting file processing.\u001b[0m\n", - "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1m1 files already processed. Skipping those.\u001b[0m\n" + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mNo existing dataset metrics found. Please process files.\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mNo existing genomic locations found. Please process files.\u001b[0m\n", + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mStarting file processing.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "af449af84ee449598860f2cdaeb62507", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[33m\u001b[1mNo species column found. Defaulting to homo_sapiens.\u001b[0m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\u001b[1m\u001b[34mcpgpt\u001b[0m\u001b[1m\u001b[0m: \u001b[36mCpGPTDataSaver\u001b[0m: \u001b[1mFile processing completed.\u001b[0m\n"
      ]
     }
    ],
@@ -1009,7 +1047,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
@@ -1085,7 +1123,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 20,
    "metadata": {},
    "outputs": [
     {
@@ -1093,6 +1131,7 @@
      "output_type": "stream",
      "text": [
       "Using 16bit Automatic Mixed Precision (AMP)\n",
+      "You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.\n",
       "GPU available: True (cuda), used: True\n",
       "TPU available: False, using: 0 TPU cores\n",
       "HPU available: False, using: 0 HPUs\n"
@@ -1112,14 +1151,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "You are using a CUDA device ('NVIDIA A10G') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
+      "You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
      ]
     },
     {
@@ -1140,7 +1179,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "7053e95c364b41fea56a7ba86aa50639",
+       "model_id": "0fbd4a31791746fb9313a571f662de06",
        "version_major": 2,
        "version_minor": 0
       },
@@ -1155,7 +1194,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/data/miniforge3/envs/cpgpt/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.\n"
+      "/home/ubuntu/miniconda3/envs/cpgpt/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=29` in the `DataLoader` to improve performance.\n"
      ]
     },
     {
@@ -1180,22 +1219,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'sample_embedding': tensor([[-0.0580, -0.0656, -0.0079,  ..., -0.0076, -0.1528, -0.0810],\n",
-       "         [-0.0750, -0.0756, -0.0059,  ..., -0.0241, -0.1455, -0.0859],\n",
-       "         [ 0.0325, -0.0363, -0.1413,  ..., -0.0782, -0.1645, -0.1452],\n",
+       "{'sample_embedding': tensor([[-0.0580, -0.0655, -0.0080,  ..., -0.0076, -0.1528, -0.0811],\n",
+       "         [-0.0750, -0.0754, -0.0058,  ..., -0.0240, -0.1456, -0.0862],\n",
+       "         [ 0.0325, -0.0363, -0.1416,  ..., -0.0783, -0.1642, -0.1454],\n",
        "         ...,\n",
-       "         [-0.0838, -0.0999, -0.0362,  ...,  0.0130, -0.0523, -0.1371],\n",
-       "         [-0.0597, -0.0813, -0.0225,  ...,  0.0113, -0.1043, -0.0967],\n",
-       "         [-0.0248, -0.0809, -0.0548,  ...,  0.0250, -0.0796, -0.0924]])}"
+       "         [-0.0838, -0.0998, -0.0364,  ...,  0.0130, -0.0521, -0.1371],\n",
+       "         [-0.0598, -0.0812, -0.0226,  ...,  0.0112, -0.1045, -0.0966],\n",
+       "         [-0.0247, -0.0808, -0.0547,  ...,  0.0251, -0.0795, -0.0924]])}"
       ]
      },
-     "execution_count": 19,
+     "execution_count": 22,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1213,7 +1252,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 23,
    "metadata": {},
    "outputs": [
     {
@@ -1234,7 +1273,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "8c6701aaadf441ac991829def24e9af9",
+       "model_id": "05b45273313c437aaeb47d549bde987d",
        "version_major": 2,
        "version_minor": 0
       },
@@ -1267,53 +1306,53 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 24,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'pred_conditions': tensor([[ 0.3667],\n",
-       "         [ 0.0154],\n",
-       "         [ 6.9766],\n",
-       "         [ 0.4402],\n",
-       "         [-0.6748],\n",
-       "         [-0.2629],\n",
-       "         [-0.7241],\n",
-       "         [-2.7090],\n",
-       "         [-0.1346],\n",
-       "         [-0.4250],\n",
-       "         [-0.0770],\n",
-       "         [ 0.1849],\n",
-       "         [-0.0803],\n",
-       "         [ 0.4863],\n",
+       "{'pred_conditions': tensor([[ 0.3674],\n",
+       "         [ 0.0183],\n",
+       "         [ 6.9727],\n",
+       "         [ 0.4380],\n",
+       "         [-0.6777],\n",
+       "         [-0.2605],\n",
+       "         [-0.7271],\n",
+       "         [-2.7109],\n",
+       "         [-0.1348],\n",
+       "         [-0.4270],\n",
+       "         [-0.0753],\n",
+       "         [ 0.1816],\n",
+       "         [-0.0773],\n",
+       "         [ 0.4890],\n",
        "         [ 2.5410],\n",
-       "         [-1.4512],\n",
-       "         [-3.1914],\n",
-       "         [ 3.1328],\n",
-       "         [-0.7524],\n",
-       "         [-1.4854],\n",
+       "         [-1.4492],\n",
+       "         [-3.1934],\n",
+       "         [ 3.1348],\n",
+       "         [-0.7500],\n",
+       "         [-1.4834],\n",
        "         [-1.8594],\n",
        "         [-1.4404],\n",
-       "         [-2.0391],\n",
-       "         [-1.4297],\n",
-       "         [-2.4863],\n",
-       "         [-2.0703],\n",
-       "         [-2.4922],\n",
-       "         [-2.4121],\n",
-       "         [-2.3438],\n",
+       "         [-2.0371],\n",
+       "         [-1.4287],\n",
+       "         [-2.4883],\n",
+       "         [-2.0684],\n",
+       "         [-2.4902],\n",
+       "         [-2.4180],\n",
+       "         [-2.3418],\n",
        "         [-1.7637],\n",
-       "         [-1.4941],\n",
-       "         [-2.4941],\n",
-       "         [-0.4998],\n",
-       "         [-0.5352],\n",
-       "         [-1.9775],\n",
-       "         [-3.3359],\n",
-       "         [-1.0107],\n",
-       "         [-0.5669]], dtype=torch.float16)}"
+       "         [-1.4990],\n",
+       "         [-2.4922],\n",
+       "         [-0.5005],\n",
+       "         [-0.5342],\n",
+       "         [-1.9814],\n",
+       "         [-3.3320],\n",
+       "         [-1.0078],\n",
+       "         [-0.5684]], dtype=torch.float16)}"
       ]
      },
-     "execution_count": 21,
+     "execution_count": 24,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1338,7 +1377,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [
     {
@@ -1347,7 +1386,7 @@
        "['cg00000292', 'cg00002426', 'cg00003994', 'cg00005847', 'cg00008493']"
       ]
      },
-     "execution_count": 22,
+     "execution_count": 25,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1361,7 +1400,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 26,
    "metadata": {},
    "outputs": [
     {
@@ -1370,7 +1409,7 @@
        "['16:28878778', '3:57757815', '7:15686236', '2:176164344', '14:93347430']"
       ]
      },
-     "execution_count": 23,
+     "execution_count": 26,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1384,7 +1423,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 27,
    "metadata": {},
    "outputs": [
     {
@@ -1405,7 +1444,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "5295ca43bd854f46874501e6f1d6c4ce",
+       "model_id": "be16312fa89e400782a2ca4190498182",
        "version_major": 2,
        "version_minor": 0
       },
@@ -1447,23 +1486,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 28,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'pred_meth': tensor([[0.8501, 0.8633, 0.0503,  ..., 0.0348, 0.9292, 0.7095],\n",
-       "         [0.8706, 0.8833, 0.0492,  ..., 0.0340, 0.9419, 0.7275],\n",
-       "         [0.3799, 0.4067, 0.2776,  ..., 0.3308, 0.4133, 0.2937],\n",
+       "{'pred_meth': tensor([[0.8501, 0.8628, 0.0503,  ..., 0.0348, 0.9292, 0.7095],\n",
+       "         [0.8711, 0.8833, 0.0492,  ..., 0.0340, 0.9419, 0.7280],\n",
+       "         [0.3799, 0.4065, 0.2776,  ..., 0.3313, 0.4131, 0.2937],\n",
        "         ...,\n",
-       "         [0.7925, 0.7881, 0.0529,  ..., 0.0367, 0.9351, 0.7075],\n",
-       "         [0.8247, 0.8291, 0.0523,  ..., 0.0337, 0.9351, 0.7031],\n",
-       "         [0.6494, 0.4927, 0.0576,  ..., 0.0337, 0.9385, 0.7085]],\n",
+       "         [0.7925, 0.7886, 0.0529,  ..., 0.0368, 0.9351, 0.7075],\n",
+       "         [0.8247, 0.8286, 0.0523,  ..., 0.0337, 0.9351, 0.7031],\n",
+       "         [0.6494, 0.4927, 0.0576,  ..., 0.0337, 0.9375, 0.7085]],\n",
        "        dtype=torch.float16)}"
       ]
      },
-     "execution_count": 25,
+     "execution_count": 28,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1482,7 +1521,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": 29,
    "metadata": {},
    "outputs": [
     {
@@ -1503,7 +1542,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "7acc5665c097411682686d6c4e3e1eed",
+       "model_id": "e4e3cae7e9ac4f2fa33e2990b5e78c63",
        "version_major": 2,
        "version_minor": 0
       },
@@ -1541,23 +1580,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 30,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'pred_meth': tensor([[0.8516, 0.8486, 0.0482,  ..., 0.0344, 0.9248, 0.7080],\n",
-       "         [0.8691, 0.8696, 0.0484,  ..., 0.0330, 0.9434, 0.7383],\n",
-       "         [0.4485, 0.4104, 0.3074,  ..., 0.3823, 0.3540, 0.2976],\n",
+       "{'pred_meth': tensor([[0.8486, 0.8452, 0.0482,  ..., 0.0344, 0.9233, 0.7075],\n",
+       "         [0.8672, 0.8662, 0.0484,  ..., 0.0326, 0.9434, 0.7412],\n",
+       "         [0.4385, 0.4065, 0.2893,  ..., 0.3835, 0.3711, 0.2949],\n",
        "         ...,\n",
-       "         [0.7856, 0.7690, 0.0512,  ..., 0.0355, 0.9326, 0.6919],\n",
-       "         [0.8271, 0.8198, 0.0520,  ..., 0.0332, 0.9331, 0.6934],\n",
-       "         [0.6523, 0.4868, 0.0560,  ..., 0.0331, 0.9297, 0.7061]],\n",
+       "         [0.7832, 0.7661, 0.0508,  ..., 0.0352, 0.9321, 0.6914],\n",
+       "         [0.8291, 0.8237, 0.0523,  ..., 0.0335, 0.9326, 0.6953],\n",
+       "         [0.6528, 0.4924, 0.0550,  ..., 0.0326, 0.9282, 0.7065]],\n",
        "        dtype=torch.float16)}"
       ]
      },
-     "execution_count": 27,
+     "execution_count": 30,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1583,7 +1622,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 28,
+   "execution_count": 31,
    "metadata": {},
    "outputs": [
     {
@@ -1604,7 +1643,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "9f9ff4bf0f164c078749d80cd707aaae",
+       "model_id": "9c86847d114e45ccb15bdce75c0d67cd",
        "version_major": 2,
        "version_minor": 0
       },
@@ -1639,31 +1678,31 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 32,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'attention_weights': tensor([[[0.0011, 0.0010, 0.0009,  ...,    nan,    nan,    nan],\n",
-       "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
-       "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "{'attention_weights': tensor([[[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "          [0.0010, 0.0010, 0.0011,  ...,    nan,    nan,    nan],\n",
+       "          [0.0010, 0.0010, 0.0011,  ...,    nan,    nan,    nan],\n",
        "          ...,\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],\n",
        " \n",
-       "         [[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
-       "          [0.0010, 0.0011, 0.0010,  ...,    nan,    nan,    nan],\n",
-       "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "         [[0.0011, 0.0017, 0.0009,  ...,    nan,    nan,    nan],\n",
+       "          [0.0008, 0.0027, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "          [0.0010, 0.0010, 0.0011,  ...,    nan,    nan,    nan],\n",
        "          ...,\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],\n",
        " \n",
-       "         [[0.0121, 0.0115, 0.0097,  ...,    nan,    nan,    nan],\n",
-       "          [0.0114, 0.0123, 0.0098,  ...,    nan,    nan,    nan],\n",
-       "          [0.0105, 0.0110, 0.0222,  ...,    nan,    nan,    nan],\n",
+       "         [[0.0143, 0.0129, 0.0143,  ...,    nan,    nan,    nan],\n",
+       "          [0.0126, 0.0218, 0.0130,  ...,    nan,    nan,    nan],\n",
+       "          [0.0140, 0.0148, 0.0152,  ...,    nan,    nan,    nan],\n",
        "          ...,\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
@@ -1671,36 +1710,36 @@
        " \n",
        "         ...,\n",
        " \n",
-       "         [[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
-       "          [0.0010, 0.0011, 0.0011,  ...,    nan,    nan,    nan],\n",
+       "         [[0.0011, 0.0011, 0.0009,  ...,    nan,    nan,    nan],\n",
+       "          [0.0010, 0.0011, 0.0010,  ...,    nan,    nan,    nan],\n",
        "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
        "          ...,\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],\n",
        " \n",
-       "         [[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
-       "          [0.0010, 0.0011, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "         [[0.0011, 0.0010, 0.0009,  ...,    nan,    nan,    nan],\n",
        "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "          [0.0010, 0.0010, 0.0011,  ...,    nan,    nan,    nan],\n",
        "          ...,\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],\n",
        " \n",
-       "         [[0.0011, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
+       "         [[0.0011, 0.0010, 0.0011,  ...,    nan,    nan,    nan],\n",
        "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
        "          [0.0010, 0.0010, 0.0010,  ...,    nan,    nan,    nan],\n",
        "          ...,\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan],\n",
        "          [   nan,    nan,    nan,  ...,    nan,    nan,    nan]]]),\n",
-       " 'meth': tensor([[0.8164, 0.0323, 0.0501,  ...,    nan,    nan,    nan],\n",
-       "         [0.9417, 0.0925, 0.0540,  ...,    nan,    nan,    nan],\n",
-       "         [0.6220, 0.3588, 0.4511,  ...,    nan,    nan,    nan],\n",
+       " 'meth': tensor([[0.9530, 0.9361, 0.0252,  ...,    nan,    nan,    nan],\n",
+       "         [0.1304, 0.0519, 0.0876,  ...,    nan,    nan,    nan],\n",
+       "         [0.0592, 0.1145, 0.4986,  ...,    nan,    nan,    nan],\n",
        "         ...,\n",
-       "         [0.1128, 0.0809, 0.1357,  ...,    nan,    nan,    nan],\n",
-       "         [0.7781, 0.0938, 0.0571,  ...,    nan,    nan,    nan],\n",
-       "         [0.0147, 0.6906, 0.6015,  ...,    nan,    nan,    nan]]),\n",
+       "         [0.2004, 0.0160, 0.1507,  ...,    nan,    nan,    nan],\n",
+       "         [0.0799, 0.0552, 0.0529,  ...,    nan,    nan,    nan],\n",
+       "         [0.8575, 0.2002, 0.0232,  ...,    nan,    nan,    nan]]),\n",
        " 'mask_na': tensor([[False, False, False,  ...,  True,  True,  True],\n",
        "         [False, False, False,  ...,  True,  True,  True],\n",
        "         [False, False, False,  ...,  True,  True,  True],\n",
@@ -1708,37 +1747,26 @@
        "         [False, False, False,  ...,  True,  True,  True],\n",
        "         [False, False, False,  ...,  True,  True,  True],\n",
        "         [False, False, False,  ...,  True,  True,  True]]),\n",
-       " 'chroms': tensor([[ 2,  2,  2,  ..., -1, -1, -1],\n",
-       "         [16, 16, 16,  ..., -1, -1, -1],\n",
-       "         [ 0,  0,  0,  ..., -1, -1, -1],\n",
+       " 'chroms': tensor([[ 7,  7,  7,  ..., -1, -1, -1],\n",
+       "         [18, 18, 18,  ..., -1, -1, -1],\n",
+       "         [ 8,  8,  8,  ..., -1, -1, -1],\n",
        "         ...,\n",
-       "         [ 2,  2,  2,  ..., -1, -1, -1],\n",
-       "         [16, 16, 16,  ..., -1, -1, -1],\n",
-       "         [ 7,  7,  7,  ..., -1, -1, -1]], dtype=torch.int32),\n",
-       " 'positions': tensor([[  394198,   746800,  2140242,  ...,       -1,       -1,       -1],\n",
-       "         [ 1394034,  2756719,  2964113,  ...,       -1,       -1,       -1],\n",
-       "         [15410502, 24319359, 24902090,  ...,       -1,       -1,       -1],\n",
+       "         [ 8,  8,  8,  ..., -1, -1, -1],\n",
+       "         [19, 19, 19,  ..., -1, -1, -1],\n",
+       "         [10, 10, 10,  ..., -1, -1, -1]], dtype=torch.int32),\n",
+       " 'positions': tensor([[   52120,   179983,   401506,  ...,       -1,       -1,       -1],\n",
+       "         [  392130,  1313259,  5261595,  ...,       -1,       -1,       -1],\n",
+       "         [ 3917223,  7035962,  7484743,  ...,       -1,       -1,       -1],\n",
        "         ...,\n",
-       "         [  451100,   536298,   805725,  ...,       -1,       -1,       -1],\n",
-       "         [  625585,  1720432,  4859955,  ...,       -1,       -1,       -1],\n",
-       "         [  234647,   276727,   689597,  ...,       -1,       -1,       -1]],\n",
+       "         [  731823,  1455948,  2338275,  ...,       -1,       -1,       -1],\n",
+       "         [ 6348498, 19116462, 21945658,  ...,       -1,       -1,       -1],\n",
+       "         [  475153,   859982,  1041020,  ...,       -1,       -1,       -1]],\n",
        "        dtype=torch.int32)}"
       ]
      },
-     "execution_count": 29,
+     "execution_count": 32,
      "metadata": {},
      "output_type": "execute_result"
-    },
-    {
-     "ename": "",
-     "evalue": "",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
-      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
-      "\u001b[1;31mClick here for more info. \n",
-      "\u001b[1;31mView Jupyter log for further details."
-     ]
     }
    ],
    "source": [
@@ -1762,7 +1790,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.12.8"
+   "version": "3.12.9"
   }
  },
  "nbformat": 4,