From a335dcfbac303d178144256e5f9b100039166cb1 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 17:31:54 +0000 Subject: [PATCH 1/9] started notebook --- examples/chestXray_image_generation_VAE.ipynb | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 examples/chestXray_image_generation_VAE.ipynb diff --git a/examples/chestXray_image_generation_VAE.ipynb b/examples/chestXray_image_generation_VAE.ipynb new file mode 100644 index 000000000..1742c903b --- /dev/null +++ b/examples/chestXray_image_generation_VAE.ipynb @@ -0,0 +1,19 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "fe32019b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From c931d82c6d019e0e41027111dbef5b6eaeee5d4b Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 17:44:42 +0000 Subject: [PATCH 2/9] created new base_image_dataset class + tests --- pyhealth/datasets/__init__.py | 1 + pyhealth/datasets/base_image_dataset.py | 84 +++++++++++++++++++++++++ pyhealth/datasets/covid19_cxr.py | 28 ++++++++- tests/core/test_base_image_dataset.py | 83 ++++++++++++++++++++++++ tests/core/test_covid19_cxr.py | 80 +++++++++++++++++++++++ 5 files changed, 274 insertions(+), 2 deletions(-) create mode 100644 pyhealth/datasets/base_image_dataset.py create mode 100644 tests/core/test_base_image_dataset.py create mode 100644 tests/core/test_covid19_cxr.py diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index ced02afd7..a68660c22 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -47,6 +47,7 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset +from .base_image_dataset import BaseImageDataset from .cardiology import CardiologyDataset from .chestxray14 import ChestXray14Dataset from .covid19_cxr import COVID19CXRDataset diff --git a/pyhealth/datasets/base_image_dataset.py b/pyhealth/datasets/base_image_dataset.py new file mode 100644 index 000000000..fabb47b8e --- /dev/null +++ b/pyhealth/datasets/base_image_dataset.py @@ -0,0 +1,84 @@ +import logging +from typing import Dict, Optional + +from pyhealth.datasets.base_dataset import BaseDataset +from pyhealth.processors.image_processor import ImageProcessor +from pyhealth.tasks.base_task import BaseTask +from pyhealth.datasets.sample_dataset import SampleDataset + +logger = logging.getLogger(__name__) + + +class BaseImageDataset(BaseDataset): + """Base class for image datasets in PyHealth. + + This class provides common functionality for loading and processing image data, + including default image processing setup. + + Args: + root: Root directory of the raw data containing the dataset files. + dataset_name: Optional name of the dataset. Defaults to "base_image". + config_path: Optional path to the configuration file. + + Attributes: + root: Root directory of the raw data. + dataset_name: Name of the dataset. + config_path: Path to the configuration file. + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__( + root=root, + dataset_name=dataset_name or "base_image", + config_path=config_path, + **kwargs, + ) + + def set_task( + self, + task: BaseTask | None = None, + num_workers: int = 1, + cache_dir: str | None = None, + cache_format: str = "parquet", + input_processors: Dict[str, any] | None = None, + output_processors: Dict[str, any] | None = None, + ) -> SampleDataset: + """Set the task for the dataset with default image processing. + + If no image processor is provided, defaults to ImageProcessor with + image_size=299 and mode="L" (grayscale). + + Args: + task: The task to set. + num_workers: Number of workers for processing. + cache_dir: Directory for caching. + cache_format: Format for caching. + input_processors: Input processors. + output_processors: Output processors. + + Returns: + SampleDataset: The processed sample dataset. + """ + if input_processors is None or "image" not in input_processors: + image_processor = ImageProcessor( + image_size=299, # Default image size + mode="L", # Grayscale by default + ) + if input_processors is None: + input_processors = {} + input_processors["image"] = image_processor + + return super().set_task( + task, + num_workers, + cache_dir, + cache_format, + input_processors, + output_processors, + ) \ No newline at end of file diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py index 42f5671b5..a043b4b36 100644 --- a/pyhealth/datasets/covid19_cxr.py +++ b/pyhealth/datasets/covid19_cxr.py @@ -11,12 +11,12 @@ from pyhealth.tasks.base_task import BaseTask from ..tasks import COVID19CXRClassification -from .base_dataset import BaseDataset +from .base_image_dataset import BaseImageDataset logger = logging.getLogger(__name__) -class COVID19CXRDataset(BaseDataset): +class COVID19CXRDataset(BaseImageDataset): """Base image dataset for COVID-19 Radiography Database. Dataset is available at: @@ -92,6 +92,13 @@ def __init__( if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "covid19_cxr.yaml" + if not self._check_raw_data_exists(root): + raise ValueError( + f"Raw COVID-19 CXR dataset files not found in {root}. " + "Please download the dataset from " + "https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database " + "and extract the contents to the specified root directory." + ) if not os.path.exists(os.path.join(root, "covid19_cxr-metadata-pyhealth.csv")): self.prepare_metadata(root) default_tables = ["covid19_cxr"] @@ -149,6 +156,23 @@ def prepare_metadata(self, root: str) -> None: df.to_csv(os.path.join(root, "covid19_cxr-metadata-pyhealth.csv"), index=False) return + def _check_raw_data_exists(self, root: str) -> bool: + """Check if required raw data files exist. + + Args: + root: Root directory containing the dataset files. + + Returns: + bool: True if all required files exist, False otherwise. + """ + required_files = [ + "COVID.metadata.xlsx", + "Lung_Opacity.metadata.xlsx", + "Normal.metadata.xlsx", + "Viral Pneumonia.metadata.xlsx", + ] + return all(os.path.exists(os.path.join(root, f)) for f in required_files) + def set_task( self, task: BaseTask | None = None, diff --git a/tests/core/test_base_image_dataset.py b/tests/core/test_base_image_dataset.py new file mode 100644 index 000000000..4d751591c --- /dev/null +++ b/tests/core/test_base_image_dataset.py @@ -0,0 +1,83 @@ +""" +Unit tests for the BaseImageDataset class. + +Author: + Kilo Code +""" +import unittest +from unittest.mock import patch, MagicMock + +from pyhealth.datasets import BaseImageDataset +from pyhealth.processors import ImageProcessor + + +class TestBaseImageDataset(unittest.TestCase): + def setUp(self): + # Mock the BaseDataset __init__ to avoid dependencies + with patch('pyhealth.datasets.base_dataset.BaseDataset.__init__', return_value=None): + self.dataset = BaseImageDataset(root="/fake/root") + + def test_set_task_adds_image_processor_when_missing(self): + """Test that set_task adds ImageProcessor when 'image' key is missing from input_processors.""" + # Mock BaseDataset.set_task to capture arguments + with patch('pyhealth.datasets.base_dataset.BaseDataset.set_task') as mock_super_set_task: + mock_super_set_task.return_value = MagicMock() + + result = self.dataset.set_task() + + # Check that super().set_task was called + mock_super_set_task.assert_called_once() + args, kwargs = mock_super_set_task.call_args + + # input_processors is args[4] + input_processors = args[4] + self.assertIsNotNone(input_processors) + self.assertIn('image', input_processors) + self.assertIsInstance(input_processors['image'], ImageProcessor) + + # Check default values + image_proc = input_processors['image'] + self.assertEqual(image_proc.image_size, 299) + self.assertEqual(image_proc.mode, "L") + + def test_set_task_does_not_override_existing_image_processor(self): + """Test that set_task does not override existing ImageProcessor.""" + custom_processor = ImageProcessor(image_size=128, mode="RGB") + + with patch('pyhealth.datasets.base_dataset.BaseDataset.set_task') as mock_super_set_task: + mock_super_set_task.return_value = MagicMock() + + result = self.dataset.set_task(input_processors={'image': custom_processor}) + + mock_super_set_task.assert_called_once() + args, kwargs = mock_super_set_task.call_args + + # input_processors is args[4] + input_processors = args[4] + self.assertIsNotNone(input_processors) + self.assertIn('image', input_processors) + self.assertIs(input_processors['image'], custom_processor) + + def test_set_task_preserves_other_processors(self): + """Test that set_task preserves other processors while adding image processor.""" + other_processor = MagicMock() + + with patch('pyhealth.datasets.base_dataset.BaseDataset.set_task') as mock_super_set_task: + mock_super_set_task.return_value = MagicMock() + + result = self.dataset.set_task(input_processors={'other': other_processor}) + + mock_super_set_task.assert_called_once() + args, kwargs = mock_super_set_task.call_args + + # input_processors is args[4] + input_processors = args[4] + self.assertIsNotNone(input_processors) + self.assertIn('image', input_processors) + self.assertIn('other', input_processors) + self.assertIs(input_processors['other'], other_processor) + self.assertIsInstance(input_processors['image'], ImageProcessor) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/core/test_covid19_cxr.py b/tests/core/test_covid19_cxr.py new file mode 100644 index 000000000..41eba3a1d --- /dev/null +++ b/tests/core/test_covid19_cxr.py @@ -0,0 +1,80 @@ +""" +Unit tests for the COVID19CXRDataset class. + +Author: + Kilo Code +""" +import os +import shutil +import tempfile +import unittest +from unittest.mock import patch + +from pyhealth.datasets import COVID19CXRDataset +from pyhealth.tasks import COVID19CXRClassification + + +class TestCOVID19CXRDataset(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_check_raw_data_exists_returns_false_when_files_missing(self): + """Test _check_raw_data_exists returns False when required files are missing.""" + dataset = COVID19CXRDataset.__new__(COVID19CXRDataset) # Create without __init__ + result = dataset._check_raw_data_exists(self.temp_dir) + self.assertFalse(result) + + def test_check_raw_data_exists_returns_true_when_files_present(self): + """Test _check_raw_data_exists returns True when all required files are present.""" + # Create dummy files + required_files = [ + "COVID.metadata.xlsx", + "Lung_Opacity.metadata.xlsx", + "Normal.metadata.xlsx", + "Viral Pneumonia.metadata.xlsx", + ] + for filename in required_files: + with open(os.path.join(self.temp_dir, filename), 'w') as f: + f.write("dummy content") + + dataset = COVID19CXRDataset.__new__(COVID19CXRDataset) + result = dataset._check_raw_data_exists(self.temp_dir) + self.assertTrue(result) + + def test_init_raises_value_error_when_raw_data_missing(self): + """Test that __init__ raises ValueError with correct message when raw data is missing.""" + with self.assertRaises(ValueError) as context: + COVID19CXRDataset(root=self.temp_dir) + + expected_message = ( + f"Raw COVID-19 CXR dataset files not found in {self.temp_dir}. " + "Please download the dataset from " + "https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database " + "and extract the contents to the specified root directory." + ) + self.assertEqual(str(context.exception), expected_message) + + @patch('pyhealth.datasets.covid19_cxr.COVID19CXRDataset._check_raw_data_exists') + @patch('pyhealth.datasets.base_dataset.BaseDataset.__init__') + def test_init_works_when_raw_data_present(self, mock_base_init, mock_check): + """Test that __init__ works when raw data is present.""" + mock_check.return_value = True + + # Mock the metadata file as existing + metadata_path = os.path.join(self.temp_dir, "covid19_cxr-metadata-pyhealth.csv") + with open(metadata_path, 'w') as f: + f.write("path\nfake_path.png") + + dataset = COVID19CXRDataset(root=self.temp_dir) + + # Check that BaseDataset.__init__ was called + mock_base_init.assert_called_once() + + # Removed test_default_task as it requires instantiation + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From bda111f5d46ed8bfa4632124232a698135f9f653 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 18:00:41 +0000 Subject: [PATCH 3/9] updated tests --- examples/chestXray_image_generation_VAE.ipynb | 19 ------------------- tests/core/test_base_image_dataset.py | 2 +- tests/core/test_covid19_cxr.py | 15 +++++++++------ 3 files changed, 10 insertions(+), 26 deletions(-) delete mode 100644 examples/chestXray_image_generation_VAE.ipynb diff --git a/examples/chestXray_image_generation_VAE.ipynb b/examples/chestXray_image_generation_VAE.ipynb deleted file mode 100644 index 1742c903b..000000000 --- a/examples/chestXray_image_generation_VAE.ipynb +++ /dev/null @@ -1,19 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "fe32019b", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/core/test_base_image_dataset.py b/tests/core/test_base_image_dataset.py index 4d751591c..38c7f7d3e 100644 --- a/tests/core/test_base_image_dataset.py +++ b/tests/core/test_base_image_dataset.py @@ -2,7 +2,7 @@ Unit tests for the BaseImageDataset class. Author: - Kilo Code + Giovanni M. Dall'Olio, GMD Bioinformatics """ import unittest from unittest.mock import patch, MagicMock diff --git a/tests/core/test_covid19_cxr.py b/tests/core/test_covid19_cxr.py index 41eba3a1d..1ca8a911e 100644 --- a/tests/core/test_covid19_cxr.py +++ b/tests/core/test_covid19_cxr.py @@ -2,7 +2,7 @@ Unit tests for the COVID19CXRDataset class. Author: - Kilo Code + Giovanni M. Dall'Olio, GMD Bioinformatics """ import os import shutil @@ -11,7 +11,6 @@ from unittest.mock import patch from pyhealth.datasets import COVID19CXRDataset -from pyhealth.tasks import COVID19CXRClassification class TestCOVID19CXRDataset(unittest.TestCase): @@ -66,15 +65,19 @@ def test_init_works_when_raw_data_present(self, mock_base_init, mock_check): # Mock the metadata file as existing metadata_path = os.path.join(self.temp_dir, "covid19_cxr-metadata-pyhealth.csv") with open(metadata_path, 'w') as f: - f.write("path\nfake_path.png") + f.write("""path,url,label +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-1.png,https://sirm.org/category/senza-categoria/covid-19/,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-2.png,https://sirm.org/category/senza-categoria/covid-19/,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-3.png,https://sirm.org/category/senza-categoria/covid-19/,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-167.png,https://github.com/ml-workgroup/covid-19-image-repository/tree/master/png,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-336.png,https://eurorad.org,COVID +/home/ubuntu/Downloads/COVID-19_Radiography_Dataset//COVID/images/COVID-967.png,https://github.com/ieee8023/covid-chestxray-dataset,COVID""" + ) dataset = COVID19CXRDataset(root=self.temp_dir) # Check that BaseDataset.__init__ was called mock_base_init.assert_called_once() - # Removed test_default_task as it requires instantiation - - if __name__ == "__main__": unittest.main() \ No newline at end of file From d0bf7067c1ad5d8caa9175a001efb07cc5084776 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 18:07:09 +0000 Subject: [PATCH 4/9] supporting path expansion --- pyhealth/datasets/covid19_cxr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py index a043b4b36..e612b863f 100644 --- a/pyhealth/datasets/covid19_cxr.py +++ b/pyhealth/datasets/covid19_cxr.py @@ -89,6 +89,9 @@ def __init__( dataset_name: Optional[str] = None, config_path: Optional[str] = None, ) -> None: + # Expand user path (e.g., ~/Downloads -> /home/user/Downloads) + root = os.path.expanduser(root) + if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "covid19_cxr.yaml" From b4b091fa284527f6dfd15cf5f4720b4df6df5227 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 18:12:45 +0000 Subject: [PATCH 5/9] expanding path in base class --- pyhealth/datasets/base_image_dataset.py | 3 +++ pyhealth/datasets/covid19_cxr.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_image_dataset.py b/pyhealth/datasets/base_image_dataset.py index fabb47b8e..78c578e2d 100644 --- a/pyhealth/datasets/base_image_dataset.py +++ b/pyhealth/datasets/base_image_dataset.py @@ -33,6 +33,9 @@ def __init__( config_path: Optional[str] = None, **kwargs, ) -> None: + # Expand user path (e.g., ~/Downloads -> /home/user/Downloads) + root = os.path.expanduser(root) + super().__init__( root=root, dataset_name=dataset_name or "base_image", diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py index e612b863f..a043b4b36 100644 --- a/pyhealth/datasets/covid19_cxr.py +++ b/pyhealth/datasets/covid19_cxr.py @@ -89,9 +89,6 @@ def __init__( dataset_name: Optional[str] = None, config_path: Optional[str] = None, ) -> None: - # Expand user path (e.g., ~/Downloads -> /home/user/Downloads) - root = os.path.expanduser(root) - if config_path is None: logger.info("No config path provided, using default config") config_path = Path(__file__).parent / "configs" / "covid19_cxr.yaml" From 3c8879dabee47f80d911dd269ca715f55dcc2fe0 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 18:22:39 +0000 Subject: [PATCH 6/9] added openpyxl for reading excel files --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ceedcdd0b..339c71429 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "pandas~=2.3.1", "pandarallel~=1.6.5", "pydantic~=2.11.7", + "openpyxl~=3.1.0" ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 495ad31b06e4eb64d898d6e0379bb2a801d92b46 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 18:23:17 +0000 Subject: [PATCH 7/9] attribution --- tests/core/test_base_image_dataset.py | 2 +- tests/core/test_covid19_cxr.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/core/test_base_image_dataset.py b/tests/core/test_base_image_dataset.py index 38c7f7d3e..e845ec857 100644 --- a/tests/core/test_base_image_dataset.py +++ b/tests/core/test_base_image_dataset.py @@ -2,7 +2,7 @@ Unit tests for the BaseImageDataset class. Author: - Giovanni M. Dall'Olio, GMD Bioinformatics + Giovanni M. Dall'Olio """ import unittest from unittest.mock import patch, MagicMock diff --git a/tests/core/test_covid19_cxr.py b/tests/core/test_covid19_cxr.py index 1ca8a911e..466c9bfa0 100644 --- a/tests/core/test_covid19_cxr.py +++ b/tests/core/test_covid19_cxr.py @@ -1,8 +1,12 @@ """ Unit tests for the COVID19CXRDataset class. +The tests simply check that the dataset initialization works as expected, +creating mock files and ensuring COVID19CXRDataset behaves correctly when +raw data files are present or absent. + Author: - Giovanni M. Dall'Olio, GMD Bioinformatics + Giovanni M. Dall'Olio """ import os import shutil From 74f9a1448dee9c780fa1fe5223507e78c26d36d1 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 18:24:05 +0000 Subject: [PATCH 8/9] transformed script into notebook --- examples/covid19cxr_conformal.ipynb | 459 ++++++++++++++++++++++++++++ examples/covid19cxr_conformal.py | 316 ------------------- 2 files changed, 459 insertions(+), 316 deletions(-) create mode 100644 examples/covid19cxr_conformal.ipynb delete mode 100644 examples/covid19cxr_conformal.py diff --git a/examples/covid19cxr_conformal.ipynb b/examples/covid19cxr_conformal.ipynb new file mode 100644 index 000000000..686bf94e4 --- /dev/null +++ b/examples/covid19cxr_conformal.ipynb @@ -0,0 +1,459 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conformal Prediction for COVID-19 Chest X-Ray Classification\n", + "\n", + "This example demonstrates:\n", + "1. Training a ResNet-18 model on COVID-19 CXR dataset\n", + "2. Conventional conformal prediction using LABEL\n", + "3. Covariate shift adaptive conformal prediction using CovariateLabel\n", + "4. Comparison of coverage and efficiency between the two methods" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/PyHealth/pyhealth/sampler/sage_sampler.py:3: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " import pkg_resources\n", + "/home/ubuntu/PyHealth/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "This is a warning of potentially slow compute. You could uncomment this line and use the Python implementation instead of Cython.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.calib.predictionset import LABEL\n", + "from pyhealth.calib.predictionset.covariate import CovariateLabel\n", + "from pyhealth.calib.utils import extract_embeddings\n", + "from pyhealth.datasets import (\n", + " COVID19CXRDataset,\n", + " get_dataloader,\n", + " split_by_sample_conformal,\n", + ")\n", + "from pyhealth.models import TorchvisionModel\n", + "from pyhealth.trainer import Trainer, get_metrics_fn\n", + "\n", + "# Set random seed for reproducibility\n", + "torch.manual_seed(42)\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 1: Load and prepare dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "STEP 1: Loading COVID-19 CXR Dataset\n", + "================================================================================\n", + "No config path provided, using default config\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Raw COVID-19 CXR dataset files not found in ~/Downloads/COVID-19_Radiography_Dataset. Please download the dataset from https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database and extract the contents to the specified root directory.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mValueError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33m=\u001b[39m\u001b[33m\"\u001b[39m * \u001b[32m80\u001b[39m)\n\u001b[32m 5\u001b[39m root = \u001b[33m\"\u001b[39m\u001b[33m~/Downloads/COVID-19_Radiography_Dataset\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m base_dataset = \u001b[43mCOVID19CXRDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 7\u001b[39m sample_dataset = base_dataset.set_task(cache_dir=\u001b[33m\"\u001b[39m\u001b[33m../../covid19cxr_cache\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 9\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mTotal samples: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(sample_dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/covid19_cxr.py:96\u001b[39m, in \u001b[36mCOVID19CXRDataset.__init__\u001b[39m\u001b[34m(self, root, dataset_name, config_path)\u001b[39m\n\u001b[32m 94\u001b[39m config_path = Path(\u001b[34m__file__\u001b[39m).parent / \u001b[33m\"\u001b[39m\u001b[33mconfigs\u001b[39m\u001b[33m\"\u001b[39m / \u001b[33m\"\u001b[39m\u001b[33mcovid19_cxr.yaml\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 95\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m._check_raw_data_exists(root):\n\u001b[32m---> \u001b[39m\u001b[32m96\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 97\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mRaw COVID-19 CXR dataset files not found in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mroot\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 98\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mPlease download the dataset from \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 99\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mhttps://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 100\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mand extract the contents to the specified root directory.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 101\u001b[39m )\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os.path.exists(os.path.join(root, \u001b[33m\"\u001b[39m\u001b[33mcovid19_cxr-metadata-pyhealth.csv\u001b[39m\u001b[33m\"\u001b[39m)):\n\u001b[32m 103\u001b[39m \u001b[38;5;28mself\u001b[39m.prepare_metadata(root)\n", + "\u001b[31mValueError\u001b[39m: Raw COVID-19 CXR dataset files not found in ~/Downloads/COVID-19_Radiography_Dataset. Please download the dataset from https://www.kaggle.com/api/v1/datasets/download/tawsifurrahman/covid19-radiography-database and extract the contents to the specified root directory." + ] + } + ], + "source": [ + "print(\"=\" * 80)\n", + "print(\"STEP 1: Loading COVID-19 CXR Dataset\")\n", + "print(\"=\" * 80)\n", + "\n", + "root = \"~/Downloads/COVID-19_Radiography_Dataset\"\n", + "base_dataset = COVID19CXRDataset(root)\n", + "sample_dataset = base_dataset.set_task(cache_dir=\"../../covid19cxr_cache\")\n", + "\n", + "print(f\"Total samples: {len(sample_dataset)}\")\n", + "print(f\"Task mode: {sample_dataset.output_schema}\")\n", + "\n", + "# Split into train/val/cal/test\n", + "# For conformal prediction, we need a separate calibration set\n", + "train_data, val_data, cal_data, test_data = split_by_sample_conformal(\n", + " dataset=sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15]\n", + ")\n", + "\n", + "print(f\"Train: {len(train_data)}\")\n", + "print(f\"Val: {len(val_data)}\")\n", + "print(f\"Cal: {len(cal_data)} (for conformal calibration)\")\n", + "print(f\"Test: {len(test_data)}\")\n", + "\n", + "# Create data loaders\n", + "train_loader = get_dataloader(train_data, batch_size=32, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=32, shuffle=False)\n", + "cal_loader = get_dataloader(cal_data, batch_size=32, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=32, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 2: Train ResNet-18 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 2: Training ResNet-18 Model\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Initialize ResNet-18 with pretrained weights\n", + "resnet = TorchvisionModel(\n", + " dataset=sample_dataset,\n", + " model_name=\"resnet18\",\n", + " model_config={\"weights\": \"DEFAULT\"},\n", + ")\n", + "\n", + "# Train the model\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "trainer = Trainer(model=resnet, device=device)\n", + "\n", + "print(f\"Training on device: {device}\")\n", + "trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " epochs=5,\n", + " monitor=\"accuracy\",\n", + ")\n", + "\n", + "print(\"✓ Model training completed\")\n", + "\n", + "# Evaluate base model on test set\n", + "print(\"\\nBase model performance on test set:\")\n", + "y_true_base, y_prob_base, loss_base = trainer.inference(test_loader)\n", + "base_metrics = get_metrics_fn(\"multiclass\")(\n", + " y_true_base, y_prob_base, metrics=[\"accuracy\", \"f1_weighted\"]\n", + ")\n", + "for metric, value in base_metrics.items():\n", + " print(f\" {metric}: {value:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 3: Conventional Conformal Prediction with LABEL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 3: Conventional Conformal Prediction (LABEL)\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Target miscoverage rate of 10% (90% coverage)\n", + "alpha = 0.1\n", + "print(f\"Target miscoverage rate: {alpha} (90% coverage)\")\n", + "\n", + "# Create LABEL predictor\n", + "label_predictor = LABEL(model=resnet, alpha=alpha)\n", + "\n", + "# Calibrate on calibration set\n", + "print(\"Calibrating LABEL predictor...\")\n", + "label_predictor.calibrate(cal_dataset=cal_data)\n", + "\n", + "# Evaluate on test set\n", + "print(\"Evaluating LABEL predictor on test set...\")\n", + "y_true_label, y_prob_label, _, extra_label = Trainer(model=label_predictor).inference(\n", + " test_loader, additional_outputs=[\"y_predset\"]\n", + ")\n", + "\n", + "label_metrics = get_metrics_fn(\"multiclass\")(\n", + " y_true_label,\n", + " y_prob_label,\n", + " metrics=[\"accuracy\", \"miscoverage_ps\"],\n", + " y_predset=extra_label[\"y_predset\"],\n", + ")\n", + "\n", + "# Calculate average set size\n", + "predset_label = (\n", + " torch.tensor(extra_label[\"y_predset\"])\n", + " if isinstance(extra_label[\"y_predset\"], np.ndarray)\n", + " else extra_label[\"y_predset\"]\n", + ")\n", + "avg_set_size_label = predset_label.float().sum(dim=1).mean().item()\n", + "\n", + "# Extract scalar values from metrics (handle both scalar and array returns)\n", + "miscoverage_label = label_metrics[\"miscoverage_ps\"]\n", + "if isinstance(miscoverage_label, np.ndarray):\n", + " miscoverage_label = float(\n", + " miscoverage_label.item()\n", + " if miscoverage_label.size == 1\n", + " else miscoverage_label.mean()\n", + " )\n", + "else:\n", + " miscoverage_label = float(miscoverage_label)\n", + "\n", + "print(\"\\nLABEL Results:\")\n", + "print(f\" Accuracy: {label_metrics['accuracy']:.4f}\")\n", + "print(f\" Empirical miscoverage: {miscoverage_label:.4f}\")\n", + "print(f\" Average set size: {avg_set_size_label:.2f}\")\n", + "print(f\" Target miscoverage: {alpha:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 4: Covariate Shift Adaptive Conformal Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 4: Covariate Shift Adaptive Conformal Prediction\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Extract embeddings from the model\n", + "# For TorchvisionModel, we extract features from avgpool layer (before fc)\n", + "print(\"Extracting embeddings from calibration set...\")\n", + "cal_embeddings = extract_embeddings(resnet, cal_data, batch_size=32, device=device)\n", + "print(f\" Cal embeddings shape: {cal_embeddings.shape}\")\n", + "\n", + "print(\"Extracting embeddings from test set...\")\n", + "test_embeddings = extract_embeddings(resnet, test_data, batch_size=32, device=device)\n", + "print(f\" Test embeddings shape: {test_embeddings.shape}\")\n", + "\n", + "# Create CovariateLabel predictor\n", + "print(\"\\nCreating CovariateLabel predictor...\")\n", + "covariate_predictor = CovariateLabel(model=resnet, alpha=alpha)\n", + "\n", + "# Calibrate with embeddings (KDEs will be fitted automatically)\n", + "print(\"Calibrating CovariateLabel predictor...\")\n", + "print(\" - Fitting KDEs for covariate shift correction...\")\n", + "covariate_predictor.calibrate(\n", + " cal_dataset=cal_data, cal_embeddings=cal_embeddings, test_embeddings=test_embeddings\n", + ")\n", + "print(\"✓ Calibration completed\")\n", + "\n", + "# Evaluate on test set\n", + "print(\"Evaluating CovariateLabel predictor on test set...\")\n", + "y_true_cov, y_prob_cov, _, extra_cov = Trainer(model=covariate_predictor).inference(\n", + " test_loader, additional_outputs=[\"y_predset\"]\n", + ")\n", + "\n", + "cov_metrics = get_metrics_fn(\"multiclass\")(\n", + " y_true_cov,\n", + " y_prob_cov,\n", + " metrics=[\"accuracy\", \"miscoverage_ps\"],\n", + " y_predset=extra_cov[\"y_predset\"],\n", + ")\n", + "\n", + "# Calculate average set size\n", + "predset_cov = (\n", + " torch.tensor(extra_cov[\"y_predset\"])\n", + " if isinstance(extra_cov[\"y_predset\"], np.ndarray)\n", + " else extra_cov[\"y_predset\"]\n", + ")\n", + "avg_set_size_cov = predset_cov.float().sum(dim=1).mean().item()\n", + "\n", + "# Extract scalar values from metrics (handle both scalar and array returns)\n", + "miscoverage_cov = cov_metrics[\"miscoverage_ps\"]\n", + "if isinstance(miscoverage_cov, np.ndarray):\n", + " miscoverage_cov = float(\n", + " miscoverage_cov.item() if miscoverage_cov.size == 1 else miscoverage_cov.mean()\n", + " )\n", + "else:\n", + " miscoverage_cov = float(miscoverage_cov)\n", + "\n", + "print(\"\\nCovariateLabel Results:\")\n", + "print(f\" Accuracy: {cov_metrics['accuracy']:.4f}\")\n", + "print(f\" Empirical miscoverage: {miscoverage_cov:.4f}\")\n", + "print(f\" Average set size: {avg_set_size_cov:.2f}\")\n", + "print(f\" Target miscoverage: {alpha:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 5: Compare Methods" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 5: Comparison of Methods\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(f\"\\nTarget: {1-alpha:.0%} coverage (max {alpha:.0%} miscoverage)\")\n", + "print(\"\\n{:<40} {:<15} {:<15}\".format(\"Metric\", \"LABEL\", \"CovariateLabel\"))\n", + "print(\"-\" * 70)\n", + "\n", + "# Coverage (1 - miscoverage)\n", + "label_coverage = 1 - miscoverage_label\n", + "cov_coverage = 1 - miscoverage_cov\n", + "print(\n", + " \"{:<40} {:<15.2%} {:<15.2%}\".format(\n", + " \"Empirical Coverage\", label_coverage, cov_coverage\n", + " )\n", + ")\n", + "\n", + "# Miscoverage\n", + "print(\n", + " \"{:<40} {:<15.4f} {:<15.4f}\".format(\n", + " \"Empirical Miscoverage\",\n", + " miscoverage_label,\n", + " miscoverage_cov,\n", + " )\n", + ")\n", + "\n", + "# Average set size (smaller is better for efficiency)\n", + "print(\n", + " \"{:<40} {:<15.2f} {:<15.2f}\".format(\n", + " \"Average Set Size\", avg_set_size_label, avg_set_size_cov\n", + " )\n", + ")\n", + "\n", + "# Efficiency (inverse of average set size)\n", + "efficiency_label = 1.0 / avg_set_size_label\n", + "efficiency_cov = 1.0 / avg_set_size_cov\n", + "print(\n", + " \"{:<40} {:<15.4f} {:<15.4f}\".format(\n", + " \"Efficiency (1/avg_set_size)\", efficiency_label, efficiency_cov\n", + " )\n", + ")\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"Summary\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nKey Observations:\")\n", + "print(\"1. Both methods achieve near-target coverage guarantees\")\n", + "print(\"2. LABEL: Standard conformal prediction\")\n", + "print(\"3. CovariateLabel: Adapts to distribution shift between cal and test\")\n", + "print(\"\\nWhen to use CovariateLabel:\")\n", + "print(\" - When test distribution differs from calibration distribution\")\n", + "print(\" - When you have access to test embeddings/features\")\n", + "print(\" - When you want more robust coverage under distribution shift\")\n", + "print(\"\\nWhen to use LABEL:\")\n", + "print(\" - When cal and test distributions are similar (exchangeable)\")\n", + "print(\" - Simpler method, no need to fit KDEs\")\n", + "print(\" - Computationally more efficient\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STEP 6: Visualize Prediction Sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 6: Example Predictions\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Show first 5 test examples\n", + "n_examples = 5\n", + "print(f\"\\nShowing first {n_examples} test examples:\")\n", + "print(\"-\" * 80)\n", + "\n", + "for i in range(min(n_examples, len(y_true_label))):\n", + " true_class = int(y_true_label[i])\n", + "\n", + " # LABEL prediction set\n", + " if isinstance(predset_label, np.ndarray):\n", + " label_set = np.where(predset_label[i])[0]\n", + " else:\n", + " label_set = torch.where(predset_label[i])[0].cpu().numpy()\n", + "\n", + " # CovariateLabel prediction set\n", + " if isinstance(predset_cov, np.ndarray):\n", + " cov_set = np.where(predset_cov[i])[0]\n", + " else:\n", + " cov_set = torch.where(predset_cov[i])[0].cpu().numpy()\n", + "\n", + " print(f\"\\nExample {i+1}:\")\n", + " print(f\" True class: {true_class}\")\n", + " print(f\" LABEL set: {label_set.tolist()} (size: {len(label_set)})\")\n", + " print(f\" CovariateLabel set: {cov_set.tolist()} (size: {len(cov_set)})\")\n", + " print(f\" Correct in LABEL? {true_class in label_set}\")\n", + " print(f\" Correct in CovariateLabel? {true_class in cov_set}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"Example completed successfully!\")\n", + "print(\"=\" * 80)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/covid19cxr_conformal.py b/examples/covid19cxr_conformal.py deleted file mode 100644 index 4adb2dd8d..000000000 --- a/examples/covid19cxr_conformal.py +++ /dev/null @@ -1,316 +0,0 @@ -""" -Conformal Prediction for COVID-19 Chest X-Ray Classification. - -This example demonstrates: -1. Training a ResNet-18 model on COVID-19 CXR dataset -2. Conventional conformal prediction using LABEL -3. Covariate shift adaptive conformal prediction using CovariateLabel -4. Comparison of coverage and efficiency between the two methods -""" - -import numpy as np -import torch - -from pyhealth.calib.predictionset import LABEL -from pyhealth.calib.predictionset.covariate import CovariateLabel -from pyhealth.calib.utils import extract_embeddings -from pyhealth.datasets import ( - COVID19CXRDataset, - get_dataloader, - split_by_sample_conformal, -) -from pyhealth.models import TorchvisionModel -from pyhealth.trainer import Trainer, get_metrics_fn - -# Set random seed for reproducibility -torch.manual_seed(42) -np.random.seed(42) - -# ============================================================================ -# STEP 1: Load and prepare dataset -# ============================================================================ -print("=" * 80) -print("STEP 1: Loading COVID-19 CXR Dataset") -print("=" * 80) - -root = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets/COVID-19_Radiography_Dataset" -base_dataset = COVID19CXRDataset(root) -sample_dataset = base_dataset.set_task(cache_dir="../../covid19cxr_cache") - -print(f"Total samples: {len(sample_dataset)}") -print(f"Task mode: {sample_dataset.output_schema}") - -# Split into train/val/cal/test -# For conformal prediction, we need a separate calibration set -train_data, val_data, cal_data, test_data = split_by_sample_conformal( - dataset=sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] -) - -print(f"Train: {len(train_data)}") -print(f"Val: {len(val_data)}") -print(f"Cal: {len(cal_data)} (for conformal calibration)") -print(f"Test: {len(test_data)}") - -# Create data loaders -train_loader = get_dataloader(train_data, batch_size=32, shuffle=True) -val_loader = get_dataloader(val_data, batch_size=32, shuffle=False) -cal_loader = get_dataloader(cal_data, batch_size=32, shuffle=False) -test_loader = get_dataloader(test_data, batch_size=32, shuffle=False) - -# ============================================================================ -# STEP 2: Train ResNet-18 model -# ============================================================================ -print("\n" + "=" * 80) -print("STEP 2: Training ResNet-18 Model") -print("=" * 80) - -# Initialize ResNet-18 with pretrained weights -resnet = TorchvisionModel( - dataset=sample_dataset, - model_name="resnet18", - model_config={"weights": "DEFAULT"}, -) - -# Train the model -device = "cuda:0" if torch.cuda.is_available() else "cpu" -trainer = Trainer(model=resnet, device=device) - -print(f"Training on device: {device}") -trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=5, - monitor="accuracy", -) - -print("✓ Model training completed") - -# Evaluate base model on test set -print("\nBase model performance on test set:") -y_true_base, y_prob_base, loss_base = trainer.inference(test_loader) -base_metrics = get_metrics_fn("multiclass")( - y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"] -) -for metric, value in base_metrics.items(): - print(f" {metric}: {value:.4f}") - -# ============================================================================ -# STEP 3: Conventional Conformal Prediction with LABEL -# ============================================================================ -print("\n" + "=" * 80) -print("STEP 3: Conventional Conformal Prediction (LABEL)") -print("=" * 80) - -# Target miscoverage rate of 10% (90% coverage) -alpha = 0.1 -print(f"Target miscoverage rate: {alpha} (90% coverage)") - -# Create LABEL predictor -label_predictor = LABEL(model=resnet, alpha=alpha) - -# Calibrate on calibration set -print("Calibrating LABEL predictor...") -label_predictor.calibrate(cal_dataset=cal_data) - -# Evaluate on test set -print("Evaluating LABEL predictor on test set...") -y_true_label, y_prob_label, _, extra_label = Trainer(model=label_predictor).inference( - test_loader, additional_outputs=["y_predset"] -) - -label_metrics = get_metrics_fn("multiclass")( - y_true_label, - y_prob_label, - metrics=["accuracy", "miscoverage_ps"], - y_predset=extra_label["y_predset"], -) - -# Calculate average set size -predset_label = ( - torch.tensor(extra_label["y_predset"]) - if isinstance(extra_label["y_predset"], np.ndarray) - else extra_label["y_predset"] -) -avg_set_size_label = predset_label.float().sum(dim=1).mean().item() - -# Extract scalar values from metrics (handle both scalar and array returns) -miscoverage_label = label_metrics["miscoverage_ps"] -if isinstance(miscoverage_label, np.ndarray): - miscoverage_label = float( - miscoverage_label.item() - if miscoverage_label.size == 1 - else miscoverage_label.mean() - ) -else: - miscoverage_label = float(miscoverage_label) - -print("\nLABEL Results:") -print(f" Accuracy: {label_metrics['accuracy']:.4f}") -print(f" Empirical miscoverage: {miscoverage_label:.4f}") -print(f" Average set size: {avg_set_size_label:.2f}") -print(f" Target miscoverage: {alpha:.2f}") - -# ============================================================================ -# STEP 4: Covariate Shift Adaptive Conformal Prediction -# ============================================================================ -print("\n" + "=" * 80) -print("STEP 4: Covariate Shift Adaptive Conformal Prediction") -print("=" * 80) - -# Extract embeddings from the model -# For TorchvisionModel, we extract features from avgpool layer (before fc) -print("Extracting embeddings from calibration set...") -cal_embeddings = extract_embeddings(resnet, cal_data, batch_size=32, device=device) -print(f" Cal embeddings shape: {cal_embeddings.shape}") - -print("Extracting embeddings from test set...") -test_embeddings = extract_embeddings(resnet, test_data, batch_size=32, device=device) -print(f" Test embeddings shape: {test_embeddings.shape}") - -# Create CovariateLabel predictor -print("\nCreating CovariateLabel predictor...") -covariate_predictor = CovariateLabel(model=resnet, alpha=alpha) - -# Calibrate with embeddings (KDEs will be fitted automatically) -print("Calibrating CovariateLabel predictor...") -print(" - Fitting KDEs for covariate shift correction...") -covariate_predictor.calibrate( - cal_dataset=cal_data, cal_embeddings=cal_embeddings, test_embeddings=test_embeddings -) -print("✓ Calibration completed") - -# Evaluate on test set -print("Evaluating CovariateLabel predictor on test set...") -y_true_cov, y_prob_cov, _, extra_cov = Trainer(model=covariate_predictor).inference( - test_loader, additional_outputs=["y_predset"] -) - -cov_metrics = get_metrics_fn("multiclass")( - y_true_cov, - y_prob_cov, - metrics=["accuracy", "miscoverage_ps"], - y_predset=extra_cov["y_predset"], -) - -# Calculate average set size -predset_cov = ( - torch.tensor(extra_cov["y_predset"]) - if isinstance(extra_cov["y_predset"], np.ndarray) - else extra_cov["y_predset"] -) -avg_set_size_cov = predset_cov.float().sum(dim=1).mean().item() - -# Extract scalar values from metrics (handle both scalar and array returns) -miscoverage_cov = cov_metrics["miscoverage_ps"] -if isinstance(miscoverage_cov, np.ndarray): - miscoverage_cov = float( - miscoverage_cov.item() if miscoverage_cov.size == 1 else miscoverage_cov.mean() - ) -else: - miscoverage_cov = float(miscoverage_cov) - -print("\nCovariateLabel Results:") -print(f" Accuracy: {cov_metrics['accuracy']:.4f}") -print(f" Empirical miscoverage: {miscoverage_cov:.4f}") -print(f" Average set size: {avg_set_size_cov:.2f}") -print(f" Target miscoverage: {alpha:.2f}") - -# ============================================================================ -# STEP 5: Compare Methods -# ============================================================================ -print("\n" + "=" * 80) -print("STEP 5: Comparison of Methods") -print("=" * 80) - -print(f"\nTarget: {1-alpha:.0%} coverage (max {alpha:.0%} miscoverage)") -print("\n{:<40} {:<15} {:<15}".format("Metric", "LABEL", "CovariateLabel")) -print("-" * 70) - -# Coverage (1 - miscoverage) -label_coverage = 1 - miscoverage_label -cov_coverage = 1 - miscoverage_cov -print( - "{:<40} {:<15.2%} {:<15.2%}".format( - "Empirical Coverage", label_coverage, cov_coverage - ) -) - -# Miscoverage -print( - "{:<40} {:<15.4f} {:<15.4f}".format( - "Empirical Miscoverage", - miscoverage_label, - miscoverage_cov, - ) -) - -# Average set size (smaller is better for efficiency) -print( - "{:<40} {:<15.2f} {:<15.2f}".format( - "Average Set Size", avg_set_size_label, avg_set_size_cov - ) -) - -# Efficiency (inverse of average set size) -efficiency_label = 1.0 / avg_set_size_label -efficiency_cov = 1.0 / avg_set_size_cov -print( - "{:<40} {:<15.4f} {:<15.4f}".format( - "Efficiency (1/avg_set_size)", efficiency_label, efficiency_cov - ) -) - -print("\n" + "=" * 80) -print("Summary") -print("=" * 80) - -print("\nKey Observations:") -print("1. Both methods achieve near-target coverage guarantees") -print("2. LABEL: Standard conformal prediction") -print("3. CovariateLabel: Adapts to distribution shift between cal and test") -print("\nWhen to use CovariateLabel:") -print(" - When test distribution differs from calibration distribution") -print(" - When you have access to test embeddings/features") -print(" - When you want more robust coverage under distribution shift") -print("\nWhen to use LABEL:") -print(" - When cal and test distributions are similar (exchangeable)") -print(" - Simpler method, no need to fit KDEs") -print(" - Computationally more efficient") - -# ============================================================================ -# STEP 6: Visualize Prediction Sets -# ============================================================================ -print("\n" + "=" * 80) -print("STEP 6: Example Predictions") -print("=" * 80) - -# Show first 5 test examples -n_examples = 5 -print(f"\nShowing first {n_examples} test examples:") -print("-" * 80) - -for i in range(min(n_examples, len(y_true_label))): - true_class = int(y_true_label[i]) - - # LABEL prediction set - if isinstance(predset_label, np.ndarray): - label_set = np.where(predset_label[i])[0] - else: - label_set = torch.where(predset_label[i])[0].cpu().numpy() - - # CovariateLabel prediction set - if isinstance(predset_cov, np.ndarray): - cov_set = np.where(predset_cov[i])[0] - else: - cov_set = torch.where(predset_cov[i])[0].cpu().numpy() - - print(f"\nExample {i+1}:") - print(f" True class: {true_class}") - print(f" LABEL set: {label_set.tolist()} (size: {len(label_set)})") - print(f" CovariateLabel set: {cov_set.tolist()} (size: {len(cov_set)})") - print(f" Correct in LABEL? {true_class in label_set}") - print(f" Correct in CovariateLabel? {true_class in cov_set}") - -print("\n" + "=" * 80) -print("Example completed successfully!") -print("=" * 80) From 9abd437e86f558a5c727af074e333e6a0907fcd2 Mon Sep 17 00:00:00 2001 From: Giovanni Dall'Olio Date: Fri, 21 Nov 2025 21:03:53 +0000 Subject: [PATCH 9/9] fixed missing import --- pyhealth/datasets/base_image_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/datasets/base_image_dataset.py b/pyhealth/datasets/base_image_dataset.py index 78c578e2d..361871a64 100644 --- a/pyhealth/datasets/base_image_dataset.py +++ b/pyhealth/datasets/base_image_dataset.py @@ -1,4 +1,5 @@ import logging +import os from typing import Dict, Optional from pyhealth.datasets.base_dataset import BaseDataset