diff --git a/.gitignore b/.gitignore index 4ed3c760b..659948f91 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,9 @@ leaderboard/rtd_token.txt # locally pre-trained models pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model data/physionet.org/ +examples/chestxray_vae_synthetic.png +examples/chestxray_vae_comparison.png +examples/chestxray_vae_conditional.png +examples/vae_model.pth +uv.lock +data/mimic-iv-clinical-database-demo-2.2/ diff --git a/examples/ChestXray-image-generation-GAN.ipynb b/examples/ChestXray-image-generation-GAN.ipynb index 954e86a2b..4a298667c 100644 --- a/examples/ChestXray-image-generation-GAN.ipynb +++ b/examples/ChestXray-image-generation-GAN.ipynb @@ -1,8 +1,16 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "cf3e101f", + "metadata": {}, + "source": [ + "# Chest X-Ray Image Generation using GAN" + ] + }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "9a77c2ec", "metadata": {}, "outputs": [], @@ -12,108 +20,146 @@ }, { "cell_type": "markdown", - "id": "53feb87e", + "id": "303f1068", "metadata": {}, "source": [ - "### STEP 1: load the chest Xray data" + "### Load Libraries" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "13913032", + "execution_count": null, + "id": "a9b89c8f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/chaoqiy2/miniconda3/envs/moltext/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Statistics of COVID19CXRDataset:\n", - "Number of samples: 21165\n", - "Number of classes: 4\n", - "Class distribution: Counter({'Normal': 10192, 'Lung Opacity': 6012, 'COVID': 3616, 'Viral Pneumonia': 1345})\n" + "/home/ubuntu/PyHealth/pyhealth/trainer.py:12: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from tqdm.autonotebook import trange\n", + "/home/ubuntu/PyHealth/pyhealth/sampler/sage_sampler.py:3: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " import pkg_resources\n" ] } ], "source": [ + "from pyhealth.datasets import split_by_visit, get_dataloader\n", + "from pyhealth.trainer import Trainer\n", "from pyhealth.datasets import COVID19CXRDataset\n", + "from pyhealth.models import VAE\n", + "from pyhealth.processors import ImageProcessor\n", + "from torchvision import transforms\n", + "from pyhealth.processors import SequenceProcessor\n", "\n", - "root = \"/srv/local/data/COVID-19_Radiography_Dataset\"\n", - "base_dataset = COVID19CXRDataset(root, refresh_cache=False)\n", - "\n", - "base_dataset.stat()" + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", - "id": "048ff2ee", + "id": "53feb87e", + "metadata": {}, + "source": [ + "## STEP 1: load the chest Xray data\n", + "\n", + "We also prepare the data:\n", + "- resize images to 128x128\n", + "- split train/test/validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3804b253", "metadata": {}, + "outputs": [], "source": [ - "### STEP 2: set task and processing the data" + "# Download command (uncomment to run)\n", + "# !curl -L -o ~/Downloads/covid19-radiography-database.zip https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database\n", + "# !unzip ~/Downloads/covid19-radiography-database.zip -d ~/Downloads/COVID-19_Radiography_Dataset" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "9487b2c6", + "execution_count": 2, + "id": "13913032", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Initializing covid19_cxr dataset from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset (dev mode: False)\n", + "Scanning table: covid19_cxr from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset/covid19_cxr-metadata-pyhealth.csv\n", + "Collecting global event dataframe...\n", + "Collected dataframe with shape: (21165, 6)\n", + "Dataset: covid19_cxr\n", + "Dev mode: False\n", + "Number of patients: 21165\n", + "Number of events: 21165\n", + "Setting task COVID19CXRClassification for covid19_cxr base dataset...\n", + "Generating samples with 1 worker(s)...\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "Generating samples for COVID19CXRClassification: 100%|███████████████████████████████████████████| 21165/21165 [00:00<00:00, 2783009.72it/s]\n" + "Generating samples for COVID19CXRClassification with 1 worker: 100%|██████████| 21165/21165 [00:08<00:00, 2637.68it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Processing samples: 100%|██████████| 21165/21165 [01:18<00:00, 270.15it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated 21165 samples for task COVID19CXRClassification\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" ] } ], "source": [ - "from torchvision import transforms\n", + "image_size = 128\n", + "covid19cxr_path = \"~/Downloads/COVID-19_Radiography_Dataset\"\n", "\n", - "sample_dataset = base_dataset.set_task()\n", + "base_dataset = COVID19CXRDataset(covid19cxr_path)\n", "\n", - "# the transformation automatically normalize the pixel intensity into [0, 1]\n", - "transform = transforms.Compose([\n", - " transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)), # only use the first channel\n", - " transforms.Resize((128, 128)),\n", - "])\n", + "base_dataset.stats()\n", "\n", - "def encode(sample):\n", - " sample[\"path\"] = transform(sample[\"path\"])\n", - " return sample\n", "\n", - "sample_dataset.set_transform(encode)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "1601d7d1", - "metadata": {}, - "outputs": [], - "source": [ - "from pyhealth.datasets import split_by_visit, get_dataloader\n", + "# Step 2: Set task with custom image processing for GAN\n", + "image_processor = ImageProcessor(image_size=image_size, mode=\"RGB\") # Resize to 128x128 for GAN\n", "\n", - "# split dataset\n", - "train_dataset, val_dataset, test_dataset = split_by_visit(\n", - " sample_dataset, [0.8, 0.1, 0.1]\n", - ")\n", - "train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)\n", - "val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)\n", - "test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)" + "sample_dataset = base_dataset.set_task(input_processors={\"image\": image_processor})" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "d4d8b012", + "execution_count": null, + "id": "split_data", "metadata": {}, "outputs": [ { @@ -126,8 +172,19 @@ } ], "source": [ + "\n", + "# split dataset\n", + "train_dataset, val_dataset, test_dataset = split_by_visit(\n", + " sample_dataset, [0.8, 0.1, 0.1]\n", + ")\n", + "\n", + "train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)\n", + "val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)\n", + "test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)\n", + "\n", "data = next(iter(train_dataloader))\n", - "print (data[\"path\"][0].shape)\n", + "\n", + "print(data[\"image\"][0].shape)\n", "\n", "print(\n", " \"loader size: train/val/test\",\n", @@ -139,7 +196,7 @@ }, { "cell_type": "markdown", - "id": "a897d167", + "id": "step3", "metadata": {}, "source": [ "### STEP3: define the GAN model" @@ -147,8 +204,8 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "b93a449c", + "execution_count": 4, + "id": "gan_model", "metadata": {}, "outputs": [], "source": [ @@ -157,13 +214,13 @@ "model = GAN(\n", " input_channel=3,\n", " input_size=128,\n", - " hidden_dim = 256,\n", + " hidden_dim=256,\n", ")" ] }, { "cell_type": "markdown", - "id": "e6155755", + "id": "step4", "metadata": {}, "source": [ "### STEP4: training the GAN model in an adversarial way" @@ -171,708 +228,708 @@ }, { "cell_type": "code", - "execution_count": 24, - "id": "a06f823a", + "execution_count": 5, + "id": "training", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:56<00:00, 1.18it/s]\n" + "100%|██████████| 67/67 [00:07<00:00, 9.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 0 --- loss of G: 235.46342170238495, loss of D: 38.44307478144765\n" + "epoch: 0 --- loss of G: 187.00772655010223, loss of D: 17.679536825045943\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.70it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 1 --- loss of G: 189.38715118169785, loss of D: 30.993362290784717\n" + "epoch: 1 --- loss of G: 421.60688638687134, loss of D: 0.9925004122778773\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.17it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 2 --- loss of G: 239.08964610099792, loss of D: 6.970357734709978\n" + "epoch: 2 --- loss of G: 510.64223623275757, loss of D: 0.3214332648785785\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:54<00:00, 1.22it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.63it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 3 --- loss of G: 282.8131763935089, loss of D: 4.137967556715012\n" + "epoch: 3 --- loss of G: 567.562891960144, loss of D: 0.15068832962424494\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:55<00:00, 1.20it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.35it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 4 --- loss of G: 255.01798963546753, loss of D: 5.110567823052406\n" + "epoch: 4 --- loss of G: 599.3213233947754, loss of D: 0.08608292078133672\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:54<00:00, 1.23it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.19it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 5 --- loss of G: 304.87264013290405, loss of D: 8.64646109007299\n" + "epoch: 5 --- loss of G: 588.9778165817261, loss of D: 0.12099041216424666\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:52<00:00, 1.28it/s]\n" + "100%|██████████| 67/67 [00:07<00:00, 8.99it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 6 --- loss of G: 500.5840277671814, loss of D: 24.204148830845952\n" + "epoch: 6 --- loss of G: 603.5574598312378, loss of D: 0.12727271695621312\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:56<00:00, 1.19it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.30it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 7 --- loss of G: 451.33461332321167, loss of D: 0.7537129627307877\n" + "epoch: 7 --- loss of G: 639.0430097579956, loss of D: 0.1332307325792499\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:55<00:00, 1.21it/s]\n" + "100%|██████████| 67/67 [00:13<00:00, 5.15it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 8 --- loss of G: 431.7444968223572, loss of D: 6.4659804629627615\n" + "epoch: 8 --- loss of G: 651.0378975868225, loss of D: 0.9335501801688224\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:56<00:00, 1.18it/s]\n" + "100%|██████████| 67/67 [02:29<00:00, 2.23s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 9 --- loss of G: 596.1160368919373, loss of D: 7.6383128159234275\n" + "epoch: 9 --- loss of G: 616.8124380111694, loss of D: 3.84032856952399\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [03:13<00:00, 2.89s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 10 --- loss of G: 484.0970947742462, loss of D: 15.179329166712705\n" + "epoch: 10 --- loss of G: 519.6118609905243, loss of D: 9.224071164149791\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [01:48<00:00, 1.62s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 11 --- loss of G: 407.63574838638306, loss of D: 0.530137675232254\n" + "epoch: 11 --- loss of G: 562.777735710144, loss of D: 3.7825616598129272\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.87it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 12 --- loss of G: 312.33895611763, loss of D: 3.4384627752006054\n" + "epoch: 12 --- loss of G: 336.34571504592896, loss of D: 20.344450883567333\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 11.03it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 13 --- loss of G: 312.2976408004761, loss of D: 2.7867361521348357\n" + "epoch: 13 --- loss of G: 294.52232551574707, loss of D: 19.83913780748844\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.38it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 14 --- loss of G: 264.729043006897, loss of D: 7.06029037386179\n" + "epoch: 14 --- loss of G: 140.42432856559753, loss of D: 39.276702120900154\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.17it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.78it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 15 --- loss of G: 225.16066551208496, loss of D: 11.562116228044033\n" + "epoch: 15 --- loss of G: 99.24431133270264, loss of D: 38.27878451347351\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.40it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 16 --- loss of G: 218.4423906803131, loss of D: 8.810364093631506\n" + "epoch: 16 --- loss of G: 81.83258920907974, loss of D: 40.19685423374176\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.17it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.40it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 17 --- loss of G: 312.99179339408875, loss of D: 11.620514743030071\n" + "epoch: 17 --- loss of G: 83.60888206958771, loss of D: 40.768057465553284\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 18 --- loss of G: 211.89375686645508, loss of D: 9.615357603877783\n" + "epoch: 18 --- loss of G: 96.50049769878387, loss of D: 36.23655703663826\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.17it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.50it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 19 --- loss of G: 171.22825515270233, loss of D: 25.423519730567932\n" + "epoch: 19 --- loss of G: 105.58138465881348, loss of D: 35.10594001412392\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:56<00:00, 1.18it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.88it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 20 --- loss of G: 210.7448410987854, loss of D: 24.356860101222992\n" + "epoch: 20 --- loss of G: 114.39195239543915, loss of D: 32.3329062461853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.81it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 21 --- loss of G: 244.63870346546173, loss of D: 11.02416418865323\n" + "epoch: 21 --- loss of G: 101.80236661434174, loss of D: 38.6760116815567\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 22 --- loss of G: 243.14930987358093, loss of D: 22.47604788094759\n" + "epoch: 22 --- loss of G: 109.04553139209747, loss of D: 32.825062185525894\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 23 --- loss of G: 274.94408202171326, loss of D: 26.34657647460699\n" + "epoch: 23 --- loss of G: 123.76316118240356, loss of D: 30.759442001581192\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.36it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 24 --- loss of G: 234.09665405750275, loss of D: 16.731902796775103\n" + "epoch: 24 --- loss of G: 105.20807480812073, loss of D: 36.456340968608856\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.38it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 25 --- loss of G: 306.30480122566223, loss of D: 16.174514853628352\n" + "epoch: 25 --- loss of G: 137.6276512145996, loss of D: 25.502176865935326\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.41it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 26 --- loss of G: 330.1878414154053, loss of D: 34.08231335878372\n" + "epoch: 26 --- loss of G: 149.49159610271454, loss of D: 28.30062609910965\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.40it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 27 --- loss of G: 254.93080186843872, loss of D: 14.502326969988644\n" + "epoch: 27 --- loss of G: 145.72875905036926, loss of D: 31.580776512622833\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.45it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 28 --- loss of G: 276.89978194236755, loss of D: 17.942093975842\n" + "epoch: 28 --- loss of G: 130.92576503753662, loss of D: 32.87881879508495\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 29 --- loss of G: 207.1793919801712, loss of D: 22.071087922900915\n" + "epoch: 29 --- loss of G: 131.46568083763123, loss of D: 31.840122565627098\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.37it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 30 --- loss of G: 158.61903989315033, loss of D: 27.629727222025394\n" + "epoch: 30 --- loss of G: 84.15838372707367, loss of D: 42.73091375827789\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.69it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 31 --- loss of G: 202.9431939125061, loss of D: 25.77581850066781\n" + "epoch: 31 --- loss of G: 74.97820204496384, loss of D: 40.54647704958916\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.68it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 32 --- loss of G: 324.6885280609131, loss of D: 36.55593832582235\n" + "epoch: 32 --- loss of G: 101.26055026054382, loss of D: 29.263713508844376\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.31it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 33 --- loss of G: 221.39485549926758, loss of D: 11.331106215715408\n" + "epoch: 33 --- loss of G: 86.86687338352203, loss of D: 37.598598927259445\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 34 --- loss of G: 147.92780953645706, loss of D: 38.99323396384716\n" + "epoch: 34 --- loss of G: 84.66891765594482, loss of D: 40.19908273220062\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.33it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 35 --- loss of G: 192.14132618904114, loss of D: 28.10118254646659\n" + "epoch: 35 --- loss of G: 79.72011703252792, loss of D: 41.72318208217621\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 36 --- loss of G: 235.25207090377808, loss of D: 31.65054650232196\n" + "epoch: 36 --- loss of G: 73.43266141414642, loss of D: 40.948941975831985\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 37 --- loss of G: 157.7527117729187, loss of D: 31.56289905309677\n" + "epoch: 37 --- loss of G: 72.37873077392578, loss of D: 41.60090494155884\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:59<00:00, 1.13it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.73it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 38 --- loss of G: 203.90822780132294, loss of D: 24.688751220703125\n" + "epoch: 38 --- loss of G: 61.18730026483536, loss of D: 44.2846662402153\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.48it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 39 --- loss of G: 193.7863022685051, loss of D: 27.719643944874406\n" + "epoch: 39 --- loss of G: 62.92646849155426, loss of D: 43.42814892530441\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 40 --- loss of G: 148.90731024742126, loss of D: 24.35784476995468\n" + "epoch: 40 --- loss of G: 70.7225307226181, loss of D: 38.10763677954674\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.27it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 41 --- loss of G: 138.36136281490326, loss of D: 32.37918431684375\n" + "epoch: 41 --- loss of G: 79.05009615421295, loss of D: 38.92615833878517\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.65it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 42 --- loss of G: 144.17725467681885, loss of D: 24.131414487957954\n" + "epoch: 42 --- loss of G: 77.47582876682281, loss of D: 37.80526325106621\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:59<00:00, 1.14it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.84it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 43 --- loss of G: 137.79231882095337, loss of D: 28.492403000593185\n" + "epoch: 43 --- loss of G: 89.28383368253708, loss of D: 36.453640162944794\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.69it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 44 --- loss of G: 130.22422045469284, loss of D: 29.230164095759392\n" + "epoch: 44 --- loss of G: 83.84232759475708, loss of D: 39.26784712076187\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.16it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.62it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 45 --- loss of G: 178.59211683273315, loss of D: 18.080337017774582\n" + "epoch: 45 --- loss of G: 86.17833715677261, loss of D: 37.44430324435234\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 46 --- loss of G: 208.3746610879898, loss of D: 37.55503672361374\n" + "epoch: 46 --- loss of G: 88.58943152427673, loss of D: 35.40870875120163\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 47 --- loss of G: 149.65763175487518, loss of D: 29.568245857954025\n" + "epoch: 47 --- loss of G: 88.73302459716797, loss of D: 36.97438296675682\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:57<00:00, 1.17it/s]\n" + "100%|██████████| 67/67 [00:06<00:00, 10.50it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 48 --- loss of G: 229.93753325939178, loss of D: 25.85401887446642\n" + "epoch: 48 --- loss of G: 71.31993687152863, loss of D: 44.25900250673294\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:58<00:00, 1.15it/s]" + "100%|██████████| 67/67 [00:06<00:00, 10.71it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 49 --- loss of G: 215.52084922790527, loss of D: 15.89515457302332\n" + "epoch: 49 --- loss of G: 70.42304891347885, loss of D: 37.90644350647926\n" ] }, { @@ -889,10 +946,11 @@ "\n", "# Loss function\n", "loss = torch.nn.BCELoss()\n", + "\n", "opt_G = torch.optim.AdamW(model.generator.parameters(), lr=1e-3)\n", "opt_D = torch.optim.AdamW(model.discriminator.parameters(), lr=1e-4)\n", "\n", - "device = \"cuda:4\"\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "model.to(device)\n", "\n", "curve_D, curve_G = [], []\n", @@ -903,35 +961,39 @@ " for batch in tqdm(train_dataloader):\n", " \n", " \"\"\" train discriminator \"\"\"\n", + " \n", " opt_D.zero_grad()\n", - "\n", - " real_imgs = torch.stack(batch[\"path\"], dim=0).to(device)\n", + " \n", + " real_imgs = batch[\"image\"].to(device)\n", + " \n", " batch_size = real_imgs.shape[0]\n", + " \n", " fake_imgs = model.generate_fake(batch_size, device)\n", " \n", " real_loss = loss(model.discriminator(real_imgs), torch.ones(batch_size, 1).to(device))\n", " fake_loss = loss(model.discriminator(fake_imgs.detach()), torch.zeros(batch_size, 1).to(device))\n", " loss_D = (real_loss + fake_loss) / 2\n", - "\n", + " \n", " loss_D.backward()\n", " opt_D.step()\n", " \n", " \"\"\" train generator \"\"\"\n", + " \n", " opt_G.zero_grad()\n", " loss_G = loss(model.discriminator(fake_imgs), torch.ones(batch_size, 1).to(device))\n", - "\n", + " \n", " loss_G.backward()\n", " opt_G.step()\n", " \n", " curve_G[-1] += loss_G.item()\n", " curve_D[-1] += loss_D.item()\n", - "\n", - " print (f\"epoch: {epoch} --- loss of G: {curve_G[-1]}, loss of D: {curve_D[-1]}\")" + " \n", + " print(f\"epoch: {epoch} --- loss of G: {curve_G[-1]}, loss of D: {curve_D[-1]}\")" ] }, { "cell_type": "markdown", - "id": "dd1b7851", + "id": "exp2", "metadata": {}, "source": [ "### EXP 2: synthesize random images" @@ -939,13 +1001,13 @@ }, { "cell_type": "code", - "execution_count": 31, - "id": "e544672a", + "execution_count": 6, + "id": "generate", "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ "
" ] @@ -960,24 +1022,19 @@ "import matplotlib.pyplot as plt\n", "\n", "model.eval()\n", + "\n", "with torch.no_grad():\n", " fake_imgs = model.generate_fake(1, device).detach().cpu()\n", - " plt.imshow(fake_imgs[0][0], cmap=\"gray\")\n", + " plt.imshow(fake_imgs[0].permute(1, 2, 0).clamp(0, 1)) # RGB image\n", + " plt.title(\"Generated Chest X-Ray\")\n", + " plt.axis('off')\n", " plt.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "267e356d", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -991,7 +1048,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.16" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/examples/chestXray_image_generation_VAE.ipynb b/examples/chestXray_image_generation_VAE.ipynb new file mode 100644 index 000000000..57c3a2075 --- /dev/null +++ b/examples/chestXray_image_generation_VAE.ipynb @@ -0,0 +1,788 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Chest X-Ray Image Generation using VAE\n", + "\n", + "This notebook illustrates how to use the VAE module to generate X-Ray images, including the new features for conditional generation based on patient features.\n", + "\n", + "We will take the COVID-19 CXR dataset as starting point. This dataset is freely available on Kaggle and contains images of Chest X-Rays from COVID-19 patients." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Data\n", + "\n", + "Data is available from Kaggle. If it is not already available locally, download it with the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Download command (uncomment to run)\n", + "# !curl -L -o ~/Downloads/covid19-radiography-database.zip https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database\n", + "# !unzip ~/Downloads/covid19-radiography-database.zip -d ~/Downloads/COVID-19_Radiography_Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Data with PyHealth Datasets\n", + "\n", + "Use the COVID19CXRDataset to load this data. For custom datasets, see the `BaseImageDataset` class." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/PyHealth/pyhealth/trainer.py:12: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from tqdm.autonotebook import trange\n", + "/home/ubuntu/PyHealth/pyhealth/sampler/sage_sampler.py:3: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " import pkg_resources\n" + ] + } + ], + "source": [ + "from pyhealth.datasets import split_by_visit, get_dataloader\n", + "from pyhealth.trainer import Trainer\n", + "from pyhealth.datasets import COVID19CXRDataset\n", + "from pyhealth.models import VAE\n", + "from pyhealth.processors import ImageProcessor\n", + "from torchvision import transforms\n", + "from pyhealth.processors import SequenceProcessor\n", + "\n", + "\n", + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Data and resize images to 64x64\n", + "\n", + "Assuming data is available locally, we load it and apply an ImageProcessor step to resize all the Chest X-Ray images to 64x64." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Initializing covid19_cxr dataset from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset (dev mode: False)\n", + "Scanning table: covid19_cxr from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset/covid19_cxr-metadata-pyhealth.csv\n", + "Setting task COVID19CXRClassification for covid19_cxr base dataset...\n", + "Generating samples with 1 worker(s)...\n", + "Collecting global event dataframe...\n", + "Collected dataframe with shape: (21165, 6)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for COVID19CXRClassification with 1 worker: 100%|██████████| 21165/21165 [00:08<00:00, 2397.69it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Processing samples: 100%|██████████| 21165/21165 [00:33<00:00, 630.01it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated 21165 samples for task COVID19CXRClassification\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Step 1: Load data\n", + "covid19cxr_path = \"/home/ubuntu/Downloads/COVID-19_Radiography_Dataset\"\n", + "base_dataset = COVID19CXRDataset(covid19cxr_path)\n", + "image_size = 128\n", + "\n", + "# Step 2: Set task with custom image processing for VAE\n", + "image_processor = ImageProcessor(image_size=image_size, mode=\"L\") # Resize to 128x128 for VAE\n", + "sample_dataset = base_dataset.set_task(input_processors={\"image\": image_processor})\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Map disease metadata\n", + "\n", + "We prepare the disease status as an input processor. This will allow us to later generate images conditioned by the disease status - for example, different images for COVID patients vs normal." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: covid19_cxr from /home/ubuntu/Downloads/COVID-19_Radiography_Dataset/covid19_cxr-metadata-pyhealth.csv\n" + ] + } + ], + "source": [ + "covid_metadata = base_dataset.load_table('covid19_cxr').collect()\n", + "vocab = {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}\n", + "\n", + "# Add processor for the conditional feature\n", + "sample_dataset.input_processors[\"disease\"] = SequenceProcessor()\n", + "sample_dataset.input_processors[\"disease\"].code_vocab = vocab" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split Train/Validation/Test\n", + "\n", + "We apply a simple train/validation/test split, to evaluate the performances of the model during training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data keys: dict_keys(['image', 'disease'])\n", + "Image shape: torch.Size([1, 64, 64])\n", + "Dataset sizes - Train: 12699, Val: 4233, Test: 4233\n" + ] + } + ], + "source": [ + "# Split dataset\n", + "train_dataset, val_dataset, test_dataset = split_by_visit(\n", + " sample_dataset, [0.6, 0.2, 0.2]\n", + ")\n", + "train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)\n", + "val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)\n", + "test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)\n", + "\n", + "# Check data\n", + "data = next(iter(train_dataloader))\n", + "print(\"Data keys:\", data.keys())\n", + "print(\"Image shape:\", data[\"image\"][0].shape)\n", + "print(f\"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic VAE Training (Image Generation)\n", + "\n", + "Here we train a standard VAE for unconditional image generation.\n", + "\n", + "Note how we are passing the disease state as conditional_feature_key - this will help the model learn to distinguish images based on the disease state.\n", + "\n", + "Note: since this notebook is just a demo, we will only run 5 epochs. Image generation will work much better if the VAE is trained for longer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: No embedding created for field due to lack of compatible processor: image\n", + "VAE(\n", + " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(\n", + " (disease): Embedding(4, 256, padding_idx=0)\n", + " ))\n", + " (encoder1): Sequential(\n", + " (0): ResBlock2D(\n", + " (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ELU(alpha=1.0)\n", + " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (downsampler): Sequential(\n", + " (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (1): ResBlock2D(\n", + " (conv1): Conv2d(16, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ELU(alpha=1.0)\n", + " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (downsampler): Sequential(\n", + " (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (2): ResBlock2D(\n", + " (conv1): Conv2d(64, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ELU(alpha=1.0)\n", + " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (downsampler): Sequential(\n", + " (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " )\n", + " (mu): Linear(in_features=256, out_features=256, bias=True)\n", + " (log_std2): Linear(in_features=256, out_features=256, bias=True)\n", + " (decoder1): Sequential(\n", + " (0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2))\n", + " (1): ReLU()\n", + " (2): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))\n", + " (3): ReLU()\n", + " (4): ConvTranspose2d(64, 32, kernel_size=(6, 6), stride=(2, 2))\n", + " (5): ReLU()\n", + " (6): ConvTranspose2d(32, 1, kernel_size=(6, 6), stride=(2, 2))\n", + " (7): Sigmoid()\n", + " )\n", + ")\n", + "Metrics: ['kl_divergence', 'mse', 'mae']\n", + "Device: cuda\n", + "\n", + "Training:\n", + "Batch size: 256\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: kl_divergence\n", + "Monitor criterion: min\n", + "Epochs: 5\n", + "Patience: None\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0 / 5: 0%| | 0/50 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Get real and reconstructed images\n", + "X, X_rec, _ = trainer.inference(test_dataloader)\n", + "\n", + "# Plot comparison\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax1.imshow(X[0].reshape(image_size, image_size), cmap=\"gray\")\n", + "ax1.set_title(\"Real Chest X-Ray\")\n", + "ax1.axis('off')\n", + "\n", + "ax2.imshow(X_rec[0].reshape(64, 64), cmap=\"gray\")\n", + "ax2.set_title(\"Reconstructed by VAE\")\n", + "ax2.axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(\"chestxray_vae_comparison.png\", dpi=150)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment 2: Random Image Generation\n", + "\n", + "Once the VAE has been trained, we can use it to generate new random chest X-Ray images. Results will be better if the model is trained for more epochs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Generate synthetic images\n", + "model = trainer.model\n", + "model.eval()\n", + "\n", + "with torch.no_grad():\n", + " # Sample from latent space\n", + " z = torch.randn(1, model.hidden_dim).to(model.device)\n", + " \n", + " # Reshape for decoder (add spatial dims)\n", + " z = z.unsqueeze(2).unsqueeze(3)\n", + " \n", + " # Generate image\n", + " generated = model.decoder(z).detach().cpu().numpy()\n", + " \n", + " # Plot\n", + " plt.figure(figsize=(5, 5))\n", + " plt.imshow(generated[0].reshape(image_size, image_size), cmap=\"gray\")\n", + " plt.title(\"Generated Chest X-Ray\")\n", + " plt.axis('off')\n", + " plt.savefig(\"chestxray_vae_synthetic.png\", dpi=150)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## New Feature: Conditional VAE\n", + "\n", + "We can use the VAE to generate random images, conditioned of patient features, such as the disease state. For example, we can generate images for COVID patients, and compare to normal patients.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAewAAAGrCAYAAAACd6S0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAARPNJREFUeJzt3Xm8VWW9x/Hf5siZJ+BAghokCIiKA6kpGaQ4AWqlWaApmlOZ5pSaNmiaXodAAnHsOpveMitNr0PZKxt8VdeRaxYoYI4MweFwzkH08Lt/eM+OzVnfw/7hcXi2n/fr1euVz1lnrWc961nrYcN3/XbO3d0AAMAHWq/3uwMAAGDDWLABAEgACzYAAAlgwQYAIAEs2AAAJIAFGwCABLBgAwCQABZsAAASwIINAEACWLDfR0OGDLFp06b16D6nTZtmtbW1PbpPFO/SSy+1kSNH2tq1a9/vrmxQLpez8847L//fN954o+VyOVu4cOEGf/e3v/2t5XI5++1vf/uu9S/i3biX3gsLFy60XC5nN954Y4/u97zzzrNcLtej+8S/TZs2zYYMGRL+vTfffNO22GILmzNnzkYdd6MW7AULFtjXvvY1Gz58uFVXV1t1dbWNGjXKTjzxRHv66ac3qiMfVPfdd1/BQ+39snr1apsxY4btuuuu1tDQYJWVlTZ8+HD72te+Zv/4xz/e1761tbXZeeedV/TD+1vf+pZ82N9xxx2Wy+Vs9uzZ3e5jyJAhlsvl8v+rqamxXXbZxW6++eaNOIOesXLlSrvkkkvsrLPOsl69Cm+tD/L1686cOXN6fDEpZdOnT7dcLmcPP/yw3Oa6666zXC5nv/zlL9/DnmWbNm1awX1UX19v22+/vf3gBz+wN9544/3uXsnp3bu3nXbaafb973/fVq9eHd+BB91zzz1eXV3t9fX1/pWvfMWvvvpqv/baa/20007zIUOGeC6X84ULF0Z3+4F14okn+kYMU1EGDx7sRx555Aa3W7JkiY8ZM8bNzCdPnuxXXHGFX3/99f6Nb3zDt9hiC+/du3d+2yOPPNJramrelf521z8z8+9+97tFbd/e3u5Dhw71ESNG+BtvvJFvX758uW+66aa+8847e0dHR7f7GDx4sO+www5+yy23+C233OKXXnqpDx8+3M3Mr7322ndyOhttxowZXl9f7+3t7QXtkev3Xlr/mr311lve3t7ua9euzbdts802Pm7cuC6/29HR4e3t7Ru8Tu+VYu+ld9vLL7/svXr18qOOOkpuM378eO/Xr5+vWbPG165d6+3t7f7WW2/1aD+++93vFvXcOvLII72ioiJ/H82aNcvHjx/vZuZf+MIXerRPpWTNmjW+evXqjfrd5cuXe3l5uf/oRz8K/25oJZo/f77X1NT41ltv7a+88kqXn7/55ps+c+ZMf/HFF8Mdea+sWrUqtP0HYcGeNGmS9+rVy3/60592+dnq1av99NNPz/93Cgu2u/uDDz7oZubnnXdevu3444/3srIyf+KJJzb4+4MHD/ZJkyYVtC1evNhra2t96623LrofPWn06NF++OGHd2mPXL/3UjHXTC3YHzQflAXb3X2vvfbyhoaGzAf6Sy+95L169fITTjghtM/ocyuyYK//vOjo6PCPf/zjbmb+8ssvh46L4kyePNn32GOP8O+FVqLjjjvOzcwfe+yx0EH+9re/+cEHH+x9+vTxiooKHzNmjP/iF78o2OaGG25wM/Pf//73fuqpp3pTU5NXV1f7Zz7zGV+8eHGXfd53333+yU9+0qurq722ttYnTpzoc+fOLdimczLOnz/f999/f6+trfWDDjrI3d1/97vf+SGHHOJbbLGFl5eX++abb+6nnHKKt7W1Ffy+mXX5X6eOjg6fMWOGjxo1yisqKnzAgAF+3HHH+b/+9a+Cfqxdu9YvuOAC32yzzbyqqsrHjx/vc+fOLeoh89hjj7mZ+bHHHlvMUOfP+aWXXvKDDjrIa2pqvKmpyU8//fQuf4ovtv9/+ctffJ999vF+/fp5ZWWlDxkyJP8JYsGCBZljVMziPXXqVK+oqPC///3v/sc//tFzuZyfdtppRZ1n1oLt7v7xj3/cy8vLC9qKudb/+Z//6Wbmjz/+eJd9fv/73/devXr5Sy+9JPvzwgsvuJn5jTfeWNAevX7u7r/+9a/zc7uhocEPPPBAf/bZZwu26Xwgz5s3z4888khvaGjw+vp6nzZtmre2thZsu3r1aj/llFO8qanJa2tr/YADDvB//vOfXa5T5z24YMECd397jNe/rp2L9yOPPOJm5o888kjBsf7rv/7Ld9ppJ6+srPR+/fr5YYcd1mXcInP0sssu891228379u3rlZWVvtNOO/lPfvKTLmNW7IK9atUqP+2003zzzTf38vJyHz58uF922WUFf6vg/vYfZk488US/++67fZtttvHy8nIfNWqU33///Rs8Ruc43nXXXV1+dvnll7uZ+aOPPuru/75/brjhhi7js7HPLfd3tmC7u59xxhluZv6HP/zB3f99vz366KO+8847e0VFhX/sYx/zm266qcvvLl++3L/+9a/nx3jo0KH+H//xHwV/G6PmT3fjsWjRIp80aZLX1NT4oEGDfPbs2e7u/vTTT/unP/1pr66u9o9+9KN+2223denT888/74cccoj36dPHq6qqfNddd/V77723YJvOPt15551+4YUX+mabbeYVFRW+5557+rx587qM2+DBgwvaip2r7u4zZ870XC7ny5Yty/y5ElqwBw0a5MOGDQsdYO7cud7Q0OCjRo3ySy65xGfPnu2f+tSnPJfL+c9+9rP8dp2TfMcdd/Q999zTZ82a5aeffrqXlZX5oYceWrDPm2++2XO5nO+3334+a9Ysv+SSS3zIkCHe2NiYf9i4//uve4YOHepHHnmkX3311X7zzTe7u/tJJ53kEydO9IsuusivueYa//KXv+xlZWV+yCGH5H//j3/8o++9995uZvm/MrrlllvyPz/mmGN8k0028WOPPdavvvpqP+uss7ympsZ33nlnX7NmTX67b33rW25mPnHiRJ89e7YfffTRPmjQIG9qatrgQ+acc85xM/Pf/e53RY33kUce6ZWVlb7NNtv40Ucf7VdddZUffPDBbmY+Z86cgm2L6f/rr7/uffr0yT/YrrvuOj/33HPzn2JXrVrlV111lZuZf/azn82P0VNPPbXBvr722mvep08fHz9+vG+33Xa+xRZbeEtLS1HnmbVgv/nmm77pppv6Rz7ykYL2Yq71ypUrvaqqKvPT7qhRo3zPPffstj+33nqrm5k//fTTBe3R6/fQQw/5Jpts4sOHD/dLL73Uzz//fG9qavI+ffoUzO3OB/KOO+7on/vc53zOnDl+zDHHuJn5mWeeWbDPww8/3M3Mp06d6rNnz/bPfe5zPnr06A0u2HfffbdvvvnmPnLkyPx1ffDBB909+4Hb+fs777yzz5gxw88++2yvqqryIUOG+PLly/PbRebo5ptv7l/96ld99uzZPn36dN9ll13czLo8bItZsNeuXet77rmn53I5P+aYY3z27Nl+wAEHuJn5KaecUrCtmfn222/vAwcO9AsuuMCvuOIK33LLLb26utqXLl3a7XGam5u9srLSDz744C4/22mnnXzw4MH5PyCoBeqdPLfc3/mC/dnPftbNzJ977jl3f3t8R4wY4R/5yEf8nHPO8dmzZ/tOO+3kuVyu4INSa2urjx492vv16+fnnHOOX3311X7EEUd4Lpfzr3/96/ntogt2ZWWljxo1yk844QS/8sorfffdd89vN2jQIP/GN77hs2bN8m222cbLysr8hRdeyP/+a6+95h/5yEe8rq7Ozz33XJ8+fbpvv/323qtXr4I1qLNPO+64o48ZM8ZnzJjh5513nldXV/suu+zSZdzWX7CLnavu7r///e/dzPyee+6R1yZL0Qt2c3Ozm5l/5jOf6fKz5cuX+5IlS/L/W/dPe3vttZdvt912BX89tHbtWt999919q622yrd13uwTJkwo+NPuqaee6mVlZb5ixQp3d29pafHGxsYun1hee+01b2hoKGjv/IR89tlnd+nz+n8idXe/+OKLPZfL+aJFi/Jt6q/EH330UTezLn+a++///u+C9sWLF3t5eblPmjSp4Lw6H+Qbesh03jjrPvC603nO3/ve9wraOydhtP933323m5n/5S9/kcfcmL8S73TNNdfkP739/Oc/L/r3Bg8e7Pvss09+zj3zzDP+pS99Kf/JaF3FXuspU6b4oEGDCj4JPP74410eIFk6/1C2/h84otdvhx128AEDBhT8yfupp57yXr16+RFHHJFv63wgH3300V2O169fv/x/P/nkk25m/tWvfrVgu6lTp25wwXbXfyW+/gN3zZo1PmDAAN92220L/g3/3nvvdTPz73znO/m2Yueoe9drt2bNGt922227/AGqmAX75z//uZuZX3jhhQXthxxyiOdyOZ8/f36+zcy8vLy8oO2pp55yM/NZs2Z1exx3989//vNeWVnpzc3N+bbnnnvOzcy/+c1v5tvUAvVOn1vRBbvzPpo/f75fdNFFnsvlfPTo0fntOv+2Zd0/eC5evNgrKioK/pB7wQUXeE1Njf/jH/8oOM7ZZ5/tZWVl+X8ujS7YZuYXXXRRvm358uVeVVXluVzO77jjjnx75xivO69POeWUgr/VcH97HfnYxz7mQ4YMyd/vnX3aeuutC7I1M2fOdDPzZ555pqBP6y/Yxc5Vd/dXXnnFzcwvueSSLj/rTtEp8ZUrV5qZZb4yNH78eOvfv3/+f1deeaWZmf3rX/+y3/zmN3booYdaS0uLLV261JYuXWrLli2zfffd1+bNm2cvv/xywb6OO+64gtcR9thjD+vo6LBFixaZmdlDDz1kK1assClTpuT3t3TpUisrK7Ndd93VHnnkkS79+8pXvtKlraqqKv//W1tbbenSpbb77rubu9sTTzyxwfH4yU9+Yg0NDbb33nsX9GPMmDFWW1ub78fDDz9sa9assZNOOqngvE455ZQNHsPs3+NeV1dX1PadTjjhhIL/3mOPPeyFF14I97+xsdHMzO6991578803Q30oRlNTk5mZVVdX2yc/+cnQ7z744IP5ObfddtvZLbfcYkcddZRddtllBdsVe62POOIIe+WVVwrm0G233WZVVVV28MEHd9uXZcuW2SabbNLl/ohcv1dffdWefPJJmzZtmvXt2zffPnr0aNt7773tvvvu6/I7Wdd52bJl+eN2/s7JJ59csF2x869Yf/3rX23x4sX21a9+1SorK/PtkyZNspEjR9qvfvWrovq+7hw1K7x2y5cvt+bmZttjjz3s8ccfD/fxvvvus7Kysi5jcfrpp5u72/3331/QPmHCBBs6dGj+v0ePHm319fVd+pjl8MMPt9WrV9vPfvazfNvtt99uZmaHHXZYUf19N55bWVpbW/P30bBhw+ycc86x3Xbbze6+++6C7UaNGmV77LFH/r/79+9vI0aM6PJc2WOPPaxPnz4Fz5UJEyZYR0eH/e53v9uoPpqZHXPMMfn/39jYaCNGjLCamho79NBD8+0jRoywxsbGgj7dd999tssuuxQ8X2pra+24446zhQsX2rPPPltwnKOOOsrKy8vz/915zhu67pG52qdPHzMzW7p0abf7XN8mxW7Y+cBZtWpVl59dc8011tLSYq+//rodfvjh+fb58+ebu9u3v/1t+/a3v52538WLF9tmm22W/++PfvSjBT/vPLHly5ebmdm8efPMzGzPPffM3F99fX3Bf2+yySa2+eabd9nuxRdftO985zv2y1/+Mr/vTs3NzZn7Xte8efOsubnZBgwYkPnzxYsXm5nl/6Cx1VZbFfy8f//++XPrTuf5tLS05BfPDamsrLT+/fsXtPXp06fgPIvt/7hx4+zggw+2888/32bMmGHjx4+3z3zmMzZ16lSrqKgoqj9KS0uLnXzyyTZixAh7/vnn7ayzzrLrr78+//Pm5mZrb2/P/3d5eXnBQrbrrrvahRdeaB0dHTZ37ly78MILbfny5QU3m1nx13rvvfe2gQMH2m233WZ77bWXrV271n784x/bQQcdFP4DU6fI9eucKyNGjOjys6233toeeOABa21ttZqamnx7d/dLfX29LVq0yHr16lWw8KhjvBPd9X3kyJH2+9//vqCtmDlq9vYfFC+88EJ78sknC14z2ph3jBctWmSDBg3qci233nrrgnPotP7Yqj5m2X///a1v3752++23598P//GPf2zbb7+9bbPNNhv8/XfruZWlsrLS7rnnHjMzq6iosI997GOZxy5mPObNm2dPP/10l2vbqfO5sjF9XH+fDQ0Ntvnmm3eZCw0NDQV9WrRoke26665d9rnudd92223z7Rtag5TIXHV3+bPuFL1gNzQ02MCBA23u3LldftY5GOsXXOgsHnHGGWfYvvvum7nfYcOGFfx3WVlZ5nadJ9i5z1tuucU23XTTLtttsknhKVVUVHR5J7ajo8P23ntv+9e//mVnnXWWjRw50mpqauzll1+2adOmFVX0Yu3atTZgwAC77bbbMn+uJmzUyJEjzczsmWeeKfjTbXfUGK6r2P7ncjn76U9/ao899pjdc8899sADD9jRRx9tP/jBD+yxxx57R0Vazj33XHvttdfsz3/+s91xxx12+eWX21FHHWVjx441M7Ovf/3rdtNNN+W3HzduXMG7201NTTZhwgQzM9t3331t5MiRNnnyZJs5c6addtppZha71mVlZTZ16lS77rrrbM6cOfaHP/zBXnnllYI/hCr9+vWzt956y1paWgoWhI25fhEbul8+qIqZo48++qgdeOCB9qlPfcrmzJljAwcOtN69e9sNN9yQ/7T6fvSxmLHt3bu3HXrooXbdddfZ66+/bi+++KLNmzfPLr300qKO/W49t7KUlZXl76MNbZdl3fFYu3at7b333nbmmWdmbjt8+HAz0wtVR0dH6NjvxvzfmH1G52rn4t/5N4zFKnrBNnv7r7euv/56+/Of/2y77LLLBrffcsstzeztyVvMhChG5yeFAQMGbPQ+n3nmGfvHP/5hN910kx1xxBH59oceeqjLtmpiDR061B5++GEbO3ZswV+FrG/w4MFm9vafPDvHw8xsyZIlRf1J/YADDrCLL77Ybr311h594Bfb/06f+MQn7BOf+IR9//vft9tvv90OO+wwu+OOO+yYY47ZqE87f/3rX+3KK6+0k046yXbaaScbMWKE3XnnnXbCCSfYE088YZtssomdeeaZBYvlhv5GYtKkSTZu3Di76KKL7Pjjj7eamprQtTZ7+6/Ff/CDH9g999xj999/v/Xv31/+YXNdnQvzggULbPTo0fn2yPXrnCt///vfu/zsueees6ampoJP18UYPHiwrV271p5//vmCT79Zx8hS7LVdt+/r/+3X3//+9/zPI+666y6rrKy0Bx54oOBvc2644Ybwvjr7+PDDD3f5Q9Vzzz2X/3lPOuyww+zqq6+2O++80xYsWGC5XM6mTJmy0fuLzuX3w9ChQ23VqlUbfDZ33ssrVqwoaF//bzl6wuDBg+U91fnzdyo6VxcsWGBm//6UX6xQpbMzzzzTqqur7eijj7bXX3+9y8/X/xPIgAEDbPz48XbNNdfYq6++2mX7JUuWhDpr9vYnqfr6ervooosy/021mH12/glq3f66u82cObPLtp0PyPUn1qGHHmodHR12wQUXdPmdt956K7/9hAkTrHfv3jZr1qyC411xxRUb7KeZ2W677Wb77befXX/99fbzn/+8y8/XrFljZ5xxRlH72pj+L1++vMt13WGHHczM8n/tU11dbWZdx0jp6Oiw448/3gYOHJg/fk1Njc2aNcvmzp1rM2bMMLO3/81swoQJ+f+NGTNmg/s+66yzbNmyZXbdddeZWexam73975SjR4+266+/3u666y774he/2OVvbbLstttuZvb2H0TWby/2+g0cONB22GEHu+mmmwrGcu7cufbggw/axIkTN9iP9e2///5mZvbDH/6woL3Y+VdTU1PUdf34xz9uAwYMsKuvvrrgrwPvv/9++9vf/maTJk0qus+dysrKLJfLFXzqWrhwYeY4FmPixInW0dHRpYrejBkzLJfL5ceqp4wdO9aGDBlit956q9155502bty4zL9qLlZ0Lr8fDj30UPvTn/5kDzzwQJefrVixwt566y0ze3uRLCsr6/Jv2htbsrM7EydOtD//+c/2pz/9Kd/W2tpq1157rQ0ZMsRGjRr1jo8Rnav/8z//Y7lcLv/cKFboE/ZWW21lt99+u02ZMsVGjBhhhx12mG2//fbm7rZgwQK7/fbbrVevXgWT8sorr7RPfvKTtt1229mxxx5rW265pb3++uv2pz/9yV566SV76qmnQh2ur6+3q666yr70pS/ZTjvtZF/84hetf//+9uKLL9qvfvUrGzt27AbLWo4cOdKGDh1qZ5xxhr388stWX19vd911V+Yn3s5F4uSTT7Z9993XysrK7Itf/KKNGzfOjj/+eLv44ovtySeftH322cd69+5t8+bNs5/85Cc2c+ZMO+SQQ6x///52xhln2MUXX2yTJ0+2iRMn2hNPPGH3339/0X8dcvPNN9s+++xjn/vc5+yAAw6wvfbay2pqamzevHl2xx132KuvvmqXX355aByL7f9NN91kc+bMsc9+9rM2dOhQa2lpseuuu87q6+vzC0hVVZWNGjXK7rzzThs+fLj17dvXtt1224J/F1rXD3/4Q3v88cftrrvuKvikc+CBB9qBBx5o559/vn3hC1/I/DezDdl///1t2223tenTp9uJJ54YutadjjjiiPwiWsxfh5u9/bdJ2267rT388MN29NFHF/wscv0uu+wy23///W233XazL3/5y9be3m6zZs2yhoaGjSqRu8MOO9iUKVNszpw51tzcbLvvvrv9+te/tvnz5xf1+2PGjLGrrrrKLrzwQhs2bJgNGDAgMz/Su3dvu+SSS+yoo46ycePG2ZQpU+z111+3mTNn2pAhQ+zUU08N933SpEk2ffp022+//Wzq1Km2ePFiu/LKK23YsGEbVQL5gAMOsE9/+tN27rnn2sKFC2377be3Bx980H7xi1/YKaec0uXf+d+pXC5nU6dOtYsuusjMzL73ve+9o/1tzFx+r33jG9+wX/7ylzZ58mSbNm2ajRkzxlpbW+2ZZ56xn/70p7Zw4UJramqyhoYG+/znP2+zZs2yXC5nQ4cOtXvvvXej/427O2effbb9+Mc/tv33399OPvlk69u3r9100022YMECu+uuu7r808PGiM7Vhx56yMaOHWv9+vWLHSiUKf9/8+fP96985Ss+bNgwr6ys9KqqKh85cqSfcMIJ/uSTT3bZ/vnnn/cjjjjCN910U+/du7dvttlmPnny5ILKT52vlKz/+pCK/z/yyCO+7777ekNDg1dWVvrQoUN92rRp/te//jW/TXdVv5599lmfMGGC19bWelNTkx977LH51zbWfaXgrbfe8pNOOsn79+/vuVyuy6sS1157rY8ZM8arqqq8rq7Ot9tuOz/zzDMLKsF1dHT4+eef7wMHDgwXTunU1tbml19+ue+8885eW1vr5eXlvtVWW/lJJ51U8OqJOmf1mseG+v/444/7lClT/KMf/Wi+uMrkyZMLxtn97XfWx4wZ4+Xl5d2+4vXPf/7Ta2trffLkyZk/X7RokdfU1PiBBx7Y7Xiowinu7jfeeGPBdSz2Wnd69dVXvayszIcPH95tH9Y3ffp0r62tzXz1ptjr5+7+8MMP+9ixY72qqsrr6+v9gAMOkIVTlixZUtCe9WpWe3u7n3zyyd6vXz+vqakpunCK+9uvS06aNMnr6urciiiccuedd/qOO+7oFRUV3rdv324Lp6wva47+6Ec/8q222sorKip85MiRfsMNN2RuV+y91NLS4qeeeqoPGjTIe/fu7VtttVW3hVPWF62o9r//+79uZl5RUZH5al93hUKyFDuX3+l72OtT99u4ceO6vPbX0tLi3/zmN33YsGFeXl7uTU1Nvvvuu/vll19eUJ9iyZIlfvDBB3t1dbX36dPHjz/+eJ87d27R4zFu3DjfZpttiuprZ+GUxsZGr6ys9F122UUWTlm/2Im6Ruu/1lXsXF2xYoWXl5f79ddf36XvG5Jz/4CnU4D32NKlS23gwIH2ne98R77dkKW5udm23HJLu/TSS+3LX/7yu9hDAKm64oor7NJLL7Xnn3++qPzQuvh6TWA9N954o3V0dNiXvvSl0O81NDTYmWeeaZdddlkSX68J4L315ptv2vTp0+1b3/pWeLE2M+MTNvD/fvOb39izzz5r3/72t+3Tn/50QdELAHi/sWAD/2/8+PH2xz/+0caOHWu33nprQUEfAHi/sWADAJAA/g0bAIAEsGADAJAAFmwAABIQqnSWEvVtXqoq0OrVqzPbVVlK9U1VvXv3zmxX1XRUsXt1XFXbWW2vCtmrr8pU/VTHVdur15o6y5gWu591v6qxmP6o/asa3OrVCnUd1/82uA0dV305yvrfKNYpOq/U9VX77476HRVzUXNXbb9mzZpQe1tbW2Z7a2trZnvWNwma/fsrTovdjzqvdb85bl1qrneW4SyW2l7NdTXOak6sWzJ2XereU+OvtlfP0Oj+1T2pzlfdk+/FF8S81/iEDQBAAliwAQBIAAs2AAAJYMEGACABLNgAACSgZFPi6ntVVZq6paUls12li1UiVVH7UcnNaJpaJShVwlQlUtX4qCSs6o/qv0raqmSoSp6qRLPqpxoHlahV5xVN/qrEsRJN6av+d0cl0RU1t6JvCChqP+pNBiU6R9W1VHNUpdnVNVZzNzqH1Dio66j6qY6r5lA0Fa+2V/1U47NixYrM9j59+mS2v/DCC5ntpYhP2AAAJIAFGwCABLBgAwCQABZsAAASwIINAEACSjYlHq0HrJKwqg6xSimrdLTav6o1rZKbqqZ0NPUdTV+r/ata3yqZq9LgKnmq+q/GQbVHRRPQKpmrErLR/URr0XeXHlepY/U70WNE69Src1bXMjrWqj8qea/201OJ/2g/1T2m7km1/+j1Vf1X10U949SzRo2/6o/6HojoWw8p4xM2AAAJYMEGACABLNgAACSABRsAgASwYAMAkICSTYm3tbVltqtkoqKSlSoNrlLWqj2aHo8mN1XiUiUro7XKe6r/0fNS7eq6qHa1H5VOVyl3dX2j56sSwdGa4d3N82j9dCWa7Ffbq7GLppfVOasx7an69Yq6x9Q9E61Jru4xlR5X+1f9ib5horZX469S8Wo/SrR2fcr4hA0AQAJYsAEASAALNgAACWDBBgAgASzYAAAkoGRT4ioxqurRqgSoShGrZGI0sRhNO6uEr0qGqu1VwjR6XtHxUdur8VeJUdUerWGu5km0XnL07QOVnI2OmxqH7voTTQurduWNN94I9Un1p6dS4mpMo3Xwo296qBS02l71U81dRdVaV+MTrc2uxq21tTWzPToO6lmm0ulq+1LEJ2wAABLAgg0AQAJYsAEASAALNgAACWDBBgAgASWbElcJxGjaViVG1fYq3a2SldF0tEpuqjS42v+qVasy2+vq6jLbVWI3So2DOi+VbI0medX26riKmj/ROtbRNHg0Md2daLJfzS2V/lXbR99AUPuJ1kJXc0KljtU9rNLI6lmj0vJRanyi/VdzVI1PT9X6XrlyZWa7uvfUXFf9idZ4TxmfsAEASAALNgAACWDBBgAgASzYAAAkgAUbAIAElGxKXCVho6lgVXNbJTFVu0pWRtPd0TS1Oq5KgyvRmt5q/NX5Rus6q/0o0f2r/keTyCpZHL2Oqj/R/ZvFE+fR1K4SrUkerRGt7j31RkRjY2Nme1tbW2a7muvqvNQcVbW71fir7dX5Rt+QUftXKXR1XtF6+tH0uLIxb0qk6sNzpgAAJIwFGwCABLBgAwCQABZsAAASwIINAEACSjYlrpKMKlmp9FTtbrV9TU1NZrtKdKr+q/1H08XRNHu0hrZKsEZrgCuqn9H6x2p7db49lQaPJrVVP1X9abN4re/oGwjqGqvUsTrnaO1rVbtbba/GNJq+VseNvhERrZWt5lz0zYrovRGdJyqFHr0u3c3pDws+YQMAkAAWbAAAEsCCDQBAAliwAQBIAAs2AAAJKNmUeLQub0/Vmlbtqia5Sko2NDSE9h+ljqvaVZI0mu5WSU+VzFXbq3Z1XJXAjSZko/Mhuh91XlVVVaHtu0vXq7SwSudGRWuAt7e3h/ajzlmluNU9r8ahp+ZudDyjb2iodHpPvWmgRN/0iH6vQLRef7SGecr4hA0AQAJYsAEASAALNgAACWDBBgAgASzYAAAkoGRT4opKRKp2lQCN7kclHFUt8Wgd5Wi7SuCqNLKi0tfRGtpKNDGqEsE9lUJXVHJWUeel5o9K+Krr1V1CWZ2bOnY0dRw9bvRaKtF67up81XFV2jla7z5aW1uNW/R7AtS9qtTV1YX2r2rFq+3Vs09paWnJbO+pN2dSwCdsAAASwIINAEACWLABAEgACzYAAAlgwQYAIAElmxJXCURVf1dtrxKgKnkaTZUr0RS6qsurRGunR1P00RrjKnGstm9tbc1sV2ltVa9apdajiVo1/uq40XkSTQSr62sWrymt0r/q2qgxVX1VqWm1vep/9Lyi9eVVf6I1tKP16NX4K2r/6o0C9UyMpuijotdRjX/0DY2U8QkbAIAEsGADAJAAFmwAABLAgg0AQAJYsAEASEDJpsRV4lIlEFWqVrWrOrvR/tTW1ma2R2toq9S6SlCq7dX5qhR0tJ60ShCr8YkmeVWCNZoGj9Z1VvtR11f1J9pPlfztrhZ6NAUdvZfUNYim0NU5qDFSc12dl7rHonXkFdUfNZ4qra1Ea79HxyFaU13tX52Xeqao/atnEylxAADwgcKCDQBAAliwAQBIAAs2AAAJYMEGACABJZsSV8nBaLJSJRZVu0qPq1rlKvmo9q9SwWp7lXZWCV+1H5WcVftXoklVlY5W7Sp5qo6rqOSyur7RutpqPKP7Ue3dXRc1RtEUd5S6J9W1VOem5q7ajzqu2r+q866Oq/ajxlm9oaFE39BQc13tR52vmqNqPFUaXJ2v6r/af/R8SxGfsAEASAALNgAACWDBBgAgASzYAAAkgAUbAIAElGxKXCUio4lOlZqOUvtRCc1oUlLtX52vSjur46o0e3t7e2a7Simr+sTResDRBGtUtJa4Sncr0eSvmieqn931J3pslfJV115dA9Wu+qrS6dFUebSmt5qj0TmhtlfnFa3RrcYtmsqO1k5X/VFvsKjjRp/RajzV2w2liE/YAAAkgAUbAIAEsGADAJAAFmwAABLAgg0AQAJKNl6nkpIqEalqfUfr76o0tdqPSqSq7VW6W52v6me07q8SrSWuqIRpa2trZns0WRxNwkaTwtG619F2db7RZLeZHoueehNAzS3VV0Wlf1VKXKWXo9urcVDnFU1x19bWZrarOafS49H0dZTafzT9rsZB9VONP/iEDQBAEliwAQBIAAs2AAAJYMEGACABLNgAACSgZFPiKmEarTWt0ssqKakSnWp7leZVSUmV3FRJVbW9SqqqtLOiEscqUavOq62tLbNdJUlV8leNv+qP2l7tX1Hbq7cPVHJWzbdoDfzuEtnqGOoc1BxV10z1VbWrezW6/2haW4mmo6P3fPT7AJRoDW317Ium8aNvOKhnRLSWuOpn9JmVMj5hAwCQABZsAAASwIINAEACWLABAEgACzYAAAko2ZS4qqGtEoUqzRutra2SqqpOcLSfKmWt+h9NVqpErWpXSVK1vRofJZqCVmlw1R+VYFXjptqj499T/VHjo5Lg3f2OEq1ZrdLR9fX1me2qBrWao6q+vBqj6PbRaxMdT7V/db7qWqpxU+el+q/6E312qJS7Om60pn00PV6K+IQNAEACWLABAEgACzYAAAlgwQYAIAEs2AAAJODDE6/7f6qebjRBGa03rNLg0URqtI6yovqjkp7R+sfR9Hg0qarS4Gr8VUpfUcnfaHJWtata7qr/6rhq/LtLzqo+Red0tHa0umaK2k/0TYzoPRMd0+j3E0TrwkffHFD7V9dRib6JEU3RR+dV9NlRij48ZwoAQMJYsAEASAALNgAACWDBBgAgASzYAAAkoGRT4qoGuEp0qvq1KvlYV1eX2R5NeqqkquqP2r/aXiVJVUq8ra0tsz1aS1wlSVWCVSV81f5VkjSaXI4mTFV/VOJV7b+9vT2zXSWU1XWP1rE202Otalara6nupWjCXrVH+6Pu+ZaWlsz2aJpdtatngaphrqj+R9P7qp/RZ1O0drfqp5on6tmntlf7j74hkzI+YQMAkAAWbAAAEsCCDQBAAliwAQBIAAs2AAAJKNmUeDT5qLaPpsdVylelhVUSViUlVf+jSd5Vq1Zltqv+q/S4Sqqqfq5evTqzXSWXVbI1mvZfuXJlZnu0vnK0DrTqpxpntb26vtHErpk+52htbXXN1P6jdfzV/lWaWs0t9UaESuor0WdHNMGv9tNTNbej4x+d0yqNHx03JfpWQiniEzYAAAlgwQYAIAEs2AAAJIAFGwCABLBgAwCQgJJNiUeTidHUt0rnquOqRKfav6ISoCodrfqj9tNdujiLSl+rhG+07q9K8qrxV2n26PVS1Lip9miCNVqHO5r87e4Y6tpEa1mr/ag3IhR1XHWN1b0UrVkdraGtxiGaglbUvaSocYg+m9T4KGr/qqa6evNB9Ue9BRD9PoCUfXjOFACAhLFgAwCQABZsAAASwIINAEACWLABAEhAyabEo/V01fbR2t0qMdrQ0JDZrpKn0ZS1Sv5GU/HResZqPFX/o8nZaAI3eh2jtcRVGl/tR12XaA12lfxV/ekuna7mnBJ940LNCZUSV/uJ1k+P3tvRuR59s0Jdm+ibEtFrHK0BrkTr2qvzis6f6PhEzytlfMIGACABLNgAACSABRsAgASwYAMAkAAWbAAAElCyKfFo/V21fbQ2dZRKRKr9R2tiR2tZq/2/+eabme3RZKjaT0VFRWa7ui4qGarS1NFxiNaZVklhlSxW56XGP5p07o6q7VxXV5fZHn0zQY1FdOzUnFBUrWnVruaKGlM1d1W7SlOrcY7O0Z6qg6/GITpH1X7U9wGocVbbq3lILXEAAPCBwoINAEACWLABAEgACzYAAAlgwQYAIAElmxJXydOeEk0mqmRltNa0ao+mr1VCUyVPVf9VEla1q/1H0+Yqkaquu2pX+4nW+o7Oh+j1VYndjUnIqjkRHSOVvlY1w5csWZLZHj3naJo6muBX1HFVP6O106P3RvSZop4RSnQ/Kt2ttlfzR4mOQyniEzYAAAlgwQYAIAEs2AAAJIAFGwCABLBgAwCQgJKN16mUcjQJq2qJq4SpSlb2VF1hlaitra3NbI/W4lZJTJUAVe1q/2rcov1RSVs1buq40VR8tNa36k+0VrlKHKv+VFdXZ7ab6b6qsY6mrNWcqK+vD22vxq6n6uOrez5at72ysjKzXV1LldKP1sdX/YleXzWeqv9qbkVrhivLly/PbFfjo45biviEDQBAAliwAQBIAAs2AAAJYMEGACABLNgAACSgZFPi0XSxSnGrhKlKLKq6uU1NTaH+qISmotLgqv+qn6pdna/qf3Sco4ndaE1vla5X80SNp0rIRmuhq2RxNF0fTaF3J5pcj9Z5V7Wj1TmoNHVra2vouCrlrpL3ao5G52JP1TxX93B0TqvxUdddjc+KFSsy29U4RJ8d0Zrq6l4qRXzCBgAgASzYAAAkgAUbAIAEsGADAJAAFmwAABJQsinxaP1alWRU7dGa1YpKeqqErEpQRlPZ0RS92j6ajo7WpY4ml9X20RryKnmqxqempiazXVH7Udc9msZX18UsniCPpqDVmwbRGt1qe5XUV8dV95jaPnqPKdFrpuaiao++2aKul+qPumei33OgRL8/IPoGSyniEzYAAAlgwQYAIAEs2AAAJIAFGwCABLBgAwCQgJJNiSvRRKHaXiUxVXu0HnM01aySlSrRqWplq3a1/2jd5Wg95miqPFpPWokmhXvq7YBo/emNqWOtzkEl41U6N5oijtZJV6L7UfekmtPR81Lbq3spWgdfidaXV/2Mfm+Buseam5tD+2lra8tsj9YeV+NciviEDQBAAliwAQBIAAs2AAAJYMEGACABLNgAACSgZFPiKtkaTYyq2s5qPyr9qxKdav+KSoCqpKc6L9VPlfqO1lqPJjd7KukZ3Y8aNzUOdXV1me1qvkXrQEdT4mr/KoFrZlZVVZXZ3traGupTtFZ2tF589E2D6JsAKhUffWMhmtRX10w9U9Rxo7W7o99/EK0xrs4rOq9Wr16d2U4tcT5hAwCQBBZsAAASwIINAEACWLABAEgACzYAAAko2ZS4Siy2t7eHtlcJ0Gi9W0UlPVW6W+0/mn5X/VfnG01BR5PC0YSvOm40ja/ao/NBba+o667GQfVT9Uclwbv7nerq6sx2lbyPJtrVHFXnFr3HGhsbM9tV/1WSPpqKV6L16KPt0Rrg0WdE9PpG565Kj0efWR8mfMIGACABLNgAACSABRsAgASwYAMAkAAWbAAAElCyKfGWlpbM9mj9YJVYVLWgo3WFV61aldleX18f6k80VR5Na6v+RxOdavto4lX1XyVhowlT1c+amprMdpUSV9urfqr9RFPo3Z1vdIyix45eY9VeW1ub2a7S3eoNkGh6Wc11tR81buoZEX1DQF2vaGpd6anvV1Dt6l6NjrNK+3+Y0uN8wgYAIAEs2AAAJIAFGwCABLBgAwCQABZsAAASULIp8Wh9YpVYVO0q6RlN/6rtVVJS1YhWiVGVqFW1wdX4qP6o/URT32qc1X5U8jRaSzy6vaKuY0/V246OT3fJ7mif1BhFU8Qq5ave3FBzV/VH3ZPquNG0drTuv7qW0br80e8zUP1U20ffuFi9enVm+8qVK0P7j16X6DiXIj5hAwCQABZsAAASwIINAEACWLABAEgACzYAAAko2XidSkRGa31H6ytHa01H0+YqKakSl6qesUqqRrdXydBoAjc6ztH6wdHtVdI5mtiNJlujKXHVz+5S4iq1q35HzS3Vp2i6W42dOjc1F9WbEurNCnVt1P4VNQ7qfBV17XvqmRWdW9HvS+jfv39m+9KlSzPb+/Xrl9munhF1dXWZ7aqGfCniEzYAAAlgwQYAIAEs2AAAJIAFGwCABLBgAwCQABZsAAASULKvdalXPBT1CoN6xSBa0D76pRrRVzbUKzDqlQfVT9Uf9aqOGmfVHn11JXodo1+Goa6jGofq6urQcdX5qvNS+1Gvyantu/vyFXXO0S9yUV8Cob7MQ81Ftf/ol3Coe0O1q9e91L2h+qnGLfoFNWrcoq9CRr/IKPolItFnR/T6qnmlntHRL+pJGZ+wAQBIAAs2AAAJYMEGACABLNgAACSABRsAgASUbEpcJRlVajf65Q0qGRr9sgqVoFTJR3Ve6ks4VP9VSjmaGI1+MYFKgHb3ZRVZol+soK6XSuaqcVNJZ9V/1a7moRJN4HYnOhejiX/V155KO6v9qGumRJP90S9+UalmNZ5qezWHoql1dQ+rcYt+YVF0rkefleq6q3ErRXzCBgAgASzYAAAkgAUbAIAEsGADAJAAFmwAABJQsilxRSU0VXpZJRNVwlElLqMp9Gg942h/VHtra2tmezRZrJKh0XS0SgqrFH004RtNd0fT6WrcovtRNmaco2lkdQ7qGKrufG1tbWj7nqrzru4lddxo7etouj5ao1tdl+ibMNE3NNT5KuoNiugbF9E3H9QzohTxCRsAgASwYAMAkAAWbAAAEsCCDQBAAliwAQBIQMmmxFVCU4nWzVVqampC+1HJWZV8VHWCVeJSJUNVDW21vUoKq+Oq8Y/WHldJYZXwVfuJJnlVCl0lkVU/VZJXtav9qPZobXyzeAI+eu3VHI3W4o6+mRCtYR6dK9GUu+q/Gp9o2lk9U6LnG71X1fxRz5SlS5dmtkdrjKt2NZ6liE/YAAAkgAUbAIAEsGADAJAAFmwAABLAgg0AQAJKNiUerUOsqKSkau+pWuLRpK1KUKq60ao9WntcJT17MtWcJVpvWKX3FZUS76la4irhq5K8Kp2utu8ucaxSwep3VKJdnUP02qg5oe5VdQ3UPaPuDXW+Ku3c0tKS2R5NWat2dU8qahyi6fdoGl+9SRJ9I0Wl69WzJjofStGH50wBAEgYCzYAAAlgwQYAIAEs2AAAJIAFGwCABJRsSnzVqlWZ7SrZquoiq+Sm2r/aPpq0VYlXlehUCc1of9T2Kmmr+qOSsyoBqkTT1CrZGh1/NR9UelxRCVk1bmr/0frT3aXEo3XqVZ+iqWYlmlqPvrGgrqXaXlH3QPR7C1T/1Tj3VP+jcyX65oa6l6L3vNpepco/TPiEDQBAAliwAQBIAAs2AAAJYMEGACABLNgAACSgZFPiikrCqqSnao8mRqM1zFW6WB1XWbFiRWa7SmKqmtWq/62trZnt0drd6rzUcaO1vtX+o/WY1bhFk7ZqezV/1HyI1pbvjvodlVJWYxdNTauxUNdepZGjbzioNwfUmEZrmEdrXKtxU+erxkfdw2r/qp9qHNR+1Jszqp/qzZZoLfEPEz5hAwCQABZsAAASwIINAEACWLABAEgACzYAAAko2ZS4SlYqKlmp6vWq9qqqqsx2lfJVCU2VLlb1dFXiVfVTJWdV0lNR46z2H01fq6SqSuaqcVbJ1uj1jdZvjtZRVvMhmuTtrg509M0HlchXc1Edu6eupTpubW1tZrt6k0HNXbW96r8an+g9rMYhetyWlpbM9sbGxsx21U+Vylbt6tmn0vvRZ5OaD+rZV4r4hA0AQAJYsAEASAALNgAACWDBBgAgASzYAAAkoGRT4ioJqxKXKs0brQesEpEqWakSmtHkY0+lwdX5qv5Ha3ErKrEbrWesRNPXKoGr+qPao3Wd1XVU26txVv0xi9cAj9Zzj9brV/tRx1VzWtWmVmMUTcWr81JzV9Vgjz6b1PbRN0OitbijzybVz556a0AdN/pGUMr4hA0AQAJYsAEASAALNgAACWDBBgAgASzYAAAkoGRT4ioxGk0UquRjT9XHjaZ/VZ1jdb4qDa76o85LbR+tfxwdh2iCWB03mlpX46DqJUcTsuq8VIJb1VFW59VdzXN1zVRf1THUnIveM2pORN9AUGOq0trRGtqqn9F0d7T+uzpf1d5dHfksaq6oORe995YsWRLqj6Le5InW608Zn7ABAEgACzYAAAlgwQYAIAEs2AAAJIAFGwCABJRsSjxap1YlXlX9WrV/lUhtaWnJbG9oaAgdVyVVlWh6PJqmVuMQrcGu0vvRxG40URtNWUcTyj2V5O2p1H13+1K/E6233tbWltmu6qqre0Zdg2gqOPqGg0pHq3s4OtfV/qPp/ei9oa5vT6WvV65cGeqPesZFv7dAjVsp4hM2AAAJYMEGACABLNgAACSABRsAgASwYAMAkICSTYmrpKGikpiq/nG0hraqQa36qRK1KompEp1q/yrh21O1waMpbtVPlR6PJp2jaW21vUq2qv2ofqrzUvtRCVn1FkB9fX1mu5keI3XtoyniaKpZnZvqp7oG0fr76riqn9F68eqeV6l4RaXK1bMpOhejKW61fWNjY2a7emZF72FF7acU8QkbAIAEsGADAJAAFmwAABLAgg0AQAJYsAEASEDJpsRVglUlClUyVG2vkpvRhKM6rkqSqv1H60Cr81L1m9X+VdpcJUnV/lVCOVrTW+1H9V+ltdVx1XVX4xCdD+q8ovOwO9Ea0SotrKgxVcdV26sUtEpfq/2r7ZubmzPb1T0T7afaj2pX90x0HNQzJXpvR1Pc6o2F6PcEqHtYbR+ty58yPmEDAJAAFmwAABLAgg0AQAJYsAEASAALNgAACSjZlLhKDqqkpKISmsuWLctsr62tzWxva2vLbFcp32hqWiU0Veo4mvRUKWiVJFVJWEVtr/qv2qO1x1W72n90HFTyN5rejyZhu0unqz6pFG70jYjoXFTbR1PTaozUvRQdB9X/6NxV/VTtapxVP6NpbfVMUXNd1WZX1FsG0Tdb1DP0w4RP2AAAJIAFGwCABLBgAwCQABZsAAASwIINAEACSjYlHk2AqjrEKulZX18f2o+qDR5NaKoUsUqnr1y5MnRcNW7quCp5Gq0lrqjtVXtPpcFVnWaVOG5sbMxsV6L1pFWCOHq9uttXNKGu5no0NR2tp6+uveqnmqOqn+q40f4r6t6OPoOi9f1VP1WKW41z9HqpeaXuJdWurle0Xn/K+IQNAEACWLABAEgACzYAAAlgwQYAIAEs2AAAJKBkU+KKSlCqJKZKOK5atSqzXaWFo+noaJ1m1U+Vdlbbq6RqT6WyVWJXJVWjdayjiWCVslbzobq6OrNdjZt6m0BR+1HjoOpqdzffoqlalQpWx1bXMlorW10DNUbRN0Ci6e7o9xOoe0Cdr9peXXv1DIrOaXXPqLR2tFa52l6Nf/Qthg8TPmEDAJAAFmwAABLAgg0AQAJYsAEASAALNgAACSjZlLhKYqrEaLRestq/Ss6qRKRKVqoEaDTNrtpVklel6FXKWm1fV1eX2a7GOZreV8lZdb2iNbdVfxTVH3XcaGo9mrpX42ymx1QdW6WF1bGjKWjV12haWB03Whe+u7HLot4caGtry2xX46zuyWgqOzrnVMpd9b+n3jBR59va2hpqV/deKeITNgAACWDBBgAgASzYAAAkgAUbAIAEsGADAJCAkk2Jq1R2ND0erYscrd0drU+sqMSoSq33VL1hdb7RtLM6rmpX46OSsNHrEu2POl8131SyVc3bhoaG0PZqHMz0nFOpYEUdI/rGhboG0VrTak6oaxZNQUfT79G5VVNTk9kerXne3t4e2r/aXl1ftb1KcatxVs/W6PirN2pKEZ+wAQBIAAs2AAAJYMEGACABLNgAACSABRsAgASUbEpcJVJVu0rt9lSqPJoSVwlNlRhVydPGxsbMdtVPdVw1PiqZq9pVWjtav1ntX6W11f6j1131X42bqjOt+lNbW5vZrupJq+1VnWwzfe2jKWL1BkK0Bng0Jd5TaedoPXq1HzWe6s0KdW3UnFNzurm5ObM9OqejtdzV+Le0tGS2R7+PQV33aOq+FPEJGwCABLBgAwCQABZsAAASwIINAEACWLABAEhAyabEle5qLGdRSU+VcIzWAFf9qaioyGxXydy6urrMdlXfV22vEpoq7azGR22vEqMqCav2r9pV/6O1x1X/VVpbXS/Vn2gaXF0vNR+6S86qvkbfrFDXTKWmVcJenYNKBauUtTovtX30Xo3WbVfXQNW+Vs8U1d7U1BTavxp/lTZX95iqDR59k0RdF9UefUOmFPEJGwCABLBgAwCQABZsAAASwIINAEACWLABAEhAyabEVcJUJTqjCU0lWidY1etVyVmV9FTHra+vz2yP1j9WCVyVFFbJULX/aE3ydzulr8ZTjX+USryq8VHU9iptbqavWXSOqjcQFNUnlfKNXmNVSzw6d6N1/9W9FL3H1JyI1vpWbziocVPpbjUOqv/RtwyU6PcKRLdPGZ+wAQBIAAs2AAAJYMEGACABLNgAACSABRsAgASUbEpcJVJVKlglNFVdXpVMVIlIlbhU26skpjquSryq7fv27ZvZrsahpqYmtH/Vf5W+jiZz1fmq6xtN5irRutHRtLwSTRD3ZH1lNdYqjazS0Wos1BsR0br/qj67So9H9x99A0GlrxV1b3SX+I/sR10v1U81txYvXhzaXs1ddVz1TFT7j77JkzI+YQMAkAAWbAAAEsCCDQBAAliwAQBIAAs2AAAJKNmUuEoUqgRltA6uSpiqRKc6rkrzquP26dMn1J9okrehoSGzPVpXWJ2Xquus9hNN9UfrZKvto/uJ1kiPJmej49Bd+l3dG+oaq2vZU7Wmo29ERNPIas5FU9/RWuvRNzeU6LNGvdERTV+r7dW4Rb9HYeXKlZntSnTeliI+YQMAkAAWbAAAEsCCDQBAAliwAQBIAAs2AAAJKNmUeDTxGk2AqhrjAwYMCO1HpbVVXWTVf5UkVcdVSVKVAK2qqspsV8lQtX9FHVddl2idZjVuPZWAjtaKV9dLba+o43aXElfXLDpH29raQsfuqTRydK5Exyh6L0Xbo/2JpsqjNbdV2jyaclfjpvajjqva1TyM1uVPGZ+wAQBIAAs2AAAJYMEGACABLNgAACSABRsAgASUbEpcJQejCVmVSFXJR5VejtayjtboVqnyaD/V9irZqo4bTeyqhGk0savGR52X2r/qfzTJG627rOpeq4Ss2r67ZLE6N9UefTNBbf/GG29ktqtzUGOt9q/GtKdqlat+RkXHWW2vvidApffVmx7quO3t7Znt6tkR/V6E6Jsb0fR7KeITNgAACWDBBgAgASzYAAAkgAUbAIAEsGADAJCAkk2JNzc3Z7ZHU8oqqaqSiSoJ29DQEDqu6qdKqqokpkp0quOqpG1dXV1muxoflcxVqelofWWV+laiyeVoSl+dVzTNrsa/J8dHzV2VIlZpYbV9tDZ49M2N6P7VXFfJ+2iddzXWapyjNbfVNVbXRc1dddxoqlw9W6PPJnXdo6L191PGJ2wAABLAgg0AQAJYsAEASAALNgAACWDBBgAgASWbElep2tWrV2e2qxR3NLWrkp4qidmvX7/Q/lXyNJrKjtbWVoladdwodVy1/2gCN1qXWu1HUf1R4x9Nd0frLneXnI0m7NU5qIS9olLf0Tc3ogn7aFJfpaPV+Ubv1WgaPFrvPlqPvr6+PrNdPbNUf9T4qPFUtdCXLVuW2a6QEgcAAB8oLNgAACSABRsAgASwYAMAkAAWbAAAElCyKXGVbO3fv39mu0pcquSmSkSq+r4qGaqSmzU1NZnt6rxUMlT1X1H9jNZpVjXMVWI3Wv9YXa+e2n80savaVX/U+KjksuqPmj8bQ+1L9UnN9WideiWa7lZzNzon1L2troG6N1QKWvVHjY+6txsbGzPb1b2hxk1R6evos3LVqlWh7dX4q+8zUONZiviEDQBAAliwAQBIAAs2AAAJYMEGACABLNgAACSgZFPi0Xq3KjGqEpoq5auSodHkpkpQqrrFitq/2o8ah2jt62gyN5rKVvtXiV01/moc1H6i1H5U/6OJ1+h86O4Y0Wuj5rqq169Ek/dqTFUaOfoGRfRNCXUNoil6RV2vaD19NQ6qZriizlc9W9WzUo2z6me0Nnsp4hM2AAAJYMEGACABLNgAACSABRsAgASwYAMAkICSTYmrWtwqCVtVVZXZrhKRqu6vSlxG09q1tbWZ7SpBqfavkrwq8arqB6txi6bf1bhFk57R9Hi0ZrjqjxofdR17Kv0eravdXY3x6JxW1DGiCfvomxIqjazuGXXN1HGjNbSj6WW1vXpmqeui7kn1DIq+BaDGQT2D1HWP1vGvr6/PbFdvH5ASBwAAHygs2AAAJIAFGwCABLBgAwCQABZsAAAS8KFLiSsqWTlo0KDM9oaGhsx2lXBUKXSVlFQJ1miKWCU9oylxlSxW7SrRqc5XJWejieNoGjw6ztH9REVrj6v+qDrWZj1Xv15dMzWHVNpZ7T96bdT+VT/V+UbT4NHEf7SuffTejl6v6PcoRNPj0drm0WecqlVeiviEDQBAAliwAQBIAAs2AAAJYMEGACABLNgAACQg5yp6BwAAPjD4hA0AQAJYsAEASAALNgAACWDBBgAgASzYAAAkgAUbAIAEsGADAJAAFmwAABLAgg0AQAL+D4b0nQYYsk47AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "# Generate conditional images for each disease\n", + "model.eval()\n", + "vocab = sample_dataset.input_processors[\"disease\"].code_vocab\n", + "diseases = list(vocab.keys())\n", + "\n", + "with torch.no_grad():\n", + " for disease in diseases:\n", + " # Sample from latent space\n", + " z = torch.randn(1, model.hidden_dim).to(model.device)\n", + "\n", + " # Create conditional data\n", + " condition_indices = [vocab[disease]]\n", + " cond_data = {\"disease\": torch.tensor(condition_indices, dtype=torch.long).unsqueeze(0)}\n", + "\n", + " # Get embeddings\n", + " cond_embeddings = model.embedding_model(cond_data)\n", + " cond_vec = cond_embeddings[\"disease\"].mean(dim=1)\n", + "\n", + " # Reshape for decoder\n", + " z_reshaped = z.unsqueeze(2).unsqueeze(3)\n", + " z_reshaped = z_reshaped + cond_vec.unsqueeze(2).unsqueeze(3)\n", + "\n", + " # Generate image\n", + " generated = model.decoder(z_reshaped).detach().cpu().numpy()\n", + "\n", + " # Save individual image\n", + " filename = f\"conditional_vae_{disease.lower().replace(' ', '_')}.png\"\n", + " plt.figure(figsize=(5, 5))\n", + " plt.imshow(generated[0].reshape(image_size, image_size), cmap=\"gray\")\n", + " plt.title(f\"Generated Chest X-Ray (Conditional on {disease})\")\n", + " plt.axis('off')\n", + " plt.show()\n", + " plt.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/covid19cxr_conformal.ipynb b/examples/covid19cxr_conformal.ipynb new file mode 100644 index 000000000..686bf94e4 --- /dev/null +++ b/examples/covid19cxr_conformal.ipynb @@ -0,0 +1,459 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conformal Prediction for COVID-19 Chest X-Ray Classification\n", + "\n", + "This example demonstrates:\n", + "1. Training a ResNet-18 model on COVID-19 CXR dataset\n", + "2. Conventional conformal prediction using LABEL\n", + "3. Covariate shift adaptive conformal prediction using CovariateLabel\n", + "4. Comparison of coverage and efficiency between the two methods" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/PyHealth/pyhealth/sampler/sage_sampler.py:3: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " import pkg_resources\n", + "/home/ubuntu/PyHealth/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "This is a warning of potentially slow compute. You could uncomment this line and use the Python implementation instead of Cython.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.calib.predictionset import LABEL\n", + "from pyhealth.calib.predictionset.covariate import CovariateLabel\n", + "from pyhealth.calib.utils import extract_embeddings\n", + "from pyhealth.datasets import (\n", + " COVID19CXRDataset,\n", + " get_dataloader,\n", + " split_by_sample_conformal,\n", + ")\n", + "from pyhealth.models import TorchvisionModel\n", + "from pyhealth.trainer import Trainer, get_metrics_fn\n", + "\n", + "# Set random seed for reproducibility\n", + "torch.manual_seed(42)\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 1: Load and prepare dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "STEP 1: Loading COVID-19 CXR Dataset\n", + "================================================================================\n", + "No config path provided, using default config\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Raw COVID-19 CXR dataset files not found in ~/Downloads/COVID-19_Radiography_Dataset. Please download the dataset from https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database and extract the contents to the specified root directory.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mValueError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33m=\u001b[39m\u001b[33m\"\u001b[39m * \u001b[32m80\u001b[39m)\n\u001b[32m 5\u001b[39m root = \u001b[33m\"\u001b[39m\u001b[33m~/Downloads/COVID-19_Radiography_Dataset\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m base_dataset = \u001b[43mCOVID19CXRDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 7\u001b[39m sample_dataset = base_dataset.set_task(cache_dir=\u001b[33m\"\u001b[39m\u001b[33m../../covid19cxr_cache\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 9\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mTotal samples: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(sample_dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/covid19_cxr.py:96\u001b[39m, in \u001b[36mCOVID19CXRDataset.__init__\u001b[39m\u001b[34m(self, root, dataset_name, config_path)\u001b[39m\n\u001b[32m 94\u001b[39m config_path = Path(\u001b[34m__file__\u001b[39m).parent / \u001b[33m\"\u001b[39m\u001b[33mconfigs\u001b[39m\u001b[33m\"\u001b[39m / \u001b[33m\"\u001b[39m\u001b[33mcovid19_cxr.yaml\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 95\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._check_raw_data_exists(root):\n\u001b[32m---> \u001b[39m\u001b[32m96\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 97\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mRaw COVID-19 CXR dataset files not found in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mroot\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 98\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mPlease download the dataset from \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 99\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mhttps://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 100\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mand extract the contents to the specified root directory.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 101\u001b[39m )\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os.path.exists(os.path.join(root, \u001b[33m\"\u001b[39m\u001b[33mcovid19_cxr-metadata-pyhealth.csv\u001b[39m\u001b[33m\"\u001b[39m)):\n\u001b[32m 103\u001b[39m \u001b[38;5;28mself\u001b[39m.prepare_metadata(root)\n", + "\u001b[31mValueError\u001b[39m: Raw COVID-19 CXR dataset files not found in ~/Downloads/COVID-19_Radiography_Dataset. Please download the dataset from https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database and extract the contents to the specified root directory." + ] + } + ], + "source": [ + "print(\"=\" * 80)\n", + "print(\"STEP 1: Loading COVID-19 CXR Dataset\")\n", + "print(\"=\" * 80)\n", + "\n", + "root = \"~/Downloads/COVID-19_Radiography_Dataset\"\n", + "base_dataset = COVID19CXRDataset(root)\n", + "sample_dataset = base_dataset.set_task(cache_dir=\"../../covid19cxr_cache\")\n", + "\n", + "print(f\"Total samples: {len(sample_dataset)}\")\n", + "print(f\"Task mode: {sample_dataset.output_schema}\")\n", + "\n", + "# Split into train/val/cal/test\n", + "# For conformal prediction, we need a separate calibration set\n", + "train_data, val_data, cal_data, test_data = split_by_sample_conformal(\n", + " dataset=sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15]\n", + ")\n", + "\n", + "print(f\"Train: {len(train_data)}\")\n", + "print(f\"Val: {len(val_data)}\")\n", + "print(f\"Cal: {len(cal_data)} (for conformal calibration)\")\n", + "print(f\"Test: {len(test_data)}\")\n", + "\n", + "# Create data loaders\n", + "train_loader = get_dataloader(train_data, batch_size=32, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=32, shuffle=False)\n", + "cal_loader = get_dataloader(cal_data, batch_size=32, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=32, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 2: Train ResNet-18 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 2: Training ResNet-18 Model\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Initialize ResNet-18 with pretrained weights\n", + "resnet = TorchvisionModel(\n", + " dataset=sample_dataset,\n", + " model_name=\"resnet18\",\n", + " model_config={\"weights\": \"DEFAULT\"},\n", + ")\n", + "\n", + "# Train the model\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "trainer = Trainer(model=resnet, device=device)\n", + "\n", + "print(f\"Training on device: {device}\")\n", + "trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " epochs=5,\n", + " monitor=\"accuracy\",\n", + ")\n", + "\n", + "print(\"✓ Model training completed\")\n", + "\n", + "# Evaluate base model on test set\n", + "print(\"\\nBase model performance on test set:\")\n", + "y_true_base, y_prob_base, loss_base = trainer.inference(test_loader)\n", + "base_metrics = get_metrics_fn(\"multiclass\")(\n", + " y_true_base, y_prob_base, metrics=[\"accuracy\", \"f1_weighted\"]\n", + ")\n", + "for metric, value in base_metrics.items():\n", + " print(f\" {metric}: {value:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 3: Conventional Conformal Prediction with LABEL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 3: Conventional Conformal Prediction (LABEL)\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Target miscoverage rate of 10% (90% coverage)\n", + "alpha = 0.1\n", + "print(f\"Target miscoverage rate: {alpha} (90% coverage)\")\n", + "\n", + "# Create LABEL predictor\n", + "label_predictor = LABEL(model=resnet, alpha=alpha)\n", + "\n", + "# Calibrate on calibration set\n", + "print(\"Calibrating LABEL predictor...\")\n", + "label_predictor.calibrate(cal_dataset=cal_data)\n", + "\n", + "# Evaluate on test set\n", + "print(\"Evaluating LABEL predictor on test set...\")\n", + "y_true_label, y_prob_label, _, extra_label = Trainer(model=label_predictor).inference(\n", + " test_loader, additional_outputs=[\"y_predset\"]\n", + ")\n", + "\n", + "label_metrics = get_metrics_fn(\"multiclass\")(\n", + " y_true_label,\n", + " y_prob_label,\n", + " metrics=[\"accuracy\", \"miscoverage_ps\"],\n", + " y_predset=extra_label[\"y_predset\"],\n", + ")\n", + "\n", + "# Calculate average set size\n", + "predset_label = (\n", + " torch.tensor(extra_label[\"y_predset\"])\n", + " if isinstance(extra_label[\"y_predset\"], np.ndarray)\n", + " else extra_label[\"y_predset\"]\n", + ")\n", + "avg_set_size_label = predset_label.float().sum(dim=1).mean().item()\n", + "\n", + "# Extract scalar values from metrics (handle both scalar and array returns)\n", + "miscoverage_label = label_metrics[\"miscoverage_ps\"]\n", + "if isinstance(miscoverage_label, np.ndarray):\n", + " miscoverage_label = float(\n", + " miscoverage_label.item()\n", + " if miscoverage_label.size == 1\n", + " else miscoverage_label.mean()\n", + " )\n", + "else:\n", + " miscoverage_label = float(miscoverage_label)\n", + "\n", + "print(\"\\nLABEL Results:\")\n", + "print(f\" Accuracy: {label_metrics['accuracy']:.4f}\")\n", + "print(f\" Empirical miscoverage: {miscoverage_label:.4f}\")\n", + "print(f\" Average set size: {avg_set_size_label:.2f}\")\n", + "print(f\" Target miscoverage: {alpha:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 4: Covariate Shift Adaptive Conformal Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 4: Covariate Shift Adaptive Conformal Prediction\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Extract embeddings from the model\n", + "# For TorchvisionModel, we extract features from avgpool layer (before fc)\n", + "print(\"Extracting embeddings from calibration set...\")\n", + "cal_embeddings = extract_embeddings(resnet, cal_data, batch_size=32, device=device)\n", + "print(f\" Cal embeddings shape: {cal_embeddings.shape}\")\n", + "\n", + "print(\"Extracting embeddings from test set...\")\n", + "test_embeddings = extract_embeddings(resnet, test_data, batch_size=32, device=device)\n", + "print(f\" Test embeddings shape: {test_embeddings.shape}\")\n", + "\n", + "# Create CovariateLabel predictor\n", + "print(\"\\nCreating CovariateLabel predictor...\")\n", + "covariate_predictor = CovariateLabel(model=resnet, alpha=alpha)\n", + "\n", + "# Calibrate with embeddings (KDEs will be fitted automatically)\n", + "print(\"Calibrating CovariateLabel predictor...\")\n", + "print(\" - Fitting KDEs for covariate shift correction...\")\n", + "covariate_predictor.calibrate(\n", + " cal_dataset=cal_data, cal_embeddings=cal_embeddings, test_embeddings=test_embeddings\n", + ")\n", + "print(\"✓ Calibration completed\")\n", + "\n", + "# Evaluate on test set\n", + "print(\"Evaluating CovariateLabel predictor on test set...\")\n", + "y_true_cov, y_prob_cov, _, extra_cov = Trainer(model=covariate_predictor).inference(\n", + " test_loader, additional_outputs=[\"y_predset\"]\n", + ")\n", + "\n", + "cov_metrics = get_metrics_fn(\"multiclass\")(\n", + " y_true_cov,\n", + " y_prob_cov,\n", + " metrics=[\"accuracy\", \"miscoverage_ps\"],\n", + " y_predset=extra_cov[\"y_predset\"],\n", + ")\n", + "\n", + "# Calculate average set size\n", + "predset_cov = (\n", + " torch.tensor(extra_cov[\"y_predset\"])\n", + " if isinstance(extra_cov[\"y_predset\"], np.ndarray)\n", + " else extra_cov[\"y_predset\"]\n", + ")\n", + "avg_set_size_cov = predset_cov.float().sum(dim=1).mean().item()\n", + "\n", + "# Extract scalar values from metrics (handle both scalar and array returns)\n", + "miscoverage_cov = cov_metrics[\"miscoverage_ps\"]\n", + "if isinstance(miscoverage_cov, np.ndarray):\n", + " miscoverage_cov = float(\n", + " miscoverage_cov.item() if miscoverage_cov.size == 1 else miscoverage_cov.mean()\n", + " )\n", + "else:\n", + " miscoverage_cov = float(miscoverage_cov)\n", + "\n", + "print(\"\\nCovariateLabel Results:\")\n", + "print(f\" Accuracy: {cov_metrics['accuracy']:.4f}\")\n", + "print(f\" Empirical miscoverage: {miscoverage_cov:.4f}\")\n", + "print(f\" Average set size: {avg_set_size_cov:.2f}\")\n", + "print(f\" Target miscoverage: {alpha:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 5: Compare Methods" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 5: Comparison of Methods\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(f\"\\nTarget: {1-alpha:.0%} coverage (max {alpha:.0%} miscoverage)\")\n", + "print(\"\\n{:<40} {:<15} {:<15}\".format(\"Metric\", \"LABEL\", \"CovariateLabel\"))\n", + "print(\"-\" * 70)\n", + "\n", + "# Coverage (1 - miscoverage)\n", + "label_coverage = 1 - miscoverage_label\n", + "cov_coverage = 1 - miscoverage_cov\n", + "print(\n", + " \"{:<40} {:<15.2%} {:<15.2%}\".format(\n", + " \"Empirical Coverage\", label_coverage, cov_coverage\n", + " )\n", + ")\n", + "\n", + "# Miscoverage\n", + "print(\n", + " \"{:<40} {:<15.4f} {:<15.4f}\".format(\n", + " \"Empirical Miscoverage\",\n", + " miscoverage_label,\n", + " miscoverage_cov,\n", + " )\n", + ")\n", + "\n", + "# Average set size (smaller is better for efficiency)\n", + "print(\n", + " \"{:<40} {:<15.2f} {:<15.2f}\".format(\n", + " \"Average Set Size\", avg_set_size_label, avg_set_size_cov\n", + " )\n", + ")\n", + "\n", + "# Efficiency (inverse of average set size)\n", + "efficiency_label = 1.0 / avg_set_size_label\n", + "efficiency_cov = 1.0 / avg_set_size_cov\n", + "print(\n", + " \"{:<40} {:<15.4f} {:<15.4f}\".format(\n", + " \"Efficiency (1/avg_set_size)\", efficiency_label, efficiency_cov\n", + " )\n", + ")\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"Summary\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nKey Observations:\")\n", + "print(\"1. Both methods achieve near-target coverage guarantees\")\n", + "print(\"2. LABEL: Standard conformal prediction\")\n", + "print(\"3. CovariateLabel: Adapts to distribution shift between cal and test\")\n", + "print(\"\\nWhen to use CovariateLabel:\")\n", + "print(\" - When test distribution differs from calibration distribution\")\n", + "print(\" - When you have access to test embeddings/features\")\n", + "print(\" - When you want more robust coverage under distribution shift\")\n", + "print(\"\\nWhen to use LABEL:\")\n", + "print(\" - When cal and test distributions are similar (exchangeable)\")\n", + "print(\" - Simpler method, no need to fit KDEs\")\n", + "print(\" - Computationally more efficient\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 6: Visualize Prediction Sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 6: Example Predictions\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Show first 5 test examples\n", + "n_examples = 5\n", + "print(f\"\\nShowing first {n_examples} test examples:\")\n", + "print(\"-\" * 80)\n", + "\n", + "for i in range(min(n_examples, len(y_true_label))):\n", + " true_class = int(y_true_label[i])\n", + "\n", + " # LABEL prediction set\n", + " if isinstance(predset_label, np.ndarray):\n", + " label_set = np.where(predset_label[i])[0]\n", + " else:\n", + " label_set = torch.where(predset_label[i])[0].cpu().numpy()\n", + "\n", + " # CovariateLabel prediction set\n", + " if isinstance(predset_cov, np.ndarray):\n", + " cov_set = np.where(predset_cov[i])[0]\n", + " else:\n", + " cov_set = torch.where(predset_cov[i])[0].cpu().numpy()\n", + "\n", + " print(f\"\\nExample {i+1}:\")\n", + " print(f\" True class: {true_class}\")\n", + " print(f\" LABEL set: {label_set.tolist()} (size: {len(label_set)})\")\n", + " print(f\" CovariateLabel set: {cov_set.tolist()} (size: {len(cov_set)})\")\n", + " print(f\" Correct in LABEL? {true_class in label_set}\")\n", + " print(f\" Correct in CovariateLabel? {true_class in cov_set}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"Example completed successfully!\")\n", + "print(\"=\" * 80)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/datasets_overview.ipynb b/examples/datasets_overview.ipynb index aa99fd677..885eb0595 100644 --- a/examples/datasets_overview.ipynb +++ b/examples/datasets_overview.ipynb @@ -424,9 +424,9 @@ "\n", "**Description**: Chest X-ray images for COVID-19 classification.\n", "\n", - "**Source/Link**: Custom or public sources (check PyHealth docs).\n", + "**Source/Link**: [Custom or public sources (check PyHealth docs).](https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database)\n", "\n", - "**Download Method**: Varies; often from Kaggle or GitHub.\n", + "**Download Method**: curl, direct download\n", "\n", "**Restrictions**: Public datasets.\n", "\n", @@ -467,67 +467,6 @@ "**Example Usage**:" ] }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[31mInit signature:\u001b[39m\n", - "SampleDataset(\n", - " samples: List[Dict],\n", - " input_schema: Dict[str, Union[str, Type[pyhealth.processors.base_processor.FeatureProcessor]]],\n", - " output_schema: Dict[str, Union[str, Type[pyhealth.processors.base_processor.FeatureProcessor]]],\n", - " dataset_name: Optional[str] = \u001b[38;5;28;01mNone\u001b[39;00m,\n", - " task_name: Optional[str] = \u001b[38;5;28;01mNone\u001b[39;00m,\n", - " input_processors: Optional[Dict[str, pyhealth.processors.base_processor.FeatureProcessor]] = \u001b[38;5;28;01mNone\u001b[39;00m,\n", - " output_processors: Optional[Dict[str, pyhealth.processors.base_processor.FeatureProcessor]] = \u001b[38;5;28;01mNone\u001b[39;00m,\n", - ") -> \u001b[38;5;28;01mNone\u001b[39;00m\n", - "\u001b[31mDocstring:\u001b[39m \n", - "Sample dataset class for handling and processing data samples.\n", - "\n", - "Attributes:\n", - " samples (List[Dict]): List of data samples.\n", - " input_schema (Dict[str, Union[str, Type[FeatureProcessor], Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]):\n", - " Schema for input data. Values can be string aliases, processor classes, or tuples of (spec, kwargs_dict).\n", - " output_schema (Dict[str, Union[str, Type[FeatureProcessor], Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]):\n", - " Schema for output data. Values can be string aliases, processor classes, or tuples of (spec, kwargs_dict).\n", - " dataset_name (Optional[str]): Name of the dataset.\n", - " task_name (Optional[str]): Name of the task.\n", - "\u001b[31mInit docstring:\u001b[39m\n", - "Initializes the SampleDataset with samples and schemas.\n", - "\n", - "Args:\n", - " samples (List[Dict]): List of data samples.\n", - " input_schema (Dict[str, Union[str, Type[FeatureProcessor], Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]):\n", - " Schema for input data. Values can be string aliases, processor classes, or tuples of (spec, kwargs_dict) for instantiation.\n", - " output_schema (Dict[str, Union[str, Type[FeatureProcessor], Tuple[Union[str, Type[FeatureProcessor]], Dict[str, Any]]]]):\n", - " Schema for output data. Values can be string aliases, processor classes, or tuples of (spec, kwargs_dict) for instantiation.\n", - " dataset_name (Optional[str], optional): Name of the dataset.\n", - " Defaults to None.\n", - " task_name (Optional[str], optional): Name of the task.\n", - " Defaults to None.\n", - " input_processors (Optional[Dict[str, FeatureProcessor]],\n", - " optional): Pre-fitted input processors. If provided, these\n", - " will be used instead of creating new ones from input_schema.\n", - " Defaults to None.\n", - " output_processors (Optional[Dict[str, FeatureProcessor]],\n", - " optional): Pre-fitted output processors. If provided, these\n", - " will be used instead of creating new ones from output_schema.\n", - " Defaults to None.\n", - "\u001b[31mFile:\u001b[39m ~/PyHealth/pyhealth/datasets/sample_dataset.py\n", - "\u001b[31mType:\u001b[39m type\n", - "\u001b[31mSubclasses:\u001b[39m " - ] - } - ], - "source": [ - "SampleDataset?" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/examples/timeseries_mimic4.ipynb b/examples/timeseries_mimic4.ipynb index f5d9e012b..1532d2d8d 100644 --- a/examples/timeseries_mimic4.ipynb +++ b/examples/timeseries_mimic4.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "8f384611", "metadata": {}, "outputs": [ @@ -10,83 +10,40 @@ "name": "stdout", "output_type": "stream", "text": [ - "Memory usage Starting MIMIC4Dataset init: 668.4 MB\n", + "Memory usage Starting MIMIC4Dataset init: 645.9 MB\n", "Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: True)\n", - "Using default EHR config: /home/johnwu3/projects/PyHealth_Branch_Testing/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml\n", - "Memory usage Before initializing mimic4_ehr: 668.4 MB\n", + "Using default EHR config: /home/ubuntu/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml\n", + "Memory usage Before initializing mimic4_ehr: 645.9 MB\n", "Initializing mimic4_ehr dataset from /srv/local/data/physionet.org/files/mimiciv/2.2/ (dev mode: False)\n", - "Scanning table: diagnoses_icd from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz\n", - "Joining with table: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz\n", - "Original path does not exist. Using alternative: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv\n", - "Scanning table: procedures_icd from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/procedures_icd.csv.gz\n", - "Joining with table: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz\n", - "Original path does not exist. Using alternative: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv\n", - "Scanning table: prescriptions from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/prescriptions.csv.gz\n", - "Scanning table: labevents from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/labevents.csv.gz\n", - "Joining with table: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/d_labitems.csv.gz\n", - "Scanning table: patients from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/patients.csv.gz\n", - "Scanning table: admissions from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv.gz\n", - "Original path does not exist. Using alternative: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/admissions.csv\n", - "Scanning table: icustays from /srv/local/data/physionet.org/files/mimiciv/2.2/icu/icustays.csv.gz\n", - "Memory usage After initializing mimic4_ehr: 30751.7 MB\n", - "Memory usage After EHR dataset initialization: 30751.7 MB\n", - "Memory usage Before combining data: 30751.7 MB\n", - "Combining data from ehr dataset\n", - "Creating combined dataframe\n", - "Memory usage After combining data: 30751.7 MB\n", - "Memory usage Completed MIMIC4Dataset init: 30751.7 MB\n", - "Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset...\n", - "Generating samples with 10 worker(s)...\n", - "Collecting global event dataframe...\n", - "Dev mode enabled: limiting to 1000 patients\n", - "Collected dataframe with shape: (458197, 47)\n", - "Generating samples for InHospitalMortalityMIMIC4 with 10 workers\n" + "Scanning table: diagnoses_icd from /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Collecting samples for InHospitalMortalityMIMIC4 from 10 workers: 100%|██████████| 1000/1000 [00:03<00:00, 304.76it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Caching samples to ../../test_cache_mortality_m4/InHospitalMortalityMIMIC4.parquet\n", - "Failed to cache samples: failed to determine supertype of list[datetime[μs]] and object\n", - "Label mortality vocab: {0: 0, 1: 1}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "Processing samples: 100%|██████████| 723/723 [00:00<00:00, 1726.48it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generated 723 samples for task InHospitalMortalityMIMIC4\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "ename": "FileNotFoundError", + "evalue": "Neither path exists: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz or /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MIMIC4Dataset\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m dataset = \u001b[43mMIMIC4Dataset\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[43m \u001b[49m\u001b[43mehr_root\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m/srv/local/data/physionet.org/files/mimiciv/2.2/\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 5\u001b[39m \u001b[43m \u001b[49m\u001b[43mehr_tables\u001b[49m\u001b[43m=\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdiagnoses_icd\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mprocedures_icd\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mprescriptions\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mlabevents\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 6\u001b[39m \u001b[43m \u001b[49m\u001b[43mdev\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 7\u001b[39m \u001b[43m)\u001b[49m\n\u001b[32m 9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtasks\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m InHospitalMortalityMIMIC4\n\u001b[32m 11\u001b[39m task = InHospitalMortalityMIMIC4()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/mimic4.py:239\u001b[39m, in \u001b[36mMIMIC4Dataset.__init__\u001b[39m\u001b[34m(self, ehr_root, note_root, cxr_root, ehr_tables, note_tables, cxr_tables, ehr_config_path, note_config_path, cxr_config_path, dataset_name, dev)\u001b[39m\n\u001b[32m 237\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ehr_root:\n\u001b[32m 238\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mInitializing MIMIC4EHRDataset with tables: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mehr_tables\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m (dev mode: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdev\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m)\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m239\u001b[39m \u001b[38;5;28mself\u001b[39m.sub_datasets[\u001b[33m\"\u001b[39m\u001b[33mehr\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mMIMIC4EHRDataset\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 240\u001b[39m \u001b[43m \u001b[49m\u001b[43mroot\u001b[49m\u001b[43m=\u001b[49m\u001b[43mehr_root\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 241\u001b[39m \u001b[43m \u001b[49m\u001b[43mtables\u001b[49m\u001b[43m=\u001b[49m\u001b[43mehr_tables\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 242\u001b[39m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m=\u001b[49m\u001b[43mehr_config_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 243\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 244\u001b[39m log_memory_usage(\u001b[33m\"\u001b[39m\u001b[33mAfter EHR dataset initialization\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 246\u001b[39m \u001b[38;5;66;03m# Initialize Notes dataset if root is provided\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/mimic4.py:59\u001b[39m, in \u001b[36mMIMIC4EHRDataset.__init__\u001b[39m\u001b[34m(self, root, tables, dataset_name, config_path, **kwargs)\u001b[39m\n\u001b[32m 57\u001b[39m default_tables = [\u001b[33m\"\u001b[39m\u001b[33mpatients\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33madmissions\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33micustays\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 58\u001b[39m tables = tables + default_tables\n\u001b[32m---> \u001b[39m\u001b[32m59\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[32m 60\u001b[39m \u001b[43m \u001b[49m\u001b[43mroot\u001b[49m\u001b[43m=\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 61\u001b[39m \u001b[43m \u001b[49m\u001b[43mtables\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtables\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 62\u001b[39m \u001b[43m \u001b[49m\u001b[43mdataset_name\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdataset_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 63\u001b[39m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m=\u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 64\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 65\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 66\u001b[39m log_memory_usage(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mAfter initializing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:139\u001b[39m, in \u001b[36mBaseDataset.__init__\u001b[39m\u001b[34m(self, root, tables, dataset_name, config_path, dev)\u001b[39m\n\u001b[32m 133\u001b[39m \u001b[38;5;28mself\u001b[39m.dev = dev\n\u001b[32m 135\u001b[39m logger.info(\n\u001b[32m 136\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mInitializing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.dataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m dataset from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.root\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m (dev mode: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.dev\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m)\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 137\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m139\u001b[39m \u001b[38;5;28mself\u001b[39m.global_event_df = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mload_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 141\u001b[39m \u001b[38;5;66;03m# Cached attributes\u001b[39;00m\n\u001b[32m 142\u001b[39m \u001b[38;5;28mself\u001b[39m._collected_global_event_df = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:192\u001b[39m, in \u001b[36mBaseDataset.load_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 186\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mload_data\u001b[39m(\u001b[38;5;28mself\u001b[39m) -> pl.LazyFrame:\n\u001b[32m 187\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Loads data from the specified tables.\u001b[39;00m\n\u001b[32m 188\u001b[39m \n\u001b[32m 189\u001b[39m \u001b[33;03m Returns:\u001b[39;00m\n\u001b[32m 190\u001b[39m \u001b[33;03m pl.LazyFrame: A concatenated lazy frame of all tables.\u001b[39;00m\n\u001b[32m 191\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m192\u001b[39m frames = [\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mload_table\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtable\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlower\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m table \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.tables]\n\u001b[32m 193\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m pl.concat(frames, how=\u001b[33m\"\u001b[39m\u001b[33mdiagonal\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:222\u001b[39m, in \u001b[36mBaseDataset.load_table\u001b[39m\u001b[34m(self, table_name)\u001b[39m\n\u001b[32m 219\u001b[39m csv_path = clean_path(csv_path)\n\u001b[32m 221\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mScanning table: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtable_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcsv_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m222\u001b[39m df = \u001b[43mscan_csv_gz_or_csv_tsv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcsv_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 224\u001b[39m \u001b[38;5;66;03m# Convert column names to lowercase before calling preprocess_func\u001b[39;00m\n\u001b[32m 225\u001b[39m df = df.rename(_to_lower)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:94\u001b[39m, in \u001b[36mscan_csv_gz_or_csv_tsv\u001b[39m\u001b[34m(path)\u001b[39m\n\u001b[32m 91\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mOriginal path does not exist. Using alternative: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00malt_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 92\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m scan_file(alt_path)\n\u001b[32m---> \u001b[39m\u001b[32m94\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mNeither path exists: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m or \u001b[39m\u001b[38;5;132;01m{\u001b[39;00malt_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mFileNotFoundError\u001b[39m: Neither path exists: /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz or /srv/local/data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv" ] } ], "source": [ "from pyhealth.datasets import MIMIC4Dataset\n", "\n", + "ehr_root = \"/srv/local/data/physionet.org/files/mimiciv/2.2/\"\n", + "ehr_root = \"~/Downloads/\"\n", + "\n", "dataset = MIMIC4Dataset(\n", - " ehr_root=\"/srv/local/data/physionet.org/files/mimiciv/2.2/\",\n", + " ehr_root=ehr_root,\n", " ehr_tables=[\"diagnoses_icd\", \"procedures_icd\", \"prescriptions\", \"labevents\"],\n", " dev=True,\n", ")\n", @@ -108,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "ca618021", "metadata": {}, "outputs": [ @@ -391,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "a66e489f", "metadata": {}, "outputs": [ @@ -418,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "3b642c1d", "metadata": {}, "outputs": [ @@ -879,11 +836,43 @@ " optimizer_params={\"lr\": 1e-4} # Using learning rate of 1e-4\n", ")\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d09d115d", + "metadata": {}, + "outputs": [], + "source": [ + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0757b9fe", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import MIMIC3Dataset\n", + "from pyhealth.tasks import MortalityPredictionMIMIC3\n", + "from pyhealth.models import RNN\n", + "from pyhealth.trainer import Trainer\n", + "\n", + "# Load healthcare data\n", + "dataset = MIMIC3Dataset(root=\"data/\", tables=[\"diagnoses_icd\", \"procedures\"])\n", + "samples = dataset.set_task(MortalityPredictionMIMIC3())\n", + "\n", + "# Train model\n", + "model = RNN(dataset=samples)\n", + "trainer = Trainer(model=model)\n", + "trainer.train(train_dataloader, val_dataloader, epochs=50)" + ] } ], "metadata": { "kernelspec": { - "display_name": "medical_coding_demo", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -897,7 +886,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/examples/timeseries_vae_modeling.ipynb b/examples/timeseries_vae_modeling.ipynb new file mode 100644 index 000000000..17a49fcf9 --- /dev/null +++ b/examples/timeseries_vae_modeling.ipynb @@ -0,0 +1,1009 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Time-Series Modeling with VAE\n", + "\n", + "This notebook demonstrates how to use the enhanced VAE model for time-series data analysis and generation. Unlike image VAEs that work with spatial patterns, time-series VAEs model sequential patterns in medical data.\n", + "\n", + "What does it mean, to encode a time-series into a Variational Auto-Encoder?\n", + "\n", + "The first part of the VAE is an encoder based of a Gated Recurring Unit (GRU) Neural Network. This type of network is able to encode sequences of events into an embedding space.\n", + "\n", + "The events are treated as a sequence of events - we do not take into account the time yet, just the order of events. These events are aligned across all the patients. Since not all patients will have the same number of events, we add some null \"padding\" events to the patients that have fewer events than the maximum.\n", + "\n", + "The Encoder part of the VAE will encode these series of events into embeddings, which represent the whole history of events for a patient. We can use these embeddings as input for further modelling tasks, or to compare patients. For example, patients that have a similar sequence of events should cluster closely in the embedding space, compared to other patients.\n", + "\n", + "Finally, the Decoder part of the VAE reconstructs the original sequences of events, from the input.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import split_by_visit, get_dataloader\n", + "from pyhealth.trainer import Trainer\n", + "from pyhealth.models import VAE\n", + "from pyhealth.datasets import MIMIC4Dataset\n", + "from pyhealth.tasks import MortalityPredictionMIMIC4\n", + "\n", + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load MIMIC4 Demo Dataset\n", + "\n", + "We'll use the MIMIC4 demo dataset to demonstrate time-series VAE on real medical sequences.\n", + "\n", + "**Setup Instructions:**\n", + "1. Download MIMIC4 demo data from: https://physionet.org/files/mimic-iv-demo/2.2/\n", + "2. Create a `data/mimic4_demo` directory in your project root\n", + "3. Extract the downloaded files into `data/mimic4_demo/hosp/` subdirectory\n", + "4. Update the `ehr_root` path below if needed" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage Starting MIMIC4Dataset init: 812.4 MB\n", + "Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions'] (dev mode: True)\n", + "Using default EHR config: /home/ubuntu/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml\n", + "Memory usage Before initializing mimic4_ehr: 812.4 MB\n", + "Initializing mimic4_ehr dataset from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/ (dev mode: False)\n", + "Scanning table: diagnoses_icd from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/diagnoses_icd.csv.gz\n", + "Joining with table: /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz\n", + "Scanning table: procedures_icd from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/procedures_icd.csv.gz\n", + "Joining with table: /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz\n", + "Scanning table: prescriptions from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/prescriptions.csv.gz\n", + "Scanning table: patients from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/patients.csv.gz\n", + "Scanning table: admissions from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz\n", + "Scanning table: icustays from /home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/icu/icustays.csv.gz\n", + "Memory usage After initializing mimic4_ehr: 815.3 MB\n", + "Memory usage After EHR dataset initialization: 815.3 MB\n", + "Memory usage Before combining data: 815.3 MB\n", + "Combining data from ehr dataset\n", + "Creating combined dataframe\n", + "Memory usage After combining data: 815.3 MB\n", + "Memory usage Completed MIMIC4Dataset init: 815.3 MB\n", + "Setting task MortalityPredictionMIMIC4 for mimic4 base dataset...\n", + "Generating samples with 2 worker(s)...\n", + "Collecting global event dataframe...\n", + "Dev mode enabled: limiting to 1000 patients\n", + "Collected dataframe with shape: (23830, 38)\n", + "Generating samples for MortalityPredictionMIMIC4 with 2 workers\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Collecting samples for MortalityPredictionMIMIC4 from 2 workers: 100%|██████████| 100/100 [00:00<00:00, 184.14it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label mortality vocab: {0: 0, 1: 1}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Processing samples: 100%|██████████| 108/108 [00:00<00:00, 20409.32it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated 108 samples for task MortalityPredictionMIMIC4\n", + "MIMIC4 demo dataset loaded\n", + "Number of samples: 108\n", + "Input features: ['conditions', 'procedures', 'drugs']\n", + "Output features: ['mortality']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Load MIMIC4 demo dataset\n", + "# Download demo data from: https://physionet.org/files/mimic-iv-demo/2.2/\n", + "# and place in a local directory, then update ehr_root below\n", + "ehr_root = \"/home/ubuntu/PyHealth/data/mimic-iv-clinical-database-demo-2.2/\" # Update this path to your local MIMIC4 demo data\n", + "\n", + "dataset = MIMIC4Dataset(\n", + " ehr_root=ehr_root,\n", + " ehr_tables=[\"diagnoses_icd\", \"procedures_icd\", \"prescriptions\"],\n", + " dev=True,\n", + ")\n", + "\n", + "# Set task for time-series modeling\n", + "task = MortalityPredictionMIMIC4()\n", + "ts_dataset = dataset.set_task(task, num_workers=2)\n", + "\n", + "print(\"MIMIC4 demo dataset loaded\")\n", + "print(f\"Number of samples: {len(ts_dataset)}\")\n", + "print(f\"Input features: {list(ts_dataset.input_schema.keys())}\")\n", + "print(f\"Output features: {list(ts_dataset.output_schema.keys())}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and Train Time-Series VAE\n", + "\n", + "The VAE will learn to encode patient trajectories into a latent space and reconstruct them." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time-series VAE created\n", + "Input type: timeseries\n", + "Has embedding model: True\n", + "Has RNN encoder: True\n", + "Latent dimension: 64\n" + ] + } + ], + "source": [ + "# Create time-series VAE model\n", + "ts_model = VAE(\n", + " dataset=ts_dataset,\n", + " feature_keys=[\"conditions\"], # Single sequence feature for VAE\n", + " label_key=\"mortality\",\n", + " mode=\"binary\", # Binary classification for mortality prediction\n", + " input_type=\"timeseries\", # Key parameter for time-series mode\n", + " hidden_dim=64, # Latent dimension for medical sequences\n", + ")\n", + "\n", + "print(\"Time-series VAE created\")\n", + "print(f\"Input type: {ts_model.input_type}\")\n", + "print(f\"Has embedding model: {hasattr(ts_model, 'embedding_model')}\")\n", + "print(f\"Has RNN encoder: {hasattr(ts_model, 'encoder_rnn')}\")\n", + "print(f\"Latent dimension: {ts_model.hidden_dim}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understanding the Time-Series VAE Architecture\n", + "\n", + "The time-series VAE differs from image VAEs:\n", + "\n", + "1. **EmbeddingModel**: Converts categorical sequences to dense vectors\n", + "2. **RNN Encoder**: Processes sequential embeddings, capturing temporal patterns\n", + "3. **Latent Space**: Fixed-size representation of the entire sequence\n", + "4. **Linear Decoder**: Reconstructs the sequence's compressed representation\n", + "\n", + "This architecture can learn patterns like \"diabetes → metformin → insulin\" or \"asthma → albuterol → steroids\"." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VAE(\n", + " (embedding_model): EmbeddingModel(embedding_layers=ModuleDict(\n", + " (conditions): Embedding(865, 64, padding_idx=0)\n", + " (procedures): Embedding(218, 64, padding_idx=0)\n", + " (drugs): Embedding(486, 64, padding_idx=0)\n", + " ))\n", + " (encoder_rnn): GRU(64, 64, batch_first=True)\n", + " (mu): Linear(in_features=64, out_features=64, bias=True)\n", + " (log_std2): Linear(in_features=64, out_features=64, bias=True)\n", + " (decoder_linear): Linear(in_features=64, out_features=64, bias=True)\n", + ")\n", + "Metrics: []\n", + "Device: cuda\n", + "\n", + "Training time-series VAE...\n", + "Training:\n", + "Batch size: 32\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.0001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: loss\n", + "Monitor criterion: min\n", + "Epochs: 10\n", + "Patience: None\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0 / 10: 100%|██████████| 4/4 [00:00<00:00, 10.86it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-0, step-4 ---\n", + "loss: 600.1951\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 712.44it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-0, step-4 ---\n", + "loss: 596.4567\n", + "New best loss score (596.4567) at epoch-0, step-4\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 1 / 10: 100%|██████████| 4/4 [00:00<00:00, 290.50it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-1, step-8 ---\n", + "loss: 570.5898\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 731.70it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-1, step-8 ---\n", + "loss: 581.0278\n", + "New best loss score (581.0278) at epoch-1, step-8\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 2 / 10: 100%|██████████| 4/4 [00:00<00:00, 288.28it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-2, step-12 ---\n", + "loss: 585.5507\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 743.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-2, step-12 ---\n", + "loss: 591.7324\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 3 / 10: 100%|██████████| 4/4 [00:00<00:00, 291.74it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-3, step-16 ---\n", + "loss: 571.2698\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 740.78it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-3, step-16 ---\n", + "loss: 569.4090\n", + "New best loss score (569.4090) at epoch-3, step-16\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 4 / 10: 100%|██████████| 4/4 [00:00<00:00, 293.35it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-4, step-20 ---\n", + "loss: 575.8144\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 740.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-4, step-20 ---\n", + "loss: 593.6073\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 5 / 10: 100%|██████████| 4/4 [00:00<00:00, 293.48it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-5, step-24 ---\n", + "loss: 554.4583\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 730.30it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-5, step-24 ---\n", + "loss: 575.9142\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 6 / 10: 100%|██████████| 4/4 [00:00<00:00, 289.78it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-6, step-28 ---\n", + "loss: 567.8894\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 743.44it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-6, step-28 ---\n", + "loss: 577.6712\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 7 / 10: 100%|██████████| 4/4 [00:00<00:00, 294.15it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-7, step-32 ---\n", + "loss: 590.4861\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 737.75it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-7, step-32 ---\n", + "loss: 557.5274\n", + "New best loss score (557.5274) at epoch-7, step-32\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 8 / 10: 100%|██████████| 4/4 [00:00<00:00, 291.25it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-8, step-36 ---\n", + "loss: 577.2799\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 751.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-8, step-36 ---\n", + "loss: 552.2935\n", + "New best loss score (552.2935) at epoch-8, step-36\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Epoch 9 / 10: 100%|██████████| 4/4 [00:00<00:00, 293.05it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Train epoch-9, step-40 ---\n", + "loss: 536.7056\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 743.14it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eval epoch-9, step-40 ---\n", + "loss: 547.2967\n", + "New best loss score (547.2967) at epoch-9, step-40\n", + "Loaded best model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training completed!\n" + ] + } + ], + "source": [ + "# Prepare data for training\n", + "train_dataloader = get_dataloader(ts_dataset, batch_size=32, shuffle=True)\n", + "\n", + "# Create trainer\n", + "trainer = Trainer(\n", + " model=ts_model, \n", + " device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n", + " metrics=[] # VAE is unsupervised, no classification metrics needed\n", + ")\n", + "\n", + "# Train the model (reduced epochs for demo)\n", + "print(\"Training time-series VAE...\")\n", + "trainer.train(\n", + " train_dataloader=train_dataloader,\n", + " val_dataloader=train_dataloader, # Using same data for demo\n", + " epochs=10,\n", + " monitor=\"loss\",\n", + " monitor_criterion=\"min\",\n", + " optimizer_params={\"lr\": 1e-4},\n", + ")\n", + "\n", + "print(\"Training completed!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate Reconstruction Performance\n", + "\n", + "Check how well the VAE reconstructs the original sequences.\n", + "\n", + "**What the outputs represent:**\n", + "- `y_prob`: Reconstructed patient trajectory embeddings (VAE's attempt to recreate the input)\n", + "- `y_true`: Original RNN hidden states summarizing each patient's diagnosis sequence\n", + "- `loss`: Reconstruction error measuring how well the VAE captures medical patterns\n", + "\n", + "The `y_true` values are 64-dimensional vectors that represent compressed summaries of patient medical histories, capturing temporal patterns like disease progression (e.g., hypertension → diabetes → kidney disease)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluation: 100%|██████████| 4/4 [00:00<00:00, 715.26it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation Results:\n", + "loss: 541.9971\n", + "\n", + "Reconstruction shape: torch.Size([32, 64])\n", + "Original shape: torch.Size([32, 64])\n", + "Loss: 654.7109\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Evaluate on training data\n", + "eval_results = trainer.evaluate(train_dataloader)\n", + "print(\"Evaluation Results:\")\n", + "for metric, value in eval_results.items():\n", + " print(f\"{metric}: {value:.4f}\")\n", + "\n", + "# Get reconstruction examples\n", + "data_batch = next(iter(train_dataloader))\n", + "with torch.no_grad():\n", + " output = ts_model(**data_batch)\n", + " \n", + "print(f\"\\nReconstruction shape: {output['y_prob'].shape}\")\n", + "print(f\"Original shape: {output['y_true'].shape}\")\n", + "print(f\"Loss: {output['loss'].item():.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate New Medical Sequences\n", + "\n", + "Sample from the latent space to generate new patient trajectories and convert them to human-understandable medical codes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated sequence representations:\n", + "Shape: torch.Size([3, 64])\n", + "Sample values: [ 0.53002435 1.3221712 -0.48090675 -0.45167932 -0.49437132]\n", + "\n", + "Converting to medical codes for generated sequence 0:\n", + "tensor([ 0.0000e+00, -7.3698e-03, -3.6019e+00, 4.2856e+00, 4.9865e+00,\n", + " -3.4686e+00, -2.6077e+00, 3.6966e-01, 3.1678e-01, 1.8775e+00,\n", + " 2.7686e+00, 9.8846e+00, 1.0820e+01, 4.1149e+00, 1.1023e+01,\n", + " -4.6142e+00, 6.1783e+00, -1.9797e+00, 3.3450e+00, 4.8513e-01,\n", + " 4.9730e+00, 4.7466e+00, 6.4790e+00, 4.3559e+00, -4.6351e+00,\n", + " 5.0865e-01, 2.9454e+00, 4.5632e+00, -6.7101e+00, 1.6942e+00,\n", + " 9.0058e+00, 1.1524e+00, -4.4482e+00, 3.0030e+00, -6.6758e+00,\n", + " 1.5273e+01, -9.0026e-01, -5.9171e-01, -8.6225e+00, -6.0907e+00,\n", + " 3.5220e+00, 9.3690e+00, -5.0509e-01, 1.1058e+01, -2.0280e+00,\n", + " 9.4648e+00, -1.8569e+00, 1.9413e+00, 6.4853e+00, -8.4831e+00,\n", + " -9.4427e-01, -2.8588e+00, 7.2913e-01, -9.2253e+00, -3.6401e+00,\n", + " 4.1965e+00, 6.1821e+00, 4.8190e-02, -1.0670e+00, -2.3349e+00,\n", + " -2.1616e+00, 2.2229e+00, 1.0739e+00, 1.5211e+00, 2.5068e+00,\n", + " 6.0805e-01, -3.9276e+00, -4.7792e+00, 7.1451e+00, 6.6988e+00,\n", + " 8.2785e+00, -7.7417e+00, 1.3240e+00, -2.9780e+00, 1.1269e+01,\n", + " 1.8551e+00, -1.1236e+01, -1.6523e+00, -4.8508e+00, 6.5365e+00,\n", + " -1.8226e+00, -2.5214e+00, -1.8810e+00, 7.5196e+00, 3.4994e+00,\n", + " -6.2609e-01, 3.3908e+00, 8.6278e+00, -8.5831e+00, -5.9723e+00,\n", + " -9.4077e-01, -3.6397e+00, -7.3617e+00, 2.4734e+00, 4.7124e+00,\n", + " 5.7872e+00, 1.1024e+00, -8.8864e+00, 6.5827e+00, 7.5309e+00,\n", + " 7.7808e-01, -3.8766e+00, 4.2168e+00, -1.0284e+00, -4.5411e+00,\n", + " -1.0719e+00, 3.6934e+00, -5.7031e+00, 4.2286e+00, -2.7739e+00,\n", + " 5.0000e+00, -1.1350e+00, -1.9651e+00, 1.6179e+00, 7.9560e+00,\n", + " 4.4798e+00, -1.6526e+00, -1.6522e+00, -3.5508e+00, 3.9089e+00,\n", + " 2.7357e+00, -3.3986e+00, 6.5241e+00, 1.7675e-01, -5.8271e+00,\n", + " 9.2375e+00, 2.4214e+00, -3.7925e+00, 1.3557e+00, 1.7812e+00,\n", + " -3.0547e-01, 6.5939e-01, 1.0546e+00, 3.9847e+00, 1.8091e+00,\n", + " -4.0046e+00, -5.5650e+00, 5.2020e+00, 7.5734e+00, -5.5784e+00,\n", + " -1.1468e+01, 8.2059e+00, -7.4497e+00, -6.2369e+00, -7.5726e+00,\n", + " -4.4698e+00, 2.5128e+00, -9.4868e+00, -2.7561e+00, -1.1705e+00,\n", + " 3.3567e+00, 3.7004e+00, 2.6886e+00, 3.2938e-01, -2.1804e+00,\n", + " 5.4503e+00, -7.9096e-01, 7.1233e+00, 6.3597e+00, -1.5332e+00,\n", + " 5.3480e-02, 2.1162e-01, 3.2930e+00, -1.8011e+00, -2.7989e+00,\n", + " 2.7531e+00, -1.1935e+00, -5.0806e+00, 3.6173e+00, -1.6692e+00,\n", + " -5.6118e+00, -3.0442e+00, 1.3044e+01, 1.8823e+00, 3.4506e+00,\n", + " -8.7729e+00, 4.5545e+00, -3.8803e+00, 1.7737e+00, -2.0081e+00,\n", + " -5.4540e+00, -6.3734e+00, -5.1804e-01, 1.9105e+00, -5.7511e+00,\n", + " -1.1010e+00, 1.4860e+00, 1.7801e+00, 7.2562e+00, 3.6735e-01,\n", + " 9.1315e+00, 7.1706e+00, -3.4692e+00, 3.5518e+00, -3.0591e+00,\n", + " -1.3145e+00, -6.0018e+00, -5.7727e+00, -3.2378e+00, -5.1966e+00,\n", + " -4.1021e+00, -3.1030e+00, 8.1722e+00, 5.2892e+00, -8.7676e+00,\n", + " 1.9973e+00, 4.6449e+00, 1.3642e+00, 8.9750e-01, 1.6920e-01,\n", + " 2.3763e+00, 3.4316e-01, 4.2113e+00, -1.6114e-01, -1.8713e+00,\n", + " -1.4838e+00, 2.5215e+00, 1.4842e+00, 6.3901e+00, 4.1662e+00,\n", + " 6.0793e-01, 1.7254e+00, 8.2590e+00, 2.6119e+00, -1.7054e+00,\n", + " 6.1178e+00, -2.9691e+00, -4.0823e+00, 4.7501e-01, 8.9351e-01,\n", + " -1.0580e+00, -1.4648e+00, 2.0005e+00, 4.4008e-01, -1.1741e+01,\n", + " -4.2694e-02, -7.7522e-01, 1.9731e+00, 5.9338e+00, 6.1365e+00,\n", + " -7.7485e+00, 1.3402e+01, 1.2785e+01, -1.6353e+00, -3.6195e+00,\n", + " -1.2166e+00, 3.3741e+00, -1.0241e+01, 6.4584e+00, -1.1245e+00,\n", + " 5.6069e-01, 4.5313e+00, 7.2805e+00, 5.9941e-01, -4.1223e+00,\n", + " 2.3223e-01, -6.0213e+00, 1.2135e+00, -4.0671e+00, -3.0445e+00,\n", + " -4.7401e+00, -5.2669e+00, 8.4796e+00, -5.4145e+00, -4.7753e+00,\n", + " -5.2986e-01, -3.6312e+00, 1.0913e+00, -3.6791e+00, 2.4095e+00,\n", + " 4.3596e-01, -1.4798e+00, -1.4894e+00, 1.8673e-01, 2.4698e-01,\n", + " -2.4453e-01, -3.0582e+00, 2.9348e+00, 2.3117e-01, -2.4103e+00,\n", + " -2.8271e+00, -3.9102e+00, -3.6497e+00, -2.4924e+00, 1.9910e+00,\n", + " -6.5865e+00, -3.6936e+00, -3.4439e-01, -7.0813e-02, 8.2621e-01,\n", + " 1.4598e-01, -6.6775e+00, 3.9390e+00, -8.5732e-02, 4.4681e+00,\n", + " -5.1826e+00, -6.4087e+00, 2.4561e-01, 1.2722e+00, -2.3493e+00,\n", + " 2.2532e+00, 3.5215e+00, -1.1521e+00, -5.0680e+00, 1.1928e+01,\n", + " 2.4243e+00, -4.3520e+00, -8.1776e+00, -1.0246e+00, -6.5926e+00,\n", + " -4.7592e+00, -6.3312e-01, -1.0127e+01, -8.8611e-01, -8.8121e+00,\n", + " 2.8199e-01, 1.7753e+00, -3.0299e+00, -5.9531e+00, 8.6943e-01,\n", + " 3.5751e+00, -1.1387e+01, -2.3700e+00, 4.8945e+00, 1.6143e+00,\n", + " 3.5814e+00, -7.8166e+00, -2.4447e+00, 2.9132e+00, 2.6121e+00,\n", + " 5.3359e+00, 6.1179e+00, -1.4136e+01, 2.7990e-01, 5.0767e+00,\n", + " 3.2156e+00, 2.9253e+00, 7.9484e+00, 2.6591e+00, 8.8704e+00,\n", + " 3.1986e+00, -4.0668e-01, 4.6038e+00, 2.6145e+00, 3.5966e+00,\n", + " 1.0976e+00, -4.4431e+00, -5.4594e+00, -9.8451e-01, 3.2744e+00,\n", + " 5.2823e+00, -7.2685e+00, 3.1512e+00, 2.7413e-01, -2.7417e+00,\n", + " 2.4061e+00, 1.0575e+00, -8.0792e+00, 1.5948e+00, 5.0245e-02,\n", + " -7.0174e+00, -2.5981e+00, 1.7484e+00, 5.5191e+00, 2.9680e+00,\n", + " 6.1405e+00, -1.4361e-01, 3.6003e+00, 5.1398e+00, -3.4543e+00,\n", + " 5.1749e+00, -4.8572e+00, -1.4890e+00, 1.3004e+00, 5.2664e+00,\n", + " -3.8776e+00, -2.9888e+00, -7.6976e-01, 3.3617e+00, 8.8238e+00,\n", + " 5.4425e+00, 6.1032e-01, -2.8261e+00, -4.6408e+00, -2.9985e+00,\n", + " -7.4634e+00, -2.6308e+00, 5.1425e-01, -8.0916e-01, 3.9377e+00,\n", + " -6.8549e+00, -6.0009e+00, -4.6179e-01, 7.0597e-01, -2.6481e+00,\n", + " 1.0844e+00, 1.0792e+01, -1.3607e+00, 1.6815e+00, -5.9647e+00,\n", + " -3.4795e-01, -5.7725e+00, 8.8240e+00, 1.2161e+01, 1.4274e+00,\n", + " 5.6612e+00, -2.3413e+00, -1.8727e+00, 4.8331e+00, -1.7449e+00,\n", + " 9.4813e+00, -4.2424e+00, -4.4759e+00, -3.1925e-01, -1.0234e+00,\n", + " -2.3237e-01, 3.3495e+00, -3.8373e-01, -9.8905e-01, 4.4709e+00,\n", + " 6.2089e-01, 4.7296e+00, -4.6643e-01, -2.9335e+00, -3.2276e-01,\n", + " -8.1376e+00, -3.3459e+00, -9.3769e-02, -1.1267e+01, -4.3118e-01,\n", + " -3.3150e+00, -6.2661e-01, 4.7186e+00, -3.7383e+00, -6.2295e+00,\n", + " 5.6783e+00, 1.5134e+00, 3.3611e+00, 3.0486e+00, 1.1139e+00,\n", + " -2.3641e+00, -2.5694e+00, 4.0144e+00, -7.9107e-01, -4.3788e-01,\n", + " 3.9177e+00, 1.1759e+01, 1.4001e+00, -3.9649e-01, -1.1905e+01,\n", + " -1.9034e+00, -8.7804e+00, 1.0761e+01, 3.0643e+00, 3.6976e+00,\n", + " -2.7494e+00, -2.5087e+00, 9.4577e-01, -2.5091e+00, 5.5027e+00,\n", + " -1.1978e+00, -6.2105e+00, 2.2060e+00, -7.4665e+00, 3.0630e+00,\n", + " -5.2187e+00, -4.4273e+00, -2.1675e+00, 5.1272e+00, 5.5543e+00,\n", + " 3.7251e-01, 6.5773e+00, -9.4365e-01, -4.3358e+00, -4.5703e+00,\n", + " -5.3149e+00, 1.0750e+01, 4.3271e+00, -1.6464e-01, -9.9271e+00,\n", + " 4.0325e+00, 5.2930e+00, 4.1603e-01, -9.8539e+00, 9.0224e-02,\n", + " -4.6029e+00, -6.9055e-01, -8.8508e-01, 8.0016e-02, -5.1926e+00,\n", + " -2.0259e-01, 2.5811e-01, 1.0760e+01, -3.1546e+00, -1.2496e+00,\n", + " 1.0318e+00, -4.8848e+00, -1.1945e+01, 8.8947e-01, -7.6019e+00,\n", + " 9.7675e+00, 7.3018e+00, 7.1488e+00, 1.9355e+00, -3.6839e+00,\n", + " 9.3663e+00, -1.6682e+00, -2.4040e+00, -3.5577e+00, -4.3686e+00,\n", + " -2.9042e+00, 7.6098e+00, -2.9299e+00, 5.0070e+00, -1.1131e+00,\n", + " 2.0663e+00, -1.8503e+00, 2.0247e+00, 2.2851e+00, -4.4661e+00,\n", + " 3.9783e+00, -1.1042e+00, 6.6821e+00, 1.5496e+00, 7.2286e+00,\n", + " 7.1757e+00, 1.5642e+00, -3.8730e+00, -6.7046e+00, 4.8879e+00,\n", + " 6.2288e-01, -1.1968e+00, -6.9655e-01, 9.2160e+00, 8.1812e+00,\n", + " -3.2031e+00, 6.3572e+00, 6.4712e+00, -4.6608e+00, -1.7319e+00,\n", + " -4.2632e+00, -4.7678e+00, 9.9987e+00, 1.4257e+01, 1.6364e+01,\n", + " -1.1067e+01, -1.7239e+00, 5.2205e+00, -5.6597e+00, 4.5558e+00,\n", + " -3.2298e+00, 6.6794e+00, 3.8081e+00, 6.7543e+00, -1.6579e+00,\n", + " -3.5063e+00, -7.2064e-01, -9.8697e-01, 1.8629e+00, -1.6241e+00,\n", + " -3.0487e+00, -3.4707e+00, -3.9215e-01, -4.1220e+00, 1.9885e-01,\n", + " -2.1967e+00, -4.9302e-01, 5.2048e+00, -3.4484e+00, -2.8122e+00,\n", + " -5.5603e+00, -3.5931e-01, 4.3701e+00, -5.1773e+00, -4.5849e-01,\n", + " 3.2636e+00, 4.9086e+00, 8.1466e+00, 5.8811e+00, 7.5002e+00,\n", + " 3.9989e-02, 6.6855e+00, 9.9836e+00, -3.3854e+00, 7.8459e+00,\n", + " -1.6212e-01, -5.1125e+00, 5.8930e+00, 8.9414e+00, -1.2838e+00,\n", + " 9.6739e+00, 1.4556e+00, -4.1944e+00, 6.6870e+00, 6.9083e+00,\n", + " -4.6303e+00, -2.2641e+00, 2.0928e+00, 1.6327e+00, -5.0583e+00,\n", + " 2.3900e+00, -1.3623e+00, 3.1129e+00, -2.0396e+00, 2.0124e+00,\n", + " -2.8472e+00, 6.2345e+00, 5.9359e-01, -4.3222e+00, -8.0300e+00,\n", + " 1.0396e+01, 2.1617e+00, 3.4943e+00, 3.0572e+00, -3.5476e+00,\n", + " 5.3472e+00, 2.4073e+00, 5.9503e-01, 5.2738e-01, -3.0144e+00,\n", + " -5.1299e+00, 2.1250e+00, 1.3509e+00, -2.8952e-01, -6.6030e+00,\n", + " -3.6746e+00, -3.7765e+00, -3.2359e+00, 1.3805e+00, -2.2590e+00,\n", + " -1.5889e+00, 4.8569e+00, 1.0156e+00, 3.9435e+00, -2.0284e+00,\n", + " -1.0859e+01, -2.6059e+00, -6.2794e+00, 8.0509e-01, 2.3528e+00,\n", + " -1.6507e+00, 1.1961e+00, 9.9826e+00, 4.8255e+00, -5.6086e+00,\n", + " -6.5066e+00, -7.2207e+00, 6.8090e+00, 3.7589e-01, -6.0209e+00,\n", + " -8.2463e-01, -7.4694e-01, 8.1887e+00, -2.9548e+00, -1.1866e+00,\n", + " -3.5102e+00, -4.6932e+00, 7.2601e+00, 2.2765e+00, 9.4055e+00,\n", + " 2.7460e+00, 1.4907e+00, 4.7538e-01, 2.9329e+00, 1.2719e+00,\n", + " -2.7927e+00, -1.2591e+00, -1.7836e+00, 2.8065e+00, 5.3934e-01,\n", + " 1.3045e+00, 1.2268e+00, -5.6013e+00, 3.1139e+00, 3.8369e+00,\n", + " 3.3059e-01, 2.2407e+00, 1.2366e+00, -6.2229e+00, -3.7157e+00,\n", + " 6.4204e+00, 1.0573e+01, -3.9952e-01, 1.4235e+01, -1.0884e+01,\n", + " -2.9518e+00, 1.1950e+00, 1.0031e+00, 2.0499e+00, 1.6345e+00,\n", + " -3.8371e+00, -2.6782e+00, 1.9266e+00, -3.5774e+00, -7.1045e+00,\n", + " -7.6499e-01, -1.9117e+00, -1.0165e+01, -7.5704e-01, -4.4083e-01,\n", + " -2.7491e-01, -1.7376e+00, -7.6327e+00, -4.6055e+00, 5.0006e+00,\n", + " 5.3006e+00, -2.9642e-01, 6.4890e+00, 8.0494e-01, -7.4606e-01,\n", + " 5.7291e+00, 1.4712e+00, 1.0860e-01, -7.0036e-01, -2.1391e+00,\n", + " -2.0974e+00, 3.6578e+00, -2.5981e+00, 7.4232e+00, -3.8255e+00,\n", + " 3.1794e+00, 4.3983e-01, 5.2267e+00, -3.5482e+00, 9.6503e+00,\n", + " 1.6477e-01, -6.9726e+00, 6.3025e+00, -9.1959e+00, 8.4988e+00,\n", + " -2.7036e+00, 5.2478e+00, 1.0211e+01, -2.6642e-01, 3.9743e+00,\n", + " -1.0607e+01, 1.4849e+01, -4.1829e-01, 3.5922e+00, -1.2574e+00,\n", + " -8.8174e-01, -4.4337e-01, 1.6383e+00, -4.3945e+00, 3.5651e-01,\n", + " -1.9686e+00, 2.1606e-02, 4.2785e+00, -1.2995e+00, 8.8444e-01,\n", + " 4.1780e+00, -1.0253e+01, -6.7567e+00, -5.7410e+00, 3.1234e+00,\n", + " -4.5726e+00, -1.7247e+00, 9.8022e+00, 5.5747e+00, 4.9975e+00,\n", + " 3.2196e+00, -3.8409e+00, 4.8316e+00, -6.6928e+00, -4.7451e+00,\n", + " 1.9149e+00, 2.1189e+00, 7.8778e+00, -3.7329e-02, 1.4081e+00,\n", + " -1.3638e+00, -8.6539e+00, 1.6721e-01, -9.8377e+00, -2.7099e+00,\n", + " -2.7666e-01, -5.0738e-02, -3.5834e+00, 5.4019e+00, -8.2766e+00,\n", + " -3.6412e+00, 3.1811e+00, 1.5183e-01, -9.9372e-01, 3.6720e+00,\n", + " -4.4105e-01, -1.4014e+00, 7.7510e+00, 5.7498e+00, 3.5791e+00,\n", + " -5.4780e-01, -3.8520e+00, -2.1077e+00, 6.4482e+00, 1.0814e+00,\n", + " -7.5035e-01, 4.4298e+00, -8.8133e+00, -2.7812e+00, -5.7001e-01,\n", + " -2.6631e+00, -3.3242e+00, 1.3150e+00, -2.7915e+00, 1.7820e+00,\n", + " 4.7887e-01, 1.1222e+00, -6.1201e+00, 1.3118e+01, 9.0041e+00,\n", + " 4.9772e+00, -7.0921e+00, 2.8077e-01, 3.6044e+00, 1.8358e+00,\n", + " -2.2833e+00, -6.1981e+00, 1.5658e+01, 3.6380e+00, 1.1059e+00,\n", + " -1.2764e-01, 6.4083e-01, -1.8858e+00, 2.6034e+00, -3.1585e+00,\n", + " -1.9641e+00, 4.2007e+00, 6.5796e-02, -6.1856e-01, -2.3976e+00,\n", + " -8.2248e-01, 2.6342e+00, 1.0856e+00, 3.1940e+00, -7.0448e+00,\n", + " 1.8103e+00, 5.3263e+00, 9.3208e+00, 6.1828e+00, 4.8899e+00,\n", + " -3.5434e+00, 3.8317e+00, 8.1334e-01, 1.4348e+00, -5.5027e+00,\n", + " 5.6279e+00, -1.4003e+01, 1.5963e+00, -3.8147e+00, 3.2877e+00,\n", + " -2.3182e+00, -1.9225e+00, 4.8933e+00, 1.3743e+00, 7.1876e+00,\n", + " 1.7507e+00, -1.3171e+01, -1.0366e+00, 6.4689e+00, 4.3425e+00,\n", + " 3.5808e+00, -5.5930e+00, -3.2690e-01, 5.2676e+00, 2.4258e+00],\n", + " device='cuda:0')\n" + ] + }, + { + "ename": "TypeError", + "evalue": "topk() got multiple values for argument 'k'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[18]\u001b[39m\u001b[32m, line 33\u001b[39m\n\u001b[32m 31\u001b[39m top_k = \u001b[32m3\u001b[39m\n\u001b[32m 32\u001b[39m \u001b[38;5;28mprint\u001b[39m(similarities)\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m top_similarities, top_indices = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtopk\u001b[49m\u001b[43m(\u001b[49m\u001b[43msimilarities\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 35\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m pos \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mmin\u001b[39m(\u001b[32m5\u001b[39m, seq_embeds.shape[\u001b[32m0\u001b[39m])): \u001b[38;5;66;03m# Show first 5 positions\u001b[39;00m\n\u001b[32m 36\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mPosition \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpos\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m:\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mTypeError\u001b[39m: topk() got multiple values for argument 'k'" + ] + } + ], + "source": [ + "# Generate new sequences by sampling from latent space\n", + "ts_model.eval()\n", + "with torch.no_grad():\n", + " # Sample random latent vectors\n", + " latent_samples = torch.randn(3, ts_model.hidden_dim).to(ts_model.device)\n", + " \n", + " # Decode to get sequence representations\n", + " generated_sequences = ts_model.decoder(latent_samples)\n", + " \n", + " print(\"Generated sequence representations:\")\n", + " print(f\"Shape: {generated_sequences.shape}\")\n", + " print(f\"Sample values: {generated_sequences[0, :5].cpu().numpy()}\")\n", + " \n", + " # Convert embeddings to human-understandable medical codes\n", + " # Find closest codes in embedding space\n", + " \n", + " # Get all code embeddings from the embedding model\n", + " conditions_vocab = list(ts_dataset.input_processors['conditions'].code_vocab.keys())\n", + " all_codes = conditions_vocab # Use conditions vocabulary\n", + " code_embeddings = ts_model.embedding_model.embedding_layers['conditions'].weight.data # [vocab_size, embed_dim]\n", + " \n", + " print(f\"\\nConverting to medical codes for generated sequence 0:\")\n", + " \n", + " # The generated sequence is a single embedding vector, not a sequence\n", + " seq_embed = generated_sequences[0] # [embed_dim]\n", + " \n", + " # Compute cosine similarity with all code embeddings\n", + " similarities = torch.matmul(seq_embed, code_embeddings.t()) # [vocab_size]\n", + " \n", + " # Get top 3 most similar codes\n", + " top_k = 3\n", + " top_similarities, top_indices = torch.topk(similarities, top_k, dim=0)\n", + " \n", + " codes = [all_codes[idx] for idx in top_indices.cpu().numpy()]\n", + " sims = top_similarities.cpu().numpy()\n", + " print(f\"Top {top_k} similar medical codes: {codes}\")\n", + " print(f\"Similarities: {sims}\")\n", + " \n", + " print(\"\\nNote: These represent the most likely medical codes for the generated sequence.\")\n", + " print(\"In practice, you might use beam search or other decoding strategies for better results.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Key Insights\n", + "\n", + "### How Time-Series VAE Works:\n", + "1. **Input Processing**: Categorical sequences (diagnoses, procedures) are embedded using the EmbeddingModel\n", + "2. **Sequence Encoding**: RNN processes the embedded sequence to capture temporal patterns\n", + "3. **Latent Compression**: Variable-length sequences become fixed-size latent vectors\n", + "4. **Reconstruction**: Decoder attempts to recreate the embedded sequence representation\n", + "5. **Code Generation**: Generated embeddings are mapped back to medical codes using nearest neighbor search\n", + "\n", + "### Medical Applications:\n", + "- **Trajectory Analysis**: Understand typical patient progression patterns from real MIMIC4 data\n", + "- **Synthetic Data**: Generate realistic patient histories for research and model training\n", + "- **Anomaly Detection**: Identify unusual treatment sequences in clinical practice\n", + "- **Outcome Prediction**: Learn sequence patterns that correlate with mortality and other outcomes\n", + "- **Data Augmentation**: Create additional training samples for underrepresented conditions\n", + "\n", + "### Key Improvements in This Version:\n", + "- **Real Data**: Uses MIMIC4 demo dataset instead of synthetic data for more realistic modeling\n", + "- **Multiple Sequences**: Models both diagnoses and procedures simultaneously\n", + "- **Human-Readable Output**: Converts generated embeddings back to interpretable medical codes\n", + "- **Clinical Relevance**: Focuses on in-hospital mortality prediction task\n", + "\n", + "### Differences from Image VAE:\n", + "- **Temporal vs Spatial**: Captures time-ordered dependencies instead of spatial patterns\n", + "- **Variable Length**: Handles sequences of different lengths\n", + "- **Categorical Data**: Works with medical codes, diagnoses, treatments\n", + "- **Generation**: Creates new realistic patient trajectories with interpretable medical codes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 3390453ff..a7c0cb5c7 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -126,7 +126,7 @@ def __init__( if len(set(tables)) != len(tables): logger.warning("Duplicate table names in tables list. Removing duplicates.") tables = list(set(tables)) - self.root = root + self.root = os.path.expanduser(root) self.tables = tables self.dataset_name = dataset_name or self.__class__.__name__ self.config = load_yaml_config(config_path) diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py index 42f5671b5..bba2cdf3f 100644 --- a/pyhealth/datasets/covid19_cxr.py +++ b/pyhealth/datasets/covid19_cxr.py @@ -92,6 +92,14 @@ def __init__( if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "covid19_cxr.yaml" + root = os.path.expanduser(root) + if not self._check_raw_data_exists(root): + raise ValueError( + f"Raw COVID-19 CXR dataset files not found in {root}. " + "Please download the dataset from " + "https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database " + "and extract the contents to the specified root directory." + ) if not os.path.exists(os.path.join(root, "covid19_cxr-metadata-pyhealth.csv")): self.prepare_metadata(root) default_tables = ["covid19_cxr"] @@ -149,6 +157,40 @@ def prepare_metadata(self, root: str) -> None: df.to_csv(os.path.join(root, "covid19_cxr-metadata-pyhealth.csv"), index=False) return + def _check_raw_data_exists(self, root: str) -> bool: + """Check if required raw data files exist. + + Args: + root: Root directory containing the dataset files. + + Returns: + bool: True if all required files exist, False otherwise. + """ + required_files = [ + "COVID.metadata.xlsx", + "Lung_Opacity.metadata.xlsx", + "Normal.metadata.xlsx", + "Viral Pneumonia.metadata.xlsx", + ] + return all(os.path.exists(os.path.join(root, f)) for f in required_files) + + def _check_raw_data_exists(self, root: str) -> bool: + """Check if required raw data files exist. + + Args: + root: Root directory containing the dataset files. + + Returns: + bool: True if all required files exist, False otherwise. + """ + required_files = [ + "COVID.metadata.xlsx", + "Lung_Opacity.metadata.xlsx", + "Normal.metadata.xlsx", + "Viral Pneumonia.metadata.xlsx", + ] + return all(os.path.exists(os.path.join(root, f)) for f in required_files) + def set_task( self, task: BaseTask | None = None, diff --git a/pyhealth/models/vae.py b/pyhealth/models/vae.py index e46c395e1..7078ec5d1 100644 --- a/pyhealth/models/vae.py +++ b/pyhealth/models/vae.py @@ -8,16 +8,16 @@ import torch.nn.functional as F from pyhealth.datasets import BaseSignalDataset -from pyhealth.models import BaseModel, ResBlock2D +from pyhealth.models import BaseModel, ResBlock2D, EmbeddingModel class VAE(BaseModel): - """VAE model (take 128x128 or 64x64 or 32x32 images) + """VAE model for images or time-series data. Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." - Note: - We use CNN models as the encoder and decoder layers for now. + Supports both image generation/reconstruction and time-series modeling. + Images mode take 128x128 or 64x64 or 32x32 images. Args: dataset: the dataset to train the model. It is used to query certain @@ -26,129 +26,245 @@ class VAE(BaseModel): e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". - embedding_dim: the embedding dimension. Default is 128. - hidden_dim: the hidden dimension. Default is 128. - **kwargs: other parameters for the Deepr layer. - - Examples: + input_type: 'image' for CNN-based VAE on images, 'timeseries' for RNN-based on sequences. Default 'image'. + input_channel: number of input channels (for images). Required if input_type='image'. + input_size: input image size (for images, e.g. 128). Required if input_type='image'. + hidden_dim: the latent dimension. Default is 128. + conditional_feature_keys: list of feature keys to use as conditions for generation (optional). + **kwargs: other parameters. """ def __init__( self, - dataset: BaseSignalDataset, + dataset, feature_keys: List[str], label_key: str, - input_channel: int, - input_size: int, mode: str, + input_type: str = "image", + input_channel: Optional[int] = None, + input_size: Optional[int] = None, hidden_dim: int = 128, + conditional_feature_keys: Optional[List[str]] = None, **kwargs, ): - super(VAE, self).__init__( - dataset=dataset, - feature_keys=feature_keys, - label_key=label_key, - mode=mode, - ) + super(VAE, self).__init__(dataset=dataset) + self.input_type = input_type self.hidden_dim = hidden_dim + self.conditional_feature_keys = conditional_feature_keys + self.mode = mode + self.feature_keys = feature_keys + self.label_key = label_key - # encoder part - if input_size == 128: - self.encoder1 = nn.Sequential( - ResBlock2D(input_channel, 16, 2, True, True), - ResBlock2D(16, 64, 2, True, True), - ResBlock2D(64, 256, 2, True, True), - ) - self.mu = nn.Linear(256 * 2 * 2, self.hidden_dim) # for mu - self.log_std2 = nn.Linear(256 * 2 * 2, self.hidden_dim) # for log (sigma^2) - - self.decoder1 = nn.Sequential( - nn.ConvTranspose2d(self.hidden_dim, 256, kernel_size=5, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), - nn.Sigmoid(), - ) - - elif input_size == 64: - self.encoder1 = nn.Sequential( - ResBlock2D(input_channel, 16, 2, True, True), - ResBlock2D(16, 64, 2, True, True), - ResBlock2D(64, 256, 2, True, True), - ) - self.mu = nn.Linear(256, self.hidden_dim) # for mu - self.log_std2 = nn.Linear(256, self.hidden_dim) # for log (sigma^2) - - self.decoder1 = nn.Sequential( - nn.ConvTranspose2d(self.hidden_dim, 128, kernel_size=5, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), - nn.Sigmoid(), - ) - - elif input_size == 32: - self.encoder1 = nn.Sequential( - ResBlock2D(input_channel, 16, 2, True, True), - ResBlock2D(16, 64, 2, True, True), - # ResBlock2D(64, 256, 2, True, True), - ) - self.mu = nn.Linear(64 * 2 * 2, self.hidden_dim) # for mu - self.log_std2 = nn.Linear(64 * 2 * 2, self.hidden_dim) # for log (sigma^2) - - self.decoder1 = nn.Sequential( - nn.ConvTranspose2d(self.hidden_dim, 64, kernel_size=5, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), - nn.ReLU(), - nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), - nn.Sigmoid(), - ) - + # These will be lazily initialized when we see actual tensor sizes + self.cond_proj: Optional[nn.Linear] = None # for conditional metadata → latent + self.ts_proj: Optional[nn.Linear] = None # for timeseries concatenated features → hidden_dim + + if input_type == "image": + assert input_channel is not None and input_size is not None, \ + "For image mode, input_channel and input_size must be provided" + + # Embedding model for conditional features only (if used) + if conditional_feature_keys: + self.embedding_model = EmbeddingModel(dataset, embedding_dim=hidden_dim) + + # ----- Encoder ----- + if input_size == 128: + self.encoder1 = nn.Sequential( + ResBlock2D(input_channel, 16, 2, True, True), + ResBlock2D(16, 64, 2, True, True), + ResBlock2D(64, 256, 2, True, True), + ) + self.mu = nn.Linear(256 * 2 * 2, self.hidden_dim) + self.log_std2 = nn.Linear(256 * 2 * 2, self.hidden_dim) + + self.decoder1 = nn.Sequential( + nn.ConvTranspose2d(self.hidden_dim, 256, kernel_size=5, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), + nn.Sigmoid(), + ) + + elif input_size == 64: + self.encoder1 = nn.Sequential( + ResBlock2D(input_channel, 16, 2, True, True), + ResBlock2D(16, 64, 2, True, True), + ResBlock2D(64, 256, 2, True, True), + ) + self.mu = nn.Linear(256, self.hidden_dim) + self.log_std2 = nn.Linear(256, self.hidden_dim) + + self.decoder1 = nn.Sequential( + nn.ConvTranspose2d(self.hidden_dim, 128, kernel_size=5, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), + nn.Sigmoid(), + ) + + elif input_size == 32: + self.encoder1 = nn.Sequential( + ResBlock2D(input_channel, 16, 2, True, True), + ResBlock2D(16, 64, 2, True, True), + # ResBlock2D(64, 256, 2, True, True), + ) + self.mu = nn.Linear(64 * 2 * 2, self.hidden_dim) + self.log_std2 = nn.Linear(64 * 2 * 2, self.hidden_dim) + + self.decoder1 = nn.Sequential( + nn.ConvTranspose2d(self.hidden_dim, 64, kernel_size=5, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), + nn.ReLU(), + nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), + nn.Sigmoid(), + ) + else: + raise ValueError("Unsupported input_size for image mode") + + elif input_type == "timeseries": + # Embedding model for sequence features + self.embedding_model = EmbeddingModel(dataset, embedding_dim=hidden_dim) + self.encoder_rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True) + self.mu = nn.Linear(hidden_dim, hidden_dim) + self.log_std2 = nn.Linear(hidden_dim, hidden_dim) + self.decoder_linear = nn.Linear(hidden_dim, hidden_dim) + else: + raise ValueError("input_type must be 'image' or 'timeseries'") + + # ------------------------------------------------------------- + # ENCODER + # ------------------------------------------------------------- def encoder(self, x) -> Tuple[torch.Tensor, torch.Tensor]: - h = self.encoder1(x) - batch_size = h.shape[0] - h = h.view(batch_size, -1) - return self.mu(h), torch.sqrt(torch.exp(self.log_std2(h))) - - def sampling(self, mu, std) -> torch.Tensor: # reparameterization trick + if self.input_type == "image": + h = self.encoder1(x) + batch_size = h.shape[0] + h = h.view(batch_size, -1) + mu = self.mu(h) + std = torch.sqrt(torch.exp(self.log_std2(h))) + + elif self.input_type == "timeseries": + # x is dict of embedded features from embedding_model + embedded_list = [] + for key, emb in x.items(): + if emb.dim() == 3: # (batch, seq, emb) + _, h_seq = self.encoder_rnn(emb) # h_seq: (1, batch, hidden_dim) + h = h_seq.squeeze(0) # (batch, hidden_dim) + else: + h = emb # (batch, emb_dim) + embedded_list.append(h) + + h = torch.cat(embedded_list, dim=-1) if len(embedded_list) > 1 else embedded_list[0] + + # Project to hidden_dim once, with a learnable layer + if h.shape[-1] != self.hidden_dim: + if self.ts_proj is None: + self.ts_proj = nn.Linear(h.shape[-1], self.hidden_dim).to(h.device) + h = self.ts_proj(h) + + mu = self.mu(h) + std = torch.sqrt(torch.exp(self.log_std2(h))) + + return mu, std + + # ------------------------------------------------------------- + # SAMPLING + # ------------------------------------------------------------- + def sampling(self, mu, std) -> torch.Tensor: eps = torch.randn_like(std) return mu + eps * std + # ------------------------------------------------------------- + # DECODER + # ------------------------------------------------------------- def decoder(self, z) -> torch.Tensor: - x_hat = self.decoder1(z) + if self.input_type == "image": + x_hat = self.decoder1(z) + elif self.input_type == "timeseries": + x_hat = self.decoder_linear(z) # (batch, hidden_dim) return x_hat - - @staticmethod - def loss_function(y, x, mu, std): - ERR = F.binary_cross_entropy(y, x, reduction='sum') - KLD = -0.5 * torch.sum(1 + torch.log(std**2) - mu**2 - std**2) + + # ------------------------------------------------------------- + # LOSS + # ------------------------------------------------------------- + def loss_function(self, y, x, mu, std): + if self.input_type == "image": + ERR = F.binary_cross_entropy(y, x, reduction='sum') + elif self.input_type == "timeseries": + ERR = F.mse_loss(y, x, reduction='sum') + + # KL divergence term + KLD = -0.5 * torch.sum(1 + torch.log(std ** 2) - mu ** 2 - std ** 2) return ERR + KLD - + + # ------------------------------------------------------------- + # FORWARD + # ------------------------------------------------------------- def forward(self, **kwargs) -> Dict[str, torch.Tensor]: - - # concat the info within one batch (batch, channel, height, width) - # if the input is a list of numpy array, we need to convert it to tensor - if isinstance(kwargs[self.feature_keys[0]][0], np.ndarray): - x = torch.tensor( - np.array(kwargs[self.feature_keys[0]]).astype("float16"), device=self.device - ).float() - else: - x = torch.stack(kwargs[self.feature_keys[0]], dim=0).to(self.device) - - mu, std = self.encoder(x) - z = self.sampling(mu, std) - z = z.unsqueeze(2).unsqueeze(3) - x_rec = self.decoder(z) - + + if self.input_type == "image": + # x: image tensor + x = kwargs[self.feature_keys[0]].to(self.device) # (batch, C, H, W) + + mu, std = self.encoder(x) + z = self.sampling(mu, std) # (batch, hidden_dim) + + # Conditional embeddings (if any) + if self.conditional_feature_keys: + cond_raw = {k: kwargs[k] for k in self.conditional_feature_keys} + cond_emb = self.embedding_model(cond_raw) # dict: key -> tensor + + # Pool temporal dims if needed and concat + cond_list = [ + emb.mean(dim=1) if emb.dim() == 3 else emb + for emb in cond_emb.values() + ] + cond_vec = torch.cat(cond_list, dim=-1) if len(cond_list) > 1 else cond_list[0] + + # Project cond_vec to latent dim with a learnable layer + if cond_vec.shape[-1] != self.hidden_dim: + if self.cond_proj is None: + self.cond_proj = nn.Linear(cond_vec.shape[-1], self.hidden_dim).to(cond_vec.device) + cond_vec = self.cond_proj(cond_vec) + + # Simple conditioning: shift latent by conditional vector + z = z + cond_vec + + # Prepare for ConvTranspose decoder + z = z.unsqueeze(2).unsqueeze(3) # (batch, hidden_dim, 1, 1) + x_rec = self.decoder(z) + + elif self.input_type == "timeseries": + # Embed all feature_keys first + embedded = self.embedding_model({k: kwargs[k] for k in self.feature_keys}) + mu, std = self.encoder(embedded) + z = self.sampling(mu, std) + x_rec = self.decoder(z) + + # For reconstruction target x: re-aggregate exactly as in encoder + embedded_list = [] + for key, emb in embedded.items(): + if emb.dim() == 3: # (batch, seq, emb) + _, h_seq = self.encoder_rnn(emb) + h = h_seq.squeeze(0) # (batch, hidden_dim) + else: + h = emb + embedded_list.append(h) + + x = torch.cat(embedded_list, dim=-1) if len(embedded_list) > 1 else embedded_list[0] + if x.shape[-1] != self.hidden_dim: + if self.ts_proj is None: + self.ts_proj = nn.Linear(x.shape[-1], self.hidden_dim).to(x.device) + x = self.ts_proj(x) + loss = self.loss_function(x_rec, x, mu, std) results = { "loss": loss, @@ -158,6 +274,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: return results + if __name__ == "__main__": from pyhealth.datasets import SampleSignalDataset, get_dataloader from pyhealth.datasets import COVID19CXRDataset @@ -188,12 +305,13 @@ def encode(sample): # model model = VAE( dataset=sample_dataset, - input_channel=3, - input_size=128, feature_keys=["path"], label_key="path", mode="regression", - hidden_dim = 256, + input_type="image", + input_channel=3, + input_size=128, + hidden_dim=256, ).to("cuda") # data batch diff --git a/pyproject.toml b/pyproject.toml index ceedcdd0b..339c71429 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "openpyxl~=3.1.0" ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/test_covid19_cxr.py b/tests/core/test_covid19_cxr.py new file mode 100644 index 000000000..2d2b079cf --- /dev/null +++ b/tests/core/test_covid19_cxr.py @@ -0,0 +1,88 @@ +""" +Unit tests for the COVID19CXRDataset class. + +The tests simply check that the dataset initialization works as expected, +creating mock files and ensuring COVID19CXRDataset behaves correctly when +raw data files are present or absent. + +Author: + Giovanni M. Dall'Olio +""" +import os +import shutil +import tempfile +import unittest +from unittest.mock import patch + +from pyhealth.datasets import COVID19CXRDataset + + +class TestCOVID19CXRDataset(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_check_raw_data_exists_returns_false_when_files_missing(self): + """Test _check_raw_data_exists returns False when required files are missing.""" + dataset = COVID19CXRDataset.__new__(COVID19CXRDataset) # Create without __init__ + result = dataset._check_raw_data_exists(self.temp_dir) + self.assertFalse(result) + + def test_check_raw_data_exists_returns_true_when_files_present(self): + """Test _check_raw_data_exists returns True when all required files are present.""" + # Create dummy files + required_files = [ + "COVID.metadata.xlsx", + "Lung_Opacity.metadata.xlsx", + "Normal.metadata.xlsx", + "Viral Pneumonia.metadata.xlsx", + ] + for filename in required_files: + with open(os.path.join(self.temp_dir, filename), 'w') as f: + f.write("dummy content") + + dataset = COVID19CXRDataset.__new__(COVID19CXRDataset) + result = dataset._check_raw_data_exists(self.temp_dir) + self.assertTrue(result) + + def test_init_raises_value_error_when_raw_data_missing(self): + """Test that __init__ raises ValueError with correct message when raw data is missing.""" + with self.assertRaises(ValueError) as context: + COVID19CXRDataset(root=self.temp_dir) + + expected_message = ( + f"Raw COVID-19 CXR dataset files not found in {self.temp_dir}. " + "Please download the dataset from " + "https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database " + "and extract the contents to the specified root directory." + ) + self.assertEqual(str(context.exception), expected_message) + + @patch('pyhealth.datasets.covid19_cxr.COVID19CXRDataset._check_raw_data_exists') + @patch('pyhealth.datasets.base_dataset.BaseDataset.__init__') + def test_init_works_when_raw_data_present(self, mock_base_init, mock_check): + """Test that __init__ works when raw data is present.""" + mock_check.return_value = True + + # Mock the metadata file as existing + metadata_path = os.path.join(self.temp_dir, "covid19_cxr-metadata-pyhealth.csv") + with open(metadata_path, 'w') as f: + f.write("""path,url,label +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-1.png,https://sirm.org/category/senza-categoria/covid-19/,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-2.png,https://sirm.org/category/senza-categoria/covid-19/,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-3.png,https://sirm.org/category/senza-categoria/covid-19/,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-167.png,https://github.com/ml-workgroup/covid-19-image-repository/tree/master/png,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-336.png,https://eurorad.org,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-967.png,https://github.com/ieee8023/covid-chestxray-dataset,COVID""" + ) + + dataset = COVID19CXRDataset(root=self.temp_dir) + + # Check that BaseDataset.__init__ was called + mock_base_init.assert_called_once() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_vae.py b/tests/core/test_vae.py new file mode 100644 index 000000000..294886a32 --- /dev/null +++ b/tests/core/test_vae.py @@ -0,0 +1,235 @@ +import unittest +import torch +import os + +from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.models import VAE + + +class TestVAE(unittest.TestCase): + """Test cases for the VAE model.""" + + def test_model_initialization_image(self): + """Test VAE initialization for image mode.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "image": torch.randn(1, 128, 128), + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "image": torch.randn(1, 128, 128), + "label": 0, + }, + ] + + input_schema = {"image": "tensor"} + output_schema = {"label": "binary"} + + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_image", + ) + + model = VAE( + dataset=dataset, + feature_keys=["image"], + label_key="label", + mode="binary", + input_type="image", + input_channel=1, # assuming grayscale + input_size=128, # use 128 + hidden_dim=64, + ) + + self.assertIsInstance(model, VAE) + self.assertEqual(model.input_type, "image") + self.assertEqual(model.hidden_dim, 64) + + def test_model_forward_image(self): + """Test VAE forward pass for image mode.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "image": torch.rand(1, 128, 128), # dummy image 0-1 + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "image": torch.rand(1, 128, 128), + "label": 0, + }, + ] + + input_schema = {"image": "tensor"} + output_schema = {"label": "binary"} + + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_image", + ) + + model = VAE( + dataset=dataset, + feature_keys=["image"], + label_key="label", + mode="binary", + input_type="image", + input_channel=1, + input_size=128, + hidden_dim=64, + ) + + train_loader = get_dataloader(dataset, batch_size=1, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + + def test_model_initialization_timeseries(self): + """Test VAE initialization for timeseries mode.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86"], + "label": 1.0, # dummy + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33"], + "label": 0.5, + }, + ] + + input_schema = {"conditions": "sequence"} + output_schema = {"label": "regression"} + + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_ts", + ) + + model = VAE( + dataset=dataset, + feature_keys=["conditions"], + label_key="label", + mode="regression", + input_type="timeseries", + hidden_dim=64, + ) + + self.assertIsInstance(model, VAE) + self.assertEqual(model.input_type, "timeseries") + self.assertTrue(hasattr(model, "embedding_model")) + + def test_model_forward_timeseries(self): + """Test VAE forward pass for timeseries mode.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86"], + "label": 1.0, + }, + ] + + input_schema = {"conditions": "sequence"} + output_schema = {"label": "regression"} + + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_ts", + ) + + model = VAE( + dataset=dataset, + feature_keys=["conditions"], + label_key="label", + mode="regression", + input_type="timeseries", + hidden_dim=64, + ) + + train_loader = get_dataloader(dataset, batch_size=1, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + + def test_conditional_image_vae(self): + """Test VAE with conditional features for image mode.""" + samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "image": torch.rand(1, 128, 128), + "conditions": ["cond-33"], + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "image": torch.randn(1, 128, 128), + "conditions": ["cond-86"], + "label": 0, + }, + ] + + input_schema = {"image": "tensor", "conditions": "sequence"} + output_schema = {"label": "binary"} + + dataset = SampleDataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="test_cond", + ) + + model = VAE( + dataset=dataset, + feature_keys=["image"], + label_key="label", + mode="binary", + input_type="image", + input_channel=1, + input_size=128, + hidden_dim=64, + conditional_feature_keys=["conditions"], + ) + + self.assertTrue(hasattr(model, "embedding_model")) + + train_loader = get_dataloader(dataset, batch_size=1, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model(**data_batch) + + self.assertIn("loss", ret) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file