diff --git a/.gitignore b/.gitignore index 0f42621..6b4e173 100644 --- a/.gitignore +++ b/.gitignore @@ -154,3 +154,18 @@ data/ #Models models/ +*.pkl +*.pth + +# Claude +.claude/ +CLAUDE*.md + +# nbdev processed notebooks +_proc/ + +# Cline +.clinerules + +# MLflow +mlruns/ diff --git a/fastMONAI/__init__.py b/fastMONAI/__init__.py index 43a1e95..6b27eee 100644 --- a/fastMONAI/__init__.py +++ b/fastMONAI/__init__.py @@ -1 +1 @@ -__version__ = "0.5.3" +__version__ = "0.5.4" diff --git a/fastMONAI/_modidx.py b/fastMONAI/_modidx.py index 9562d1a..4e7f772 100644 --- a/fastMONAI/_modidx.py +++ b/fastMONAI/_modidx.py @@ -12,12 +12,18 @@ 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.MedDataset._get_data_info': ( 'dataset_info.html#meddataset._get_data_info', 'fastMONAI/dataset_info.py'), + 'fastMONAI.dataset_info.MedDataset._visualize_single_case': ( 'dataset_info.html#meddataset._visualize_single_case', + 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.MedDataset.get_largest_img_size': ( 'dataset_info.html#meddataset.get_largest_img_size', 'fastMONAI/dataset_info.py'), + 'fastMONAI.dataset_info.MedDataset.get_volume_summary': ( 'dataset_info.html#meddataset.get_volume_summary', + 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.MedDataset.suggestion': ( 'dataset_info.html#meddataset.suggestion', 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.MedDataset.summary': ( 'dataset_info.html#meddataset.summary', 'fastMONAI/dataset_info.py'), + 'fastMONAI.dataset_info.MedDataset.visualize_cases': ( 'dataset_info.html#meddataset.visualize_cases', + 'fastMONAI/dataset_info.py'), 'fastMONAI.dataset_info.get_class_weights': ( 'dataset_info.html#get_class_weights', 'fastMONAI/dataset_info.py')}, 'fastMONAI.external_data': { 'fastMONAI.external_data.MURLs': ('external_data.html#murls', 'fastMONAI/external_data.py'), @@ -202,33 +208,13 @@ 'fastMONAI.vision_core.MedImage': ('vision_core.html#medimage', 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core.MedMask': ('vision_core.html#medmask', 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core.MetaResolver': ('vision_core.html#metaresolver', 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback': ( 'vision_core.html#vscodeprogresscallback', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.__init__': ( 'vision_core.html#vscodeprogresscallback.__init__', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback._detect_vscode_environment': ( 'vision_core.html#vscodeprogresscallback._detect_vscode_environment', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.after_batch': ( 'vision_core.html#vscodeprogresscallback.after_batch', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.after_fit': ( 'vision_core.html#vscodeprogresscallback.after_fit', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.after_validate': ( 'vision_core.html#vscodeprogresscallback.after_validate', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.before_epoch': ( 'vision_core.html#vscodeprogresscallback.before_epoch', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.before_fit': ( 'vision_core.html#vscodeprogresscallback.before_fit', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.VSCodeProgressCallback.before_validate': ( 'vision_core.html#vscodeprogresscallback.before_validate', - 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core._load_and_preprocess': ( 'vision_core.html#_load_and_preprocess', 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core._multi_channel': ( 'vision_core.html#_multi_channel', 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core._preprocess': ('vision_core.html#_preprocess', 'fastMONAI/vision_core.py'), 'fastMONAI.vision_core.med_img_reader': ( 'vision_core.html#med_img_reader', - 'fastMONAI/vision_core.py'), - 'fastMONAI.vision_core.setup_vscode_progress': ( 'vision_core.html#setup_vscode_progress', - 'fastMONAI/vision_core.py')}, + 'fastMONAI/vision_core.py')}, 'fastMONAI.vision_data': { 'fastMONAI.vision_data.MedDataBlock': ('vision_data.html#meddatablock', 'fastMONAI/vision_data.py'), 'fastMONAI.vision_data.MedDataBlock.__init__': ( 'vision_data.html#meddatablock.__init__', 'fastMONAI/vision_data.py'), @@ -285,14 +271,36 @@ 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.binary_hausdorff_distance': ( 'vision_metrics.html#binary_hausdorff_distance', 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.binary_lesion_detection_rate': ( 'vision_metrics.html#binary_lesion_detection_rate', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.binary_precision': ( 'vision_metrics.html#binary_precision', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.binary_sensitivity': ( 'vision_metrics.html#binary_sensitivity', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.binary_signed_rve': ( 'vision_metrics.html#binary_signed_rve', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.calculate_confusion_metrics': ( 'vision_metrics.html#calculate_confusion_metrics', + 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.calculate_dsc': ( 'vision_metrics.html#calculate_dsc', 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.calculate_haus': ( 'vision_metrics.html#calculate_haus', 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.calculate_lesion_detection_rate': ( 'vision_metrics.html#calculate_lesion_detection_rate', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.calculate_signed_rve': ( 'vision_metrics.html#calculate_signed_rve', + 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.multi_dice_score': ( 'vision_metrics.html#multi_dice_score', 'fastMONAI/vision_metrics.py'), 'fastMONAI.vision_metrics.multi_hausdorff_distance': ( 'vision_metrics.html#multi_hausdorff_distance', - 'fastMONAI/vision_metrics.py')}, + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.multi_lesion_detection_rate': ( 'vision_metrics.html#multi_lesion_detection_rate', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.multi_precision': ( 'vision_metrics.html#multi_precision', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.multi_sensitivity': ( 'vision_metrics.html#multi_sensitivity', + 'fastMONAI/vision_metrics.py'), + 'fastMONAI.vision_metrics.multi_signed_rve': ( 'vision_metrics.html#multi_signed_rve', + 'fastMONAI/vision_metrics.py')}, 'fastMONAI.vision_plot': { 'fastMONAI.vision_plot._get_slice': ('vision_plot.html#_get_slice', 'fastMONAI/vision_plot.py'), 'fastMONAI.vision_plot.find_max_slice': ( 'vision_plot.html#find_max_slice', 'fastMONAI/vision_plot.py'), diff --git a/fastMONAI/dataset_info.py b/fastMONAI/dataset_info.py index 202c02b..339d6ba 100644 --- a/fastMONAI/dataset_info.py +++ b/fastMONAI/dataset_info.py @@ -5,6 +5,7 @@ # %% ../nbs/08_dataset_info.ipynb 2 from .vision_core import * +from .vision_plot import find_max_slice from sklearn.utils.class_weight import compute_class_weight from concurrent.futures import ThreadPoolExecutor @@ -12,28 +13,22 @@ import numpy as np import torch import glob +import matplotlib.pyplot as plt -# %% ../nbs/08_dataset_info.ipynb 4 +# %% ../nbs/08_dataset_info.ipynb 3 class MedDataset: """A class to extract and present information about the dataset.""" - def __init__(self, path=None, postfix: str = '', img_list: list = None, + def __init__(self, dataframe=None, img_col:str =None, mask_col:str ="mask_path", path=None, postfix: str = '', reorder: bool = False, dtype: (MedImage, MedMask) = MedImage, max_workers: int = 1): - """Constructs MedDataset object. + """Constructs MedDataset object.""" - Args: - path (str, optional): Path to the image folder. - postfix (str, optional): Specify the file type if there are different files in the folder. - img_list (List[str], optional): Alternatively, pass in a list with image paths. - reorder (bool, optional): Whether to reorder the data to be closest to canonical (RAS+) orientation. - dtype (Union[MedImage, MedMask], optional): Load data as datatype. Default is MedImage. - max_workers (int, optional): The number of worker threads. Default is 1. - """ - + self.input_df = dataframe + self.img_col = img_col + self.mask_col = mask_col self.path = path self.postfix = postfix - self.img_list = img_list self.reorder = reorder self.dtype = dtype self.max_workers = max_workers @@ -41,19 +36,34 @@ def __init__(self, path=None, postfix: str = '', img_list: list = None, def _create_data_frame(self): """Private method that returns a dataframe with information about the dataset.""" - + + # Handle path-based initialization (legacy mode) if self.path: - self.img_list = glob.glob(f'{self.path}/*{self.postfix}*') - if not self.img_list: print('Could not find images. Check the image path') + img_list = glob.glob(f'{self.path}/*{self.postfix}*') + if not img_list: + print('Could not find images. Check the image path') + return pd.DataFrame() + + # Handle dataframe-based initialization (new mode) + elif self.input_df is not None and self.mask_col in self.input_df.columns: + img_list = self.input_df[self.mask_col].tolist() + + else: + print('Error: Must provide either path or dataframe with mask_col') + return pd.DataFrame() + # Process images to extract metadata with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - data_info_dict = list(executor.map(self._get_data_info, self.img_list)) + data_info_dict = list(executor.map(self._get_data_info, img_list)) df = pd.DataFrame(data_info_dict) - if df.orientation.nunique() > 1: - print('The volumes in this dataset have different orientations. ' - 'Recommended to pass in the argument reorder=True when creating a MedDataset object for this dataset') + if len(df) > 0 and df.orientation.nunique() > 1 and not self.reorder: + raise ValueError( + 'Mixed orientations detected in dataset. ' + 'Please recreate MedDataset with reorder=True to get correct resample values: ' + 'MedDataset(..., reorder=True)' + ) return df @@ -74,18 +84,32 @@ def suggestion(self): def _get_data_info(self, fn: str): """Private method to collect information about an image file.""" - _, o, _ = med_img_reader(fn, reorder=self.reorder, only_tensor=False, dtype=self.dtype) - - info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3], - 'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4), - 'orientation': f'{"".join(o.orientation)}+'} - - if self.dtype is MedMask: - mask_labels_dict = o.count_labels() - mask_labels_dict = {f'voxel_count_{int(key)}': val for key, val in mask_labels_dict.items()} - info_dict.update(mask_labels_dict) - - return info_dict + try: + _, o, _ = med_img_reader(fn, reorder=self.reorder, only_tensor=False, dtype=self.dtype) + + info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3], + 'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4), + 'orientation': f'{"".join(o.orientation)}+'} + + if self.dtype is MedMask: + # Calculate voxel volume in mm³ + voxel_volume = o.spacing[0] * o.spacing[1] * o.spacing[2] + + # Get voxel counts for each label + mask_labels_dict = o.count_labels() + + # Calculate volumes for each label > 0 (skip background) + for key, voxel_count in mask_labels_dict.items(): + label_int = int(key) + if label_int > 0 and voxel_count > 0: # Skip background (label 0) + volume_mm3 = voxel_count * voxel_volume + info_dict[f'label_{label_int}_volume_mm3'] = round(volume_mm3, 4) + + return info_dict + + except Exception as e: + print(f"Warning: Failed to process {fn}: {e}") + return {'path': fn, 'error': str(e)} def get_largest_img_size(self, resample: list = None) -> list: """Get the largest image size in the dataset.""" @@ -105,6 +129,102 @@ def get_largest_img_size(self, resample: list = None) -> list: return dims + def get_volume_summary(self): + """Get summary statistics for volume columns.""" + volume_cols = [col for col in self.df.columns if col.endswith('_volume_mm3')] + + if not volume_cols: + print("No volume columns found. Make sure dtype=MedMask when creating the dataset.") + return None + + print("📊 Volume Summary:") + print("=" * 50) + + for col in volume_cols: + # Get non-zero volumes + non_zero_volumes = self.df[self.df[col] > 0][col] + + if len(non_zero_volumes) > 0: + print(f"\n{col}:") + print(f" Cases with volume: {len(non_zero_volumes)}") + print(f" Mean volume: {non_zero_volumes.mean():.2f} mm³") + print(f" Median volume: {non_zero_volumes.median():.2f} mm³") + print(f" Min volume: {non_zero_volumes.min():.2f} mm³") + print(f" Max volume: {non_zero_volumes.max():.2f} mm³") + else: + print(f"\n{col}: No cases with volume > 0") + + def _visualize_single_case(self, img_path, mask_path, case_id, anatomical_plane=2, cmap='hot', figsize=(12, 5)): + """Helper method to visualize a single case.""" + try: + # Create MedImage and MedMask with current preprocessing settings + resample, reorder = self.suggestion() + MedBase.item_preprocessing(resample=resample, reorder=reorder) + + img = MedImage.create(img_path) + mask = MedMask.create(mask_path) + + # Find optimal slice using explicit function + mask_data = mask.numpy()[0] # Remove channel dimension + optimal_slice = find_max_slice(mask_data, anatomical_plane) + + # Create subplot + fig, axes = plt.subplots(1, 2, figsize=figsize) + + # Show image + img.show(ctx=axes[0], anatomical_plane=anatomical_plane, slice_index=optimal_slice) + axes[0].set_title(f"{case_id} - Image (slice {optimal_slice})") + + # Show overlay + img.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice) + mask.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice, + alpha=0.3, cmap=cmap) + axes[1].set_title(f"{case_id} - Overlay (slice {optimal_slice})") + + # Adjust spacing to bring plots closer + plt.subplots_adjust(wspace=0.1) + plt.tight_layout() + plt.show() + + except Exception as e: + print(f"❌ Failed to visualize case {case_id}: {e}") + + def visualize_cases(self, n_cases=4, anatomical_plane=2, cmap='hot', figsize=(12, 5)): + """ + Visualize cases from the dataset. + + Args: + n_cases: Number of cases to show. If None, shows all cases. + anatomical_plane: 0=sagittal, 1=coronal, 2=axial + cmap: Colormap for mask overlay + figsize: Figure size for each case + """ + if self.input_df is None: + print("Error: No dataframe provided. Cannot visualize cases.") + return + + if self.img_col is None: + print("Error: No img_col specified. Cannot visualize cases.") + return + + # Check if required columns exist + if self.img_col not in self.input_df.columns: + print(f"Error: Column '{self.img_col}' not found in dataframe.") + return + + if self.mask_col not in self.input_df.columns: + print(f"Error: Column '{self.mask_col}' not found in dataframe.") + return + + for idx in range(min(n_cases, len(self.input_df))): + row = self.input_df.iloc[idx] + case_id = row.get('case_id', f'Case_{idx}') # Fallback if no case_id + img_path = row[self.img_col] + mask_path = row[self.mask_col] + + self._visualize_single_case(img_path, mask_path, case_id, anatomical_plane, cmap, figsize) + print("-" * 60) + # %% ../nbs/08_dataset_info.ipynb 5 def get_class_weights(labels: (np.array, list), class_weight: str = 'balanced') -> torch.Tensor: """Calculates and returns the class weights. diff --git a/fastMONAI/vision_augmentation.py b/fastMONAI/vision_augmentation.py index 445c82e..857e08c 100644 --- a/fastMONAI/vision_augmentation.py +++ b/fastMONAI/vision_augmentation.py @@ -212,7 +212,11 @@ def __init__(self, intensity=(0.5, 1), p=0.5): self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p) def encodes(self, o: MedImage): - return MedImage.create(self.add_ghosts(o)) + result = self.add_ghosts(o) + # Handle potential complex values from k-space operations + if result.is_complex(): + result = torch.real(result) + return MedImage.create(result) def encodes(self, o: MedMask): return o @@ -227,7 +231,11 @@ def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5): self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p) def encodes(self, o:MedImage): - return MedImage.create(self.add_spikes(o)) + result = self.add_spikes(o) + # Handle potential complex values from k-space operations + if result.is_complex(): + result = torch.real(result) + return MedImage.create(result) def encodes(self, o:MedMask): return o @@ -316,7 +324,11 @@ def __init__( ) def encodes(self, o: MedImage): - return MedImage.create(self.add_motion(o)) + result = self.add_motion(o) + # Handle potential complex values from k-space operations + if result.is_complex(): + result = torch.real(result) + return MedImage.create(result) def encodes(self, o: MedMask): return o diff --git a/fastMONAI/vision_core.py b/fastMONAI/vision_core.py index 8c080c1..fca6739 100644 --- a/fastMONAI/vision_core.py +++ b/fastMONAI/vision_core.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_vision_core.ipynb. # %% auto 0 -__all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask', 'VSCodeProgressCallback', 'setup_vscode_progress'] +__all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask'] # %% ../nbs/01_vision_core.ipynb 2 from .vision_plot import * @@ -247,106 +247,3 @@ class MedImage(MedBase): class MedMask(MedBase): """Subclass of MedBase that represents an mask object.""" _show_args = {'alpha':0.5, 'cmap':'tab20'} - -# %% ../nbs/01_vision_core.ipynb 14 -import os -from fastai.callback.progress import ProgressCallback -from fastai.callback.core import Callback -import sys -from IPython import get_ipython - -class VSCodeProgressCallback(ProgressCallback): - """Enhanced progress callback that works better in VS Code notebooks.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.is_vscode = self._detect_vscode_environment() - self.lr_find_progress = None - - def _detect_vscode_environment(self): - """Detect if running in VS Code Jupyter environment.""" - ipython = get_ipython() - if ipython is None: - return True # Assume VS Code if no IPython (safer default) - # VS Code detection - more comprehensive check - kernel_name = str(type(ipython.kernel)).lower() if hasattr(ipython, 'kernel') else '' - return ('vscode' in kernel_name or - 'zmq' in kernel_name or # VS Code often uses ZMQInteractiveShell - not hasattr(ipython, 'display_pub')) # Missing display publisher often indicates VS Code - - def before_fit(self): - """Initialize progress tracking before training.""" - if self.is_vscode: - if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder: - # This is lr_find, handle differently - print("🔍 Starting Learning Rate Finder...") - self.lr_find_progress = 0 - else: - # Regular training - print(f"🚀 Training for {self.learn.n_epoch} epochs...") - super().before_fit() - - def before_epoch(self): - """Initialize epoch progress.""" - if self.is_vscode: - if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder: - print(f"📊 LR Find - Testing learning rates...") - else: - print(f"📈 Epoch {self.epoch+1}/{self.learn.n_epoch}") - sys.stdout.flush() - super().before_epoch() - - def after_batch(self): - """Update progress after each batch.""" - super().after_batch() - if self.is_vscode: - if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder: - # Special handling for lr_find - self.lr_find_progress = getattr(self, 'iter', 0) + 1 - total = getattr(self, 'n_iter', 100) - if self.lr_find_progress % max(1, total // 10) == 0: - progress = (self.lr_find_progress / total) * 100 - print(f"⏳ LR Find Progress: {self.lr_find_progress}/{total} ({progress:.1f}%)") - sys.stdout.flush() - else: - # Regular training progress - if hasattr(self, 'iter') and hasattr(self, 'n_iter'): - if self.iter % max(1, self.n_iter // 20) == 0: - progress = (self.iter / self.n_iter) * 100 - print(f"⏳ Batch {self.iter}/{self.n_iter} ({progress:.1f}%)") - sys.stdout.flush() - - def after_fit(self): - """Complete progress tracking after training.""" - if self.is_vscode: - if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder: - print("✅ Learning Rate Finder completed!") - else: - print("✅ Training completed!") - sys.stdout.flush() - super().after_fit() - - def before_validate(self): - """Update before validation.""" - if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder): - print("🔄 Validating...") - sys.stdout.flush() - super().before_validate() - - def after_validate(self): - """Update after validation.""" - if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder): - print("✅ Validation completed") - sys.stdout.flush() - super().after_validate() - -def setup_vscode_progress(): - """Configure fastai to use VS Code-compatible progress callback.""" - from fastai.learner import defaults - - # Replace default ProgressCallback with VSCodeProgressCallback - if ProgressCallback in defaults.callbacks: - defaults.callbacks = [cb if cb != ProgressCallback else VSCodeProgressCallback - for cb in defaults.callbacks] - - print("✅ Configured VS Code-compatible progress callback") diff --git a/fastMONAI/vision_metrics.py b/fastMONAI/vision_metrics.py index d2ca7a2..51dafc3 100644 --- a/fastMONAI/vision_metrics.py +++ b/fastMONAI/vision_metrics.py @@ -2,12 +2,15 @@ # %% auto 0 __all__ = ['calculate_dsc', 'calculate_haus', 'binary_dice_score', 'multi_dice_score', 'binary_hausdorff_distance', - 'multi_hausdorff_distance'] + 'multi_hausdorff_distance', 'calculate_confusion_metrics', 'binary_sensitivity', 'multi_sensitivity', + 'binary_precision', 'multi_precision', 'calculate_lesion_detection_rate', 'binary_lesion_detection_rate', + 'multi_lesion_detection_rate', 'calculate_signed_rve', 'binary_signed_rve', 'multi_signed_rve'] # %% ../nbs/05_vision_metrics.ipynb 1 import torch import numpy as np -from monai.metrics import compute_hausdorff_distance, compute_dice +from monai.metrics import compute_hausdorff_distance, compute_dice, get_confusion_matrix, compute_confusion_matrix_metric +from scipy.ndimage import label as scipy_label from .vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask # %% ../nbs/05_vision_metrics.ipynb 3 @@ -18,9 +21,19 @@ def calculate_dsc(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: # %% ../nbs/05_vision_metrics.ipynb 4 def calculate_haus(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: - """MONAI `compute_hausdorff_distance`""" - - return torch.Tensor([compute_hausdorff_distance(p[None], t[None]) for p, t in list(zip(pred,targ))]) + """Compute 95th percentile Hausdorff distance (HD95) using MONAI. + + HD95 is more robust than standard Hausdorff distance as it ignores + the top 5% of outlier distances. + + Args: + pred: Binary prediction tensor [B, C, W, H, D]. + targ: Binary target tensor [B, C, W, H, D]. + + Returns: + HD95 values for each sample in batch. + """ + return torch.Tensor([compute_hausdorff_distance(p[None], t[None], percentile=95) for p, t in list(zip(pred,targ))]) # %% ../nbs/05_vision_metrics.ipynb 5 def binary_dice_score(act: torch.tensor, targ: torch.Tensor) -> torch.Tensor: @@ -62,39 +75,275 @@ def multi_dice_score(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: # %% ../nbs/05_vision_metrics.ipynb 7 def binary_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: - """Calculate the mean Hausdorff distance for binary semantic segmentation tasks. + """Calculate the mean HD95 for binary semantic segmentation tasks. Args: act: Activation tensor with dimensions [B, C, W, H, D]. targ: Target masks with dimensions [B, C, W, H, D]. Returns: - Mean Hausdorff distance. + Mean HD95. """ - - pred = pred_to_binary_mask(act) - haus = calculate_haus(pred.cpu(), targ.cpu()) return torch.mean(haus) # %% ../nbs/05_vision_metrics.ipynb 8 -def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor : - """Calculate the mean Hausdorff distance for each class in multi-class semantic segmentation tasks. +def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate the mean HD95 for each class in multi-class semantic segmentation tasks. Args: act: Activation tensor with dimensions [B, C, W, H, D]. targ: Target masks with dimensions [B, C, W, H, D]. Returns: - Mean Hausdorff distance for each class. + Mean HD95 for each class. """ - pred, n_classes = batch_pred_to_multiclass_mask(act) binary_haus = [] for c in range(1, n_classes): c_pred, c_targ = torch.where(pred==c, 1, 0), torch.where(targ==c, 1, 0) - haus = calculate_haus(pred, targ) + haus = calculate_haus(c_pred, c_targ) binary_haus.append(np.nanmean(haus)) return torch.Tensor(binary_haus) + +# %% ../nbs/05_vision_metrics.ipynb 10 +def calculate_confusion_metrics(pred: torch.Tensor, targ: torch.Tensor, metric_name: str) -> torch.Tensor: + """Calculate confusion matrix-based metric using MONAI. + + Args: + pred: Binary prediction tensor [B, C, W, H, D]. + targ: Binary target tensor [B, C, W, H, D]. + metric_name: One of "sensitivity", "precision", "specificity", "f1 score". + + Returns: + Metric values for each sample in batch. + """ + # get_confusion_matrix expects one-hot format and returns [B, n_class, 4] where 4 = [TP, FP, TN, FN] + confusion_matrix = get_confusion_matrix(pred, targ, include_background=False) + metric = compute_confusion_matrix_metric(metric_name, confusion_matrix) + return metric + +# %% ../nbs/05_vision_metrics.ipynb 11 +def binary_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean sensitivity (recall) for binary segmentation. + + Sensitivity = TP / (TP + FN) - measures the proportion of actual positives + that are correctly identified. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean sensitivity score. + """ + pred = pred_to_binary_mask(act) + sens = calculate_confusion_metrics(pred.cpu(), targ.cpu(), "sensitivity") + return torch.nanmean(sens) + +# %% ../nbs/05_vision_metrics.ipynb 12 +def multi_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean sensitivity for each class in multi-class segmentation. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean sensitivity for each class. + """ + pred, n_classes = batch_pred_to_multiclass_mask(act) + class_sens = [] + + for c in range(1, n_classes): + c_pred = torch.where(pred == c, 1, 0) + c_targ = torch.where(targ == c, 1, 0) + sens = calculate_confusion_metrics(c_pred, c_targ, "sensitivity") + class_sens.append(np.nanmean(sens.numpy())) + + return torch.Tensor(class_sens) + +# %% ../nbs/05_vision_metrics.ipynb 13 +def binary_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean precision for binary segmentation. + + Precision = TP / (TP + FP) - measures the proportion of positive predictions + that are actually correct. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean precision score. + """ + pred = pred_to_binary_mask(act) + prec = calculate_confusion_metrics(pred.cpu(), targ.cpu(), "precision") + return torch.nanmean(prec) + +# %% ../nbs/05_vision_metrics.ipynb 14 +def multi_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean precision for each class in multi-class segmentation. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean precision for each class. + """ + pred, n_classes = batch_pred_to_multiclass_mask(act) + class_prec = [] + + for c in range(1, n_classes): + c_pred = torch.where(pred == c, 1, 0) + c_targ = torch.where(targ == c, 1, 0) + prec = calculate_confusion_metrics(c_pred, c_targ, "precision") + class_prec.append(np.nanmean(prec.numpy())) + + return torch.Tensor(class_prec) + +# %% ../nbs/05_vision_metrics.ipynb 16 +def calculate_lesion_detection_rate(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate lesion-wise detection rate. + + For each connected component (lesion) in the target, check if there is + any overlap with the prediction. A lesion is considered detected if + at least one voxel overlaps. + + Args: + pred: Binary prediction tensor [B, C, W, H, D]. + targ: Binary target tensor [B, C, W, H, D]. + + Returns: + Detection rate (detected lesions / total lesions) for each sample. + """ + detection_rates = [] + + for p, t in zip(pred, targ): + p_np = p.squeeze().cpu().numpy() + t_np = t.squeeze().cpu().numpy() + + # Label connected components in target + labeled_targ, n_lesions = scipy_label(t_np) + + if n_lesions == 0: + detection_rates.append(float('nan')) + continue + + detected = 0 + for lesion_id in range(1, n_lesions + 1): + lesion_mask = (labeled_targ == lesion_id) + overlap = (p_np * lesion_mask).sum() + if overlap > 0: + detected += 1 + + detection_rates.append(detected / n_lesions) + + return torch.Tensor(detection_rates) + +# %% ../nbs/05_vision_metrics.ipynb 17 +def binary_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean lesion detection rate for binary segmentation. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean lesion detection rate. + """ + pred = pred_to_binary_mask(act) + ldr = calculate_lesion_detection_rate(pred.cpu(), targ.cpu()) + return torch.nanmean(ldr) + +# %% ../nbs/05_vision_metrics.ipynb 18 +def multi_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean lesion detection rate for each class in multi-class segmentation. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean lesion detection rate for each class. + """ + pred, n_classes = batch_pred_to_multiclass_mask(act) + class_ldr = [] + + for c in range(1, n_classes): + c_pred = torch.where(pred == c, 1, 0) + c_targ = torch.where(targ == c, 1, 0) + ldr = calculate_lesion_detection_rate(c_pred, c_targ) + class_ldr.append(np.nanmean(ldr.numpy())) + + return torch.Tensor(class_ldr) + +# %% ../nbs/05_vision_metrics.ipynb 20 +def calculate_signed_rve(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate signed Relative Volume Error. + + RVE = (pred_volume - targ_volume) / targ_volume + + Positive values indicate over-segmentation (model predicts too large), + negative values indicate under-segmentation (model predicts too small). + + Args: + pred: Binary prediction tensor [B, C, W, H, D]. + targ: Binary target tensor [B, C, W, H, D]. + + Returns: + Signed RVE for each sample in batch. + """ + rve_values = [] + + for p, t in zip(pred, targ): + pred_vol = p.sum().float() + targ_vol = t.sum().float() + + if targ_vol == 0: + rve_values.append(float('nan')) + else: + rve = (pred_vol - targ_vol) / targ_vol + rve_values.append(rve.item()) + + return torch.Tensor(rve_values) + +# %% ../nbs/05_vision_metrics.ipynb 21 +def binary_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean signed RVE for binary segmentation. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean signed RVE. + """ + pred = pred_to_binary_mask(act) + rve = calculate_signed_rve(pred.cpu(), targ.cpu()) + return torch.nanmean(rve) + +# %% ../nbs/05_vision_metrics.ipynb 22 +def multi_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: + """Calculate mean signed RVE for each class in multi-class segmentation. + + Args: + act: Activation tensor [B, C, W, H, D]. + targ: Target masks [B, C, W, H, D]. + + Returns: + Mean signed RVE for each class. + """ + pred, n_classes = batch_pred_to_multiclass_mask(act) + class_rve = [] + + for c in range(1, n_classes): + c_pred = torch.where(pred == c, 1, 0) + c_targ = torch.where(targ == c, 1, 0) + rve = calculate_signed_rve(c_pred, c_targ) + class_rve.append(np.nanmean(rve.numpy())) + + return torch.Tensor(class_rve) diff --git a/nbs/01_vision_core.ipynb b/nbs/01_vision_core.ipynb index 3661582..59f2844 100644 --- a/nbs/01_vision_core.ipynb +++ b/nbs/01_vision_core.ipynb @@ -24,7 +24,13 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "#| export\nfrom fastMONAI.vision_plot import *\nfrom fastai.data.all import *\nfrom torchio import ScalarImage, LabelMap, ToCanonical, Resample\nimport copy" + "source": [ + "#| export\n", + "from fastMONAI.vision_plot import *\n", + "from fastai.data.all import *\n", + "from torchio import ScalarImage, LabelMap, ToCanonical, Resample\n", + "import copy" + ] }, { "cell_type": "markdown", @@ -201,7 +207,123 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "#| export\nclass MedBase(torch.Tensor, metaclass=MetaResolver):\n \"\"\"A class that represents an image object.\n Metaclass casts `x` to this class if it is of type `cls._bypass_type`.\"\"\"\n \n _bypass_type = torch.Tensor\n _show_args = {'cmap':'gray'}\n resample, reorder = None, False\n affine_matrix = None\n\n @classmethod\n def create(cls, fn: (Path, str, L, list, torch.Tensor), **kwargs) -> torch.Tensor: \n \"\"\"\n Opens a medical image and casts it to MedBase object.\n If `fn` is a torch.Tensor, it's cast to MedBase object.\n\n Args:\n fn : (Path, str, torch.Tensor)\n Image path or a 4D torch.Tensor.\n kwargs : dict\n Additional parameters for the medical image reader.\n\n Returns:\n torch.Tensor : A 4D tensor as a MedBase object.\n \"\"\"\n if isinstance(fn, torch.Tensor):\n return cls(fn)\n\n return med_img_reader(fn, resample=cls.resample, reorder=cls.reorder, dtype=cls)\n\n def __new__(cls, x, **kwargs):\n \"\"\"Creates a new instance of MedBase from a tensor.\"\"\"\n if isinstance(x, torch.Tensor):\n # Create tensor of the same type and copy data\n res = torch.Tensor._make_subclass(cls, x.data, x.requires_grad)\n # Copy any additional attributes\n if hasattr(x, 'affine_matrix'):\n res.affine_matrix = x.affine_matrix\n return res\n else:\n # Handle other types by converting to tensor first\n tensor = torch.as_tensor(x, **kwargs)\n return cls.__new__(cls, tensor)\n\n def new_empty(self, size, **kwargs):\n \"\"\"Create a new empty tensor of the same type.\"\"\"\n # Create new tensor with same type and device/dtype\n kwargs.setdefault('dtype', self.dtype)\n kwargs.setdefault('device', self.device)\n new_tensor = torch.empty(size, **kwargs)\n # Use __new__ to create proper subclass instance\n return self.__class__.__new__(self.__class__, new_tensor)\n\n def __copy__(self):\n \"\"\"Shallow copy implementation.\"\"\"\n copied = self.__class__.__new__(self.__class__, self.clone())\n # Copy class attributes\n if hasattr(self, 'affine_matrix'):\n copied.affine_matrix = self.affine_matrix\n return copied\n\n def __deepcopy__(self, memo):\n \"\"\"Deep copy implementation.\"\"\"\n # Create a deep copy of the tensor data\n copied_data = self.clone()\n copied = self.__class__.__new__(self.__class__, copied_data)\n # Deep copy class attributes\n if hasattr(self, 'affine_matrix') and self.affine_matrix is not None:\n copied.affine_matrix = copy.deepcopy(self.affine_matrix, memo)\n else:\n copied.affine_matrix = None\n return copied\n\n @classmethod\n def item_preprocessing(cls, resample: (list, int, tuple), reorder: bool):\n \"\"\"\n Changes the values for the class variables `resample` and `reorder`.\n\n Args:\n resample : (list, int, tuple)\n A list with voxel spacing.\n reorder : bool\n Whether to reorder the data to be closest to canonical (RAS+) orientation.\n \"\"\"\n cls.resample = resample\n cls.reorder = reorder\n\n def show(self, ctx=None, channel: int = 0, slice_index: int = None, anatomical_plane: int = 0, **kwargs):\n \"\"\"\n Displays the Medimage using `merge(self._show_args, kwargs)`.\n\n Args:\n ctx : Any, optional\n Context to use for the display. Defaults to None.\n channel : int, optional\n The channel of the image to be displayed. Defaults to 0.\n slice_index : int or None, optional\n Index of the images to be displayed. Defaults to None.\n anatomical_plane : int, optional\n Anatomical plane of the image to be displayed. Defaults to 0.\n kwargs : dict, optional\n Additional parameters for the show function.\n\n Returns:\n Shown image.\n \"\"\"\n return show_med_img(\n self, ctx=ctx, channel=channel, slice_index=slice_index, \n anatomical_plane=anatomical_plane, voxel_size=self.resample, \n **merge(self._show_args, kwargs)\n )\n\n def __repr__(self) -> str:\n \"\"\"Returns the string representation of the MedBase instance.\"\"\"\n return f'{self.__class__.__name__} mode={self.mode} size={\"x\".join([str(d) for d in self.size])}'" + "source": [ + "#| export\n", + "class MedBase(torch.Tensor, metaclass=MetaResolver):\n", + " \"\"\"A class that represents an image object.\n", + " Metaclass casts `x` to this class if it is of type `cls._bypass_type`.\"\"\"\n", + " \n", + " _bypass_type = torch.Tensor\n", + " _show_args = {'cmap':'gray'}\n", + " resample, reorder = None, False\n", + " affine_matrix = None\n", + "\n", + " @classmethod\n", + " def create(cls, fn: (Path, str, L, list, torch.Tensor), **kwargs) -> torch.Tensor: \n", + " \"\"\"\n", + " Opens a medical image and casts it to MedBase object.\n", + " If `fn` is a torch.Tensor, it's cast to MedBase object.\n", + "\n", + " Args:\n", + " fn : (Path, str, torch.Tensor)\n", + " Image path or a 4D torch.Tensor.\n", + " kwargs : dict\n", + " Additional parameters for the medical image reader.\n", + "\n", + " Returns:\n", + " torch.Tensor : A 4D tensor as a MedBase object.\n", + " \"\"\"\n", + " if isinstance(fn, torch.Tensor):\n", + " return cls(fn)\n", + "\n", + " return med_img_reader(fn, resample=cls.resample, reorder=cls.reorder, dtype=cls)\n", + "\n", + " def __new__(cls, x, **kwargs):\n", + " \"\"\"Creates a new instance of MedBase from a tensor.\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " # Create tensor of the same type and copy data\n", + " res = torch.Tensor._make_subclass(cls, x.data, x.requires_grad)\n", + " # Copy any additional attributes\n", + " if hasattr(x, 'affine_matrix'):\n", + " res.affine_matrix = x.affine_matrix\n", + " return res\n", + " else:\n", + " # Handle other types by converting to tensor first\n", + " tensor = torch.as_tensor(x, **kwargs)\n", + " return cls.__new__(cls, tensor)\n", + "\n", + " def new_empty(self, size, **kwargs):\n", + " \"\"\"Create a new empty tensor of the same type.\"\"\"\n", + " # Create new tensor with same type and device/dtype\n", + " kwargs.setdefault('dtype', self.dtype)\n", + " kwargs.setdefault('device', self.device)\n", + " new_tensor = torch.empty(size, **kwargs)\n", + " # Use __new__ to create proper subclass instance\n", + " return self.__class__.__new__(self.__class__, new_tensor)\n", + "\n", + " def __copy__(self):\n", + " \"\"\"Shallow copy implementation.\"\"\"\n", + " copied = self.__class__.__new__(self.__class__, self.clone())\n", + " # Copy class attributes\n", + " if hasattr(self, 'affine_matrix'):\n", + " copied.affine_matrix = self.affine_matrix\n", + " return copied\n", + "\n", + " def __deepcopy__(self, memo):\n", + " \"\"\"Deep copy implementation.\"\"\"\n", + " # Create a deep copy of the tensor data\n", + " copied_data = self.clone()\n", + " copied = self.__class__.__new__(self.__class__, copied_data)\n", + " # Deep copy class attributes\n", + " if hasattr(self, 'affine_matrix') and self.affine_matrix is not None:\n", + " copied.affine_matrix = copy.deepcopy(self.affine_matrix, memo)\n", + " else:\n", + " copied.affine_matrix = None\n", + " return copied\n", + "\n", + " @classmethod\n", + " def item_preprocessing(cls, resample: (list, int, tuple), reorder: bool):\n", + " \"\"\"\n", + " Changes the values for the class variables `resample` and `reorder`.\n", + "\n", + " Args:\n", + " resample : (list, int, tuple)\n", + " A list with voxel spacing.\n", + " reorder : bool\n", + " Whether to reorder the data to be closest to canonical (RAS+) orientation.\n", + " \"\"\"\n", + " cls.resample = resample\n", + " cls.reorder = reorder\n", + "\n", + " def show(self, ctx=None, channel: int = 0, slice_index: int = None, anatomical_plane: int = 0, **kwargs):\n", + " \"\"\"\n", + " Displays the Medimage using `merge(self._show_args, kwargs)`.\n", + "\n", + " Args:\n", + " ctx : Any, optional\n", + " Context to use for the display. Defaults to None.\n", + " channel : int, optional\n", + " The channel of the image to be displayed. Defaults to 0.\n", + " slice_index : int or None, optional\n", + " Index of the images to be displayed. Defaults to None.\n", + " anatomical_plane : int, optional\n", + " Anatomical plane of the image to be displayed. Defaults to 0.\n", + " kwargs : dict, optional\n", + " Additional parameters for the show function.\n", + "\n", + " Returns:\n", + " Shown image.\n", + " \"\"\"\n", + " return show_med_img(\n", + " self, ctx=ctx, channel=channel, slice_index=slice_index, \n", + " anatomical_plane=anatomical_plane, voxel_size=self.resample, \n", + " **merge(self._show_args, kwargs)\n", + " )\n", + "\n", + " def __repr__(self) -> str:\n", + " \"\"\"Returns the string representation of the MedBase instance.\"\"\"\n", + " return f'{self.__class__.__name__} mode={self.mode} size={\"x\".join([str(d) for d in self.size])}'" + ] }, { "cell_type": "code", @@ -227,13 +349,6 @@ " _show_args = {'alpha':0.5, 'cmap':'tab20'}" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": "#| export\nimport os\nfrom fastai.callback.progress import ProgressCallback\nfrom fastai.callback.core import Callback\nimport sys\nfrom IPython import get_ipython\n\nclass VSCodeProgressCallback(ProgressCallback):\n \"\"\"Enhanced progress callback that works better in VS Code notebooks.\"\"\"\n \n def __init__(self, **kwargs):\n super().__init__(**kwargs)\n self.is_vscode = self._detect_vscode_environment()\n self.lr_find_progress = None\n \n def _detect_vscode_environment(self):\n \"\"\"Detect if running in VS Code Jupyter environment.\"\"\"\n ipython = get_ipython()\n if ipython is None:\n return True # Assume VS Code if no IPython (safer default)\n # VS Code detection - more comprehensive check\n kernel_name = str(type(ipython.kernel)).lower() if hasattr(ipython, 'kernel') else ''\n return ('vscode' in kernel_name or \n 'zmq' in kernel_name or # VS Code often uses ZMQInteractiveShell\n not hasattr(ipython, 'display_pub')) # Missing display publisher often indicates VS Code\n \n def before_fit(self):\n \"\"\"Initialize progress tracking before training.\"\"\"\n if self.is_vscode:\n if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:\n # This is lr_find, handle differently\n print(\"🔍 Starting Learning Rate Finder...\")\n self.lr_find_progress = 0\n else:\n # Regular training\n print(f\"🚀 Training for {self.learn.n_epoch} epochs...\")\n super().before_fit()\n \n def before_epoch(self):\n \"\"\"Initialize epoch progress.\"\"\"\n if self.is_vscode:\n if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:\n print(f\"📊 LR Find - Testing learning rates...\")\n else:\n print(f\"📈 Epoch {self.epoch+1}/{self.learn.n_epoch}\")\n sys.stdout.flush()\n super().before_epoch()\n \n def after_batch(self):\n \"\"\"Update progress after each batch.\"\"\"\n super().after_batch()\n if self.is_vscode:\n if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:\n # Special handling for lr_find\n self.lr_find_progress = getattr(self, 'iter', 0) + 1\n total = getattr(self, 'n_iter', 100)\n if self.lr_find_progress % max(1, total // 10) == 0:\n progress = (self.lr_find_progress / total) * 100\n print(f\"⏳ LR Find Progress: {self.lr_find_progress}/{total} ({progress:.1f}%)\")\n sys.stdout.flush()\n else:\n # Regular training progress\n if hasattr(self, 'iter') and hasattr(self, 'n_iter'):\n if self.iter % max(1, self.n_iter // 20) == 0:\n progress = (self.iter / self.n_iter) * 100\n print(f\"⏳ Batch {self.iter}/{self.n_iter} ({progress:.1f}%)\")\n sys.stdout.flush()\n \n def after_fit(self):\n \"\"\"Complete progress tracking after training.\"\"\"\n if self.is_vscode:\n if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:\n print(\"✅ Learning Rate Finder completed!\")\n else:\n print(\"✅ Training completed!\")\n sys.stdout.flush()\n super().after_fit()\n \n def before_validate(self):\n \"\"\"Update before validation.\"\"\"\n if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):\n print(\"🔄 Validating...\")\n sys.stdout.flush()\n super().before_validate()\n \n def after_validate(self):\n \"\"\"Update after validation.\"\"\"\n if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):\n print(\"✅ Validation completed\")\n sys.stdout.flush()\n super().after_validate()\n\ndef setup_vscode_progress():\n \"\"\"Configure fastai to use VS Code-compatible progress callback.\"\"\"\n from fastai.learner import defaults\n \n # Replace default ProgressCallback with VSCodeProgressCallback\n if ProgressCallback in defaults.callbacks:\n defaults.callbacks = [cb if cb != ProgressCallback else VSCodeProgressCallback \n for cb in defaults.callbacks]\n \n print(\"✅ Configured VS Code-compatible progress callback\")" - }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/03_vision_augment.ipynb b/nbs/03_vision_augment.ipynb index 5587ca6..0a53c3d 100644 --- a/nbs/03_vision_augment.ipynb +++ b/nbs/03_vision_augment.ipynb @@ -304,44 +304,14 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class RandomGhosting(DisplayedTransform):\n", - " \"\"\"Apply TorchIO `RandomGhosting`.\"\"\"\n", - " \n", - " split_idx, order = 0, 1\n", - "\n", - " def __init__(self, intensity=(0.5, 1), p=0.5):\n", - " self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)\n", - "\n", - " def encodes(self, o: MedImage):\n", - " return MedImage.create(self.add_ghosts(o))\n", - "\n", - " def encodes(self, o: MedMask):\n", - " return o" - ] + "source": "#| export\nclass RandomGhosting(DisplayedTransform):\n \"\"\"Apply TorchIO `RandomGhosting`.\"\"\"\n \n split_idx, order = 0, 1\n\n def __init__(self, intensity=(0.5, 1), p=0.5):\n self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)\n\n def encodes(self, o: MedImage):\n result = self.add_ghosts(o)\n # Handle potential complex values from k-space operations\n if result.is_complex():\n result = torch.real(result)\n return MedImage.create(result)\n\n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class RandomSpike(DisplayedTransform):\n", - " '''Apply TorchIO `RandomSpike`.'''\n", - " \n", - " split_idx,order=0,1\n", - "\n", - " def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):\n", - " self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)\n", - "\n", - " def encodes(self, o:MedImage): \n", - " return MedImage.create(self.add_spikes(o))\n", - " \n", - " def encodes(self, o:MedMask):\n", - " return o" - ] + "source": "#| export\nclass RandomSpike(DisplayedTransform):\n '''Apply TorchIO `RandomSpike`.'''\n \n split_idx,order=0,1\n\n def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):\n self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)\n\n def encodes(self, o:MedImage): \n result = self.add_spikes(o)\n # Handle potential complex values from k-space operations\n if result.is_complex():\n result = torch.real(result)\n return MedImage.create(result)\n \n def encodes(self, o:MedMask):\n return o" }, { "cell_type": "code", @@ -437,35 +407,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class RandomMotion(DisplayedTransform):\n", - " \"\"\"Apply TorchIO `RandomMotion`.\"\"\"\n", - "\n", - " split_idx, order = 0, 1\n", - "\n", - " def __init__(\n", - " self, \n", - " degrees=10, \n", - " translation=10, \n", - " num_transforms=2, \n", - " image_interpolation='linear', \n", - " p=0.5\n", - " ):\n", - " self.add_motion = tio.RandomMotion(\n", - " degrees=degrees, \n", - " translation=translation, \n", - " num_transforms=num_transforms, \n", - " image_interpolation=image_interpolation, \n", - " p=p\n", - " )\n", - "\n", - " def encodes(self, o: MedImage):\n", - " return MedImage.create(self.add_motion(o))\n", - "\n", - " def encodes(self, o: MedMask):\n", - " return o" - ] + "source": "#| export\nclass RandomMotion(DisplayedTransform):\n \"\"\"Apply TorchIO `RandomMotion`.\"\"\"\n\n split_idx, order = 0, 1\n\n def __init__(\n self, \n degrees=10, \n translation=10, \n num_transforms=2, \n image_interpolation='linear', \n p=0.5\n ):\n self.add_motion = tio.RandomMotion(\n degrees=degrees, \n translation=translation, \n num_transforms=num_transforms, \n image_interpolation=image_interpolation, \n p=p\n )\n\n def encodes(self, o: MedImage):\n result = self.add_motion(o)\n # Handle potential complex values from k-space operations\n if result.is_complex():\n result = torch.real(result)\n return MedImage.create(result)\n\n def encodes(self, o: MedMask):\n return o" }, { "cell_type": "markdown", diff --git a/nbs/05_vision_metrics.ipynb b/nbs/05_vision_metrics.ipynb index 5ea3bbc..6bdd2bc 100644 --- a/nbs/05_vision_metrics.ipynb +++ b/nbs/05_vision_metrics.ipynb @@ -16,13 +16,7 @@ "id": "8b6a83ac", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "import torch\n", - "import numpy as np\n", - "from monai.metrics import compute_hausdorff_distance, compute_dice\n", - "from fastMONAI.vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask" - ] + "source": "#| export\nimport torch\nimport numpy as np\nfrom monai.metrics import compute_hausdorff_distance, compute_dice, get_confusion_matrix, compute_confusion_matrix_metric\nfrom scipy.ndimage import label as scipy_label\nfrom fastMONAI.vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask" }, { "cell_type": "markdown", @@ -53,13 +47,7 @@ "id": "4f815b1d-ea53-4f65-a7f0-299ea54e872b", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "def calculate_haus(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"MONAI `compute_hausdorff_distance`\"\"\"\n", - "\n", - " return torch.Tensor([compute_hausdorff_distance(p[None], t[None]) for p, t in list(zip(pred,targ))])" - ] + "source": "#| export\ndef calculate_haus(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Compute 95th percentile Hausdorff distance (HD95) using MONAI.\n \n HD95 is more robust than standard Hausdorff distance as it ignores \n the top 5% of outlier distances.\n \n Args:\n pred: Binary prediction tensor [B, C, W, H, D].\n targ: Binary target tensor [B, C, W, H, D].\n \n Returns:\n HD95 values for each sample in batch.\n \"\"\"\n return torch.Tensor([compute_hausdorff_distance(p[None], t[None], percentile=95) for p, t in list(zip(pred,targ))])" }, { "cell_type": "code", @@ -121,25 +109,7 @@ "id": "a390762f-d1a9-4674-b099-2369769f4198", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "def binary_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"Calculate the mean Hausdorff distance for binary semantic segmentation tasks.\n", - " \n", - " Args:\n", - " act: Activation tensor with dimensions [B, C, W, H, D].\n", - " targ: Target masks with dimensions [B, C, W, H, D].\n", - "\n", - " Returns:\n", - " Mean Hausdorff distance.\n", - " \"\"\"\n", - " \n", - "\n", - " pred = pred_to_binary_mask(act)\n", - "\n", - " haus = calculate_haus(pred.cpu(), targ.cpu())\n", - " return torch.mean(haus)" - ] + "source": "#| export\ndef binary_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate the mean HD95 for binary semantic segmentation tasks.\n \n Args:\n act: Activation tensor with dimensions [B, C, W, H, D].\n targ: Target masks with dimensions [B, C, W, H, D].\n\n Returns:\n Mean HD95.\n \"\"\"\n pred = pred_to_binary_mask(act)\n haus = calculate_haus(pred.cpu(), targ.cpu())\n return torch.mean(haus)" }, { "cell_type": "code", @@ -147,28 +117,113 @@ "id": "ea94dea5", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor :\n", - " \"\"\"Calculate the mean Hausdorff distance for each class in multi-class semantic segmentation tasks.\n", - " \n", - " Args:\n", - " act: Activation tensor with dimensions [B, C, W, H, D].\n", - " targ: Target masks with dimensions [B, C, W, H, D].\n", - "\n", - " Returns:\n", - " Mean Hausdorff distance for each class.\n", - " \"\"\"\n", - "\n", - " pred, n_classes = batch_pred_to_multiclass_mask(act)\n", - " binary_haus = []\n", - "\n", - " for c in range(1, n_classes):\n", - " c_pred, c_targ = torch.where(pred==c, 1, 0), torch.where(targ==c, 1, 0)\n", - " haus = calculate_haus(pred, targ)\n", - " binary_haus.append(np.nanmean(haus))\n", - " return torch.Tensor(binary_haus)" - ] + "source": "#| export\ndef multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate the mean HD95 for each class in multi-class semantic segmentation tasks.\n \n Args:\n act: Activation tensor with dimensions [B, C, W, H, D].\n targ: Target masks with dimensions [B, C, W, H, D].\n\n Returns:\n Mean HD95 for each class.\n \"\"\"\n pred, n_classes = batch_pred_to_multiclass_mask(act)\n binary_haus = []\n\n for c in range(1, n_classes):\n c_pred, c_targ = torch.where(pred==c, 1, 0), torch.where(targ==c, 1, 0)\n haus = calculate_haus(c_pred, c_targ)\n binary_haus.append(np.nanmean(haus))\n return torch.Tensor(binary_haus)" + }, + { + "cell_type": "markdown", + "id": "xgm61dhkbwn", + "metadata": {}, + "source": "## Sensitivity and Precision" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdkdak1uf9s", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef calculate_confusion_metrics(pred: torch.Tensor, targ: torch.Tensor, metric_name: str) -> torch.Tensor:\n \"\"\"Calculate confusion matrix-based metric using MONAI.\n \n Args:\n pred: Binary prediction tensor [B, C, W, H, D].\n targ: Binary target tensor [B, C, W, H, D].\n metric_name: One of \"sensitivity\", \"precision\", \"specificity\", \"f1 score\".\n \n Returns:\n Metric values for each sample in batch.\n \"\"\"\n # get_confusion_matrix expects one-hot format and returns [B, n_class, 4] where 4 = [TP, FP, TN, FN]\n confusion_matrix = get_confusion_matrix(pred, targ, include_background=False)\n metric = compute_confusion_matrix_metric(metric_name, confusion_matrix)\n return metric" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "zbxmbdwq0ao", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef binary_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean sensitivity (recall) for binary segmentation.\n \n Sensitivity = TP / (TP + FN) - measures the proportion of actual positives\n that are correctly identified.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean sensitivity score.\n \"\"\"\n pred = pred_to_binary_mask(act)\n sens = calculate_confusion_metrics(pred.cpu(), targ.cpu(), \"sensitivity\")\n return torch.nanmean(sens)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "q0f40h69zan", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef multi_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean sensitivity for each class in multi-class segmentation.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean sensitivity for each class.\n \"\"\"\n pred, n_classes = batch_pred_to_multiclass_mask(act)\n class_sens = []\n \n for c in range(1, n_classes):\n c_pred = torch.where(pred == c, 1, 0)\n c_targ = torch.where(targ == c, 1, 0)\n sens = calculate_confusion_metrics(c_pred, c_targ, \"sensitivity\")\n class_sens.append(np.nanmean(sens.numpy()))\n \n return torch.Tensor(class_sens)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dlnx9vkx98j", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef binary_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean precision for binary segmentation.\n \n Precision = TP / (TP + FP) - measures the proportion of positive predictions\n that are actually correct.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean precision score.\n \"\"\"\n pred = pred_to_binary_mask(act)\n prec = calculate_confusion_metrics(pred.cpu(), targ.cpu(), \"precision\")\n return torch.nanmean(prec)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cwb5sivj8es", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef multi_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean precision for each class in multi-class segmentation.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean precision for each class.\n \"\"\"\n pred, n_classes = batch_pred_to_multiclass_mask(act)\n class_prec = []\n \n for c in range(1, n_classes):\n c_pred = torch.where(pred == c, 1, 0)\n c_targ = torch.where(targ == c, 1, 0)\n prec = calculate_confusion_metrics(c_pred, c_targ, \"precision\")\n class_prec.append(np.nanmean(prec.numpy()))\n \n return torch.Tensor(class_prec)" + }, + { + "cell_type": "markdown", + "id": "8ud77cvi8fg", + "metadata": {}, + "source": "## Lesion Detection Rate" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "flomm5fkz69", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef calculate_lesion_detection_rate(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate lesion-wise detection rate.\n \n For each connected component (lesion) in the target, check if there is\n any overlap with the prediction. A lesion is considered detected if\n at least one voxel overlaps.\n \n Args:\n pred: Binary prediction tensor [B, C, W, H, D].\n targ: Binary target tensor [B, C, W, H, D].\n \n Returns:\n Detection rate (detected lesions / total lesions) for each sample.\n \"\"\"\n detection_rates = []\n \n for p, t in zip(pred, targ):\n p_np = p.squeeze().cpu().numpy()\n t_np = t.squeeze().cpu().numpy()\n \n # Label connected components in target\n labeled_targ, n_lesions = scipy_label(t_np)\n \n if n_lesions == 0:\n detection_rates.append(float('nan'))\n continue\n \n detected = 0\n for lesion_id in range(1, n_lesions + 1):\n lesion_mask = (labeled_targ == lesion_id)\n overlap = (p_np * lesion_mask).sum()\n if overlap > 0:\n detected += 1\n \n detection_rates.append(detected / n_lesions)\n \n return torch.Tensor(detection_rates)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "jyjuvh9f09n", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef binary_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean lesion detection rate for binary segmentation.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean lesion detection rate.\n \"\"\"\n pred = pred_to_binary_mask(act)\n ldr = calculate_lesion_detection_rate(pred.cpu(), targ.cpu())\n return torch.nanmean(ldr)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5cwiazhpp", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef multi_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean lesion detection rate for each class in multi-class segmentation.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean lesion detection rate for each class.\n \"\"\"\n pred, n_classes = batch_pred_to_multiclass_mask(act)\n class_ldr = []\n \n for c in range(1, n_classes):\n c_pred = torch.where(pred == c, 1, 0)\n c_targ = torch.where(targ == c, 1, 0)\n ldr = calculate_lesion_detection_rate(c_pred, c_targ)\n class_ldr.append(np.nanmean(ldr.numpy()))\n \n return torch.Tensor(class_ldr)" + }, + { + "cell_type": "markdown", + "id": "19f3lszfd5y", + "metadata": {}, + "source": "## Signed Relative Volume Error (RVE)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8aat811h3zo", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef calculate_signed_rve(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate signed Relative Volume Error.\n \n RVE = (pred_volume - targ_volume) / targ_volume\n \n Positive values indicate over-segmentation (model predicts too large),\n negative values indicate under-segmentation (model predicts too small).\n \n Args:\n pred: Binary prediction tensor [B, C, W, H, D].\n targ: Binary target tensor [B, C, W, H, D].\n \n Returns:\n Signed RVE for each sample in batch.\n \"\"\"\n rve_values = []\n \n for p, t in zip(pred, targ):\n pred_vol = p.sum().float()\n targ_vol = t.sum().float()\n \n if targ_vol == 0:\n rve_values.append(float('nan'))\n else:\n rve = (pred_vol - targ_vol) / targ_vol\n rve_values.append(rve.item())\n \n return torch.Tensor(rve_values)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a18cg6z9qtr", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef binary_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean signed RVE for binary segmentation.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean signed RVE.\n \"\"\"\n pred = pred_to_binary_mask(act)\n rve = calculate_signed_rve(pred.cpu(), targ.cpu())\n return torch.nanmean(rve)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aetwom9qd3p", + "metadata": {}, + "outputs": [], + "source": "#| export\ndef multi_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate mean signed RVE for each class in multi-class segmentation.\n \n Args:\n act: Activation tensor [B, C, W, H, D].\n targ: Target masks [B, C, W, H, D].\n \n Returns:\n Mean signed RVE for each class.\n \"\"\"\n pred, n_classes = batch_pred_to_multiclass_mask(act)\n class_rve = []\n \n for c in range(1, n_classes):\n c_pred = torch.where(pred == c, 1, 0)\n c_targ = torch.where(targ == c, 1, 0)\n rve = calculate_signed_rve(c_pred, c_targ)\n class_rve.append(np.nanmean(rve.numpy()))\n \n return torch.Tensor(class_rve)" }, { "cell_type": "code", @@ -176,22 +231,7 @@ "id": "b84e25bf-11f5-4806-9657-1353054428e4", "metadata": {}, "outputs": [], - "source": [ - "#| hide \n", - "\n", - "# Test Dice score and Hausdorff distance \n", - "pred = torch.zeros((1,1,10,10,10))\n", - "pred[:,:,:5, :5, :5] = 1\n", - "\n", - "targ = torch.zeros((1,1,10,10,10))\n", - "targ[:,:,:5, :5, :5] = 1\n", - "\n", - "dsc = float(calculate_dsc(pred, targ)) \n", - "haus = float(calculate_haus(pred,targ))\n", - "\n", - "assert dsc == 1.0\n", - "assert haus == 0.0" - ] + "source": "#| hide \n\n# Test Dice score and Hausdorff distance (HD95)\npred = torch.zeros((1,1,10,10,10))\npred[:,:,:5, :5, :5] = 1\n\ntarg = torch.zeros((1,1,10,10,10))\ntarg[:,:,:5, :5, :5] = 1\n\ndsc = float(calculate_dsc(pred, targ)) \nhaus = float(calculate_haus(pred,targ))\n\nassert dsc == 1.0\nassert haus == 0.0\n\n# Test Signed RVE - perfect overlap should give RVE = 0\nrve = float(calculate_signed_rve(pred, targ))\nassert rve == 0.0, f\"Expected RVE=0 for same volumes, got {rve}\"\n\n# Test Lesion Detection Rate - single lesion fully detected\nldr = float(calculate_lesion_detection_rate(pred, targ))\nassert ldr == 1.0, f\"Expected LDR=1 for detected lesion, got {ldr}\"\n\n# Test over-segmentation: RVE should be positive\npred_over = torch.zeros((1, 1, 10, 10, 10))\npred_over[:, :, :6, :6, :6] = 1 # Larger than target (216 vs 125 voxels)\nrve_over = float(calculate_signed_rve(pred_over, targ))\nassert rve_over > 0, f\"Expected positive RVE for over-segmentation, got {rve_over}\"\n\n# Test under-segmentation: RVE should be negative\npred_under = torch.zeros((1, 1, 10, 10, 10))\npred_under[:, :, :3, :3, :3] = 1 # Smaller than target (27 vs 125 voxels)\nrve_under = float(calculate_signed_rve(pred_under, targ))\nassert rve_under < 0, f\"Expected negative RVE for under-segmentation, got {rve_under}\"\n\n# Test missed lesion: LDR should be 0\npred_miss = torch.zeros((1, 1, 10, 10, 10))\npred_miss[:, :, 6:10, 6:10, 6:10] = 1 # No overlap with target\nldr_miss = float(calculate_lesion_detection_rate(pred_miss, targ))\nassert ldr_miss == 0.0, f\"Expected LDR=0 for missed lesion, got {ldr_miss}\"\n\nprint(\"All metric tests passed!\")" } ], "metadata": { diff --git a/nbs/08_dataset_info.ipynb b/nbs/08_dataset_info.ipynb index 412a0bb..0c7438c 100644 --- a/nbs/08_dataset_info.ipynb +++ b/nbs/08_dataset_info.ipynb @@ -30,15 +30,25 @@ "source": [ "#| export \n", "from fastMONAI.vision_core import *\n", + "from fastMONAI.vision_plot import find_max_slice\n", "\n", "from sklearn.utils.class_weight import compute_class_weight\n", "from concurrent.futures import ThreadPoolExecutor\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", - "import glob" + "import glob\n", + "import matplotlib.pyplot as plt" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7401beac", + "metadata": {}, + "outputs": [], + "source": "#| export\nclass MedDataset:\n \"\"\"A class to extract and present information about the dataset.\"\"\"\n\n def __init__(self, dataframe=None, img_col:str =None, mask_col:str =\"mask_path\", path=None, postfix: str = '',\n reorder: bool = False, dtype: (MedImage, MedMask) = MedImage,\n max_workers: int = 1):\n \"\"\"Constructs MedDataset object.\"\"\"\n\n self.input_df = dataframe\n self.img_col = img_col\n self.mask_col = mask_col\n self.path = path\n self.postfix = postfix\n self.reorder = reorder\n self.dtype = dtype\n self.max_workers = max_workers\n self.df = self._create_data_frame()\n\n def _create_data_frame(self):\n \"\"\"Private method that returns a dataframe with information about the dataset.\"\"\"\n \n # Handle path-based initialization (legacy mode)\n if self.path:\n img_list = glob.glob(f'{self.path}/*{self.postfix}*')\n if not img_list: \n print('Could not find images. Check the image path')\n return pd.DataFrame()\n \n # Handle dataframe-based initialization (new mode)\n elif self.input_df is not None and self.mask_col in self.input_df.columns:\n img_list = self.input_df[self.mask_col].tolist()\n \n else:\n print('Error: Must provide either path or dataframe with mask_col')\n return pd.DataFrame()\n\n # Process images to extract metadata\n with ThreadPoolExecutor(max_workers=self.max_workers) as executor:\n data_info_dict = list(executor.map(self._get_data_info, img_list))\n\n df = pd.DataFrame(data_info_dict)\n \n if len(df) > 0 and df.orientation.nunique() > 1 and not self.reorder:\n raise ValueError(\n 'Mixed orientations detected in dataset. '\n 'Please recreate MedDataset with reorder=True to get correct resample values: '\n 'MedDataset(..., reorder=True)'\n )\n\n return df\n\n def summary(self):\n \"\"\"Summary DataFrame of the dataset with example path for similar data.\"\"\"\n \n columns = ['dim_0', 'dim_1', 'dim_2', 'voxel_0', 'voxel_1', 'voxel_2', 'orientation']\n \n return self.df.groupby(columns, as_index=False).agg(\n example_path=('path', 'min'), total=('path', 'size')\n ).sort_values('total', ascending=False)\n\n def suggestion(self):\n \"\"\"Voxel value that appears most often in dim_0, dim_1 and dim_2, and whether the data should be reoriented.\"\"\"\n \n resample = [float(self.df.voxel_0.mode()[0]), float(self.df.voxel_1.mode()[0]), float(self.df.voxel_2.mode()[0])]\n return resample, self.reorder\n\n def _get_data_info(self, fn: str):\n \"\"\"Private method to collect information about an image file.\"\"\"\n try:\n _, o, _ = med_img_reader(fn, reorder=self.reorder, only_tensor=False, dtype=self.dtype)\n\n info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],\n 'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),\n 'orientation': f'{\"\".join(o.orientation)}+'}\n\n if self.dtype is MedMask:\n # Calculate voxel volume in mm³\n voxel_volume = o.spacing[0] * o.spacing[1] * o.spacing[2]\n \n # Get voxel counts for each label\n mask_labels_dict = o.count_labels()\n \n # Calculate volumes for each label > 0 (skip background)\n for key, voxel_count in mask_labels_dict.items():\n label_int = int(key)\n if label_int > 0 and voxel_count > 0: # Skip background (label 0)\n volume_mm3 = voxel_count * voxel_volume\n info_dict[f'label_{label_int}_volume_mm3'] = round(volume_mm3, 4)\n\n return info_dict\n \n except Exception as e:\n print(f\"Warning: Failed to process {fn}: {e}\")\n return {'path': fn, 'error': str(e)}\n\n def get_largest_img_size(self, resample: list = None) -> list:\n \"\"\"Get the largest image size in the dataset.\"\"\"\n \n dims = None\n\n if resample is not None:\n org_voxels = self.df[[\"voxel_0\", \"voxel_1\", 'voxel_2']].values\n org_dims = self.df[[\"dim_0\", \"dim_1\", 'dim_2']].values\n\n ratio = org_voxels/resample\n new_dims = (org_dims * ratio).T\n dims = [float(new_dims[0].max().round()), float(new_dims[1].max().round()), float(new_dims[2].max().round())]\n\n else:\n dims = [float(self.df.dim_0.max()), float(self.df.dim_1.max()), float(self.df.dim_2.max())]\n\n return dims\n\n def get_volume_summary(self):\n \"\"\"Get summary statistics for volume columns.\"\"\"\n volume_cols = [col for col in self.df.columns if col.endswith('_volume_mm3')]\n \n if not volume_cols:\n print(\"No volume columns found. Make sure dtype=MedMask when creating the dataset.\")\n return None\n \n print(\"📊 Volume Summary:\")\n print(\"=\" * 50)\n \n for col in volume_cols:\n # Get non-zero volumes\n non_zero_volumes = self.df[self.df[col] > 0][col]\n \n if len(non_zero_volumes) > 0:\n print(f\"\\n{col}:\")\n print(f\" Cases with volume: {len(non_zero_volumes)}\")\n print(f\" Mean volume: {non_zero_volumes.mean():.2f} mm³\")\n print(f\" Median volume: {non_zero_volumes.median():.2f} mm³\")\n print(f\" Min volume: {non_zero_volumes.min():.2f} mm³\")\n print(f\" Max volume: {non_zero_volumes.max():.2f} mm³\")\n else:\n print(f\"\\n{col}: No cases with volume > 0\")\n \n def _visualize_single_case(self, img_path, mask_path, case_id, anatomical_plane=2, cmap='hot', figsize=(12, 5)):\n \"\"\"Helper method to visualize a single case.\"\"\"\n try:\n # Create MedImage and MedMask with current preprocessing settings\n resample, reorder = self.suggestion()\n MedBase.item_preprocessing(resample=resample, reorder=reorder)\n \n img = MedImage.create(img_path)\n mask = MedMask.create(mask_path)\n \n # Find optimal slice using explicit function\n mask_data = mask.numpy()[0] # Remove channel dimension\n optimal_slice = find_max_slice(mask_data, anatomical_plane)\n \n # Create subplot\n fig, axes = plt.subplots(1, 2, figsize=figsize)\n \n # Show image\n img.show(ctx=axes[0], anatomical_plane=anatomical_plane, slice_index=optimal_slice)\n axes[0].set_title(f\"{case_id} - Image (slice {optimal_slice})\")\n \n # Show overlay\n img.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice)\n mask.show(ctx=axes[1], anatomical_plane=anatomical_plane, slice_index=optimal_slice, \n alpha=0.3, cmap=cmap)\n axes[1].set_title(f\"{case_id} - Overlay (slice {optimal_slice})\")\n \n # Adjust spacing to bring plots closer\n plt.subplots_adjust(wspace=0.1)\n plt.tight_layout()\n plt.show()\n \n except Exception as e:\n print(f\"❌ Failed to visualize case {case_id}: {e}\")\n\n def visualize_cases(self, n_cases=4, anatomical_plane=2, cmap='hot', figsize=(12, 5)):\n \"\"\"\n Visualize cases from the dataset.\n \n Args:\n n_cases: Number of cases to show. If None, shows all cases.\n anatomical_plane: 0=sagittal, 1=coronal, 2=axial\n cmap: Colormap for mask overlay\n figsize: Figure size for each case\n \"\"\"\n if self.input_df is None:\n print(\"Error: No dataframe provided. Cannot visualize cases.\")\n return\n \n if self.img_col is None:\n print(\"Error: No img_col specified. Cannot visualize cases.\")\n return\n \n # Check if required columns exist\n if self.img_col not in self.input_df.columns:\n print(f\"Error: Column '{self.img_col}' not found in dataframe.\")\n return\n \n if self.mask_col not in self.input_df.columns:\n print(f\"Error: Column '{self.mask_col}' not found in dataframe.\")\n return\n\n for idx in range(min(n_cases, len(self.input_df))):\n row = self.input_df.iloc[idx]\n case_id = row.get('case_id', f'Case_{idx}') # Fallback if no case_id\n img_path = row[self.img_col]\n mask_path = row[self.mask_col]\n\n self._visualize_single_case(img_path, mask_path, case_id, anatomical_plane, cmap, figsize)\n print(\"-\" * 60)" + }, { "cell_type": "markdown", "id": "74812108-f3eb-4a8d-9f2d-b93132619008", @@ -48,14 +58,6 @@ ">" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "3593203e-e5e1-4564-94d4-8e31b7048cf9", - "metadata": {}, - "outputs": [], - "source": "#| export\nclass MedDataset:\n \"\"\"A class to extract and present information about the dataset.\"\"\"\n\n def __init__(self, path=None, postfix: str = '', img_list: list = None,\n reorder: bool = False, dtype: (MedImage, MedMask) = MedImage,\n max_workers: int = 1):\n \"\"\"Constructs MedDataset object.\n\n Args:\n path (str, optional): Path to the image folder.\n postfix (str, optional): Specify the file type if there are different files in the folder.\n img_list (List[str], optional): Alternatively, pass in a list with image paths.\n reorder (bool, optional): Whether to reorder the data to be closest to canonical (RAS+) orientation.\n dtype (Union[MedImage, MedMask], optional): Load data as datatype. Default is MedImage.\n max_workers (int, optional): The number of worker threads. Default is 1.\n \"\"\"\n \n self.path = path\n self.postfix = postfix\n self.img_list = img_list\n self.reorder = reorder\n self.dtype = dtype\n self.max_workers = max_workers\n self.df = self._create_data_frame()\n\n def _create_data_frame(self):\n \"\"\"Private method that returns a dataframe with information about the dataset.\"\"\"\n\n if self.path:\n self.img_list = glob.glob(f'{self.path}/*{self.postfix}*')\n if not self.img_list: print('Could not find images. Check the image path')\n\n with ThreadPoolExecutor(max_workers=self.max_workers) as executor:\n data_info_dict = list(executor.map(self._get_data_info, self.img_list))\n\n df = pd.DataFrame(data_info_dict)\n \n if df.orientation.nunique() > 1:\n print('The volumes in this dataset have different orientations. '\n 'Recommended to pass in the argument reorder=True when creating a MedDataset object for this dataset')\n\n return df\n\n def summary(self):\n \"\"\"Summary DataFrame of the dataset with example path for similar data.\"\"\"\n \n columns = ['dim_0', 'dim_1', 'dim_2', 'voxel_0', 'voxel_1', 'voxel_2', 'orientation']\n \n return self.df.groupby(columns, as_index=False).agg(\n example_path=('path', 'min'), total=('path', 'size')\n ).sort_values('total', ascending=False)\n\n def suggestion(self):\n \"\"\"Voxel value that appears most often in dim_0, dim_1 and dim_2, and whether the data should be reoriented.\"\"\"\n \n resample = [float(self.df.voxel_0.mode()[0]), float(self.df.voxel_1.mode()[0]), float(self.df.voxel_2.mode()[0])]\n return resample, self.reorder\n\n def _get_data_info(self, fn: str):\n \"\"\"Private method to collect information about an image file.\"\"\"\n _, o, _ = med_img_reader(fn, reorder=self.reorder, only_tensor=False, dtype=self.dtype)\n\n info_dict = {'path': fn, 'dim_0': o.shape[1], 'dim_1': o.shape[2], 'dim_2': o.shape[3],\n 'voxel_0': round(o.spacing[0], 4), 'voxel_1': round(o.spacing[1], 4), 'voxel_2': round(o.spacing[2], 4),\n 'orientation': f'{\"\".join(o.orientation)}+'}\n\n if self.dtype is MedMask:\n mask_labels_dict = o.count_labels()\n mask_labels_dict = {f'voxel_count_{int(key)}': val for key, val in mask_labels_dict.items()}\n info_dict.update(mask_labels_dict)\n\n return info_dict\n\n def get_largest_img_size(self, resample: list = None) -> list:\n \"\"\"Get the largest image size in the dataset.\"\"\"\n \n dims = None\n\n if resample is not None:\n org_voxels = self.df[[\"voxel_0\", \"voxel_1\", 'voxel_2']].values\n org_dims = self.df[[\"dim_0\", \"dim_1\", 'dim_2']].values\n\n ratio = org_voxels/resample\n new_dims = (org_dims * ratio).T\n dims = [float(new_dims[0].max().round()), float(new_dims[1].max().round()), float(new_dims[2].max().round())]\n\n else:\n dims = [float(self.df.dim_0.max()), float(self.df.dim_1.max()), float(self.df.dim_2.max())]\n\n return dims" - }, { "cell_type": "code", "execution_count": null, diff --git a/settings.ini b/settings.ini index f283124..1b44b57 100644 --- a/settings.ini +++ b/settings.ini @@ -5,7 +5,7 @@ ### Python Library ### lib_name = fastMONAI min_python = 3.10 -version = 0.5.3 +version = 0.5.4 ### OPTIONAL ### requirements = fastai==2.8.3 monai==1.5.0 torchio==0.20.19 xlrd>=1.2.0 scikit-image==0.25.2 imagedata==3.8.14 mlflow==3.3.1 huggingface-hub gdown gradio opencv-python plum-dispatch