From d891514a641d28abcdf6ca1c6b6a9a4b2ca21164 Mon Sep 17 00:00:00 2001 From: Ana Marta Sequeira Date: Thu, 28 Aug 2025 16:41:58 +0100 Subject: [PATCH 1/3] Rename package to GRIDGENE - src --- gridgene/__init__.py | 0 gridgene/binsom.py | 339 ++++++++++ gridgene/contours.py | 1099 ++++++++++++++++++++++++++++++ gridgene/get_arrays.py | 193 ++++++ gridgene/get_masks.py | 1278 +++++++++++++++++++++++++++++++++++ gridgene/logger.py | 31 + gridgene/mask_properties.py | 500 ++++++++++++++ gridgene/overlay.py | 465 +++++++++++++ 8 files changed, 3905 insertions(+) create mode 100644 gridgene/__init__.py create mode 100644 gridgene/binsom.py create mode 100644 gridgene/contours.py create mode 100644 gridgene/get_arrays.py create mode 100644 gridgene/get_masks.py create mode 100644 gridgene/logger.py create mode 100644 gridgene/mask_properties.py create mode 100644 gridgene/overlay.py diff --git a/gridgene/__init__.py b/gridgene/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gridgene/binsom.py b/gridgene/binsom.py new file mode 100644 index 0000000..d96a844 --- /dev/null +++ b/gridgene/binsom.py @@ -0,0 +1,339 @@ +""" +File for get the tum stroma mask using bins of image +and the SOM clustering +""" +import logging +import cv2 +import os +import time +import numpy as np +import pandas as pd +import scanpy as sc +import anndata as ad +from minisom import MiniSom +from itertools import product +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +from typing import List, Optional, Tuple, Dict, Any, Union +# from mpl_toolkits.axes_grid1 import make_axes_locatable +from gridgene.logger import get_logger + + +# TODO make bins overlapping? + + +class GetBins: + """ + Bin spatial transcriptomics data into grid cells and create AnnData objects. + """ + + def __init__(self, bin_size: int, unique_targets: List[str], logger: Optional[logging.Logger] = None): + """ + Initialize GetBins. + + Parameters + ---------- + bin_size : int + Size of bins in pixels. + unique_targets : List[str] + List of target genes. + logger : Optional[logging.Logger], optional + Logger instance, by default None + """ + self.bin_size = bin_size + self.unique_targets = unique_targets + self.adata = None + self.eval_som_statistical_df = None + self.logger = logger or get_logger(f'{__name__}.{contour_name or "GetContour"}') + self.logger.info("Initialized GetContour") + + def get_bin_df(self, df: pd.DataFrame, df_name: str) -> ad.AnnData: + """ + Convert a DataFrame of cells with spatial coordinates and target labels into a binned AnnData object. + + Parameters + ---------- + df : pd.DataFrame + DataFrame with columns ['X', 'Y', 'target'] representing cell positions and target labels. + df_name : str + Identifier for the dataset. + + Returns + ------- + ad.AnnData + AnnData object with spatial bins and counts per target. + """ + # Calculate grid positions + df['x_grid'] = df['X'] // self.bin_size + df['y_grid'] = df['Y'] // self.bin_size + + # Count occurrences of each target in each grid cell + quadrant_counts = df.groupby(['x_grid', 'y_grid', 'target']).size().unstack(fill_value=0) + + # Reindex to ensure all targets are included, even if they have 0 counts + quadrant_counts = quadrant_counts.reindex(columns=self.unique_targets, fill_value=0) + + # Convert the counts DataFrame to a numpy array for AnnData + quadrant_counts_array = quadrant_counts.values + + # Create an AnnData object + adata = sc.AnnData(X=quadrant_counts_array) + + # Set observation and variable (gene) names + adata.obs_names = [f"grid_{x}_{y}" for x, y in quadrant_counts.index] + adata.var_names = quadrant_counts.columns + adata.obs['name'] = df_name + + # Calculate centroid coordinates + adata.obs['x_centroid'] = df.groupby(['x_grid', 'y_grid'])['X'].mean().values * self.bin_size + adata.obs['y_centroid'] = df.groupby(['x_grid', 'y_grid'])['Y'].mean().values * self.bin_size + + # Store grid positions in the observation metadata + adata.obs['x_grid'] = [x for x, y in quadrant_counts.index] + adata.obs['y_grid'] = [y for x, y in quadrant_counts.index] + + # Store spatial information + adata.obs['x_coord'] = df.groupby(['x_grid', 'y_grid'])['X'].first().values * self.bin_size + adata.obs['y_coord'] = df.groupby(['x_grid', 'y_grid'])['Y'].first().values * self.bin_size + adata.obsm["spatial"] = adata.obs[["x_centroid", "y_centroid"]].copy().to_numpy() + + self.adata = adata + return adata + + def get_bin_cohort(self, df_list: List[pd.DataFrame], df_name_list: List[str], cohort_name: str) -> None: + """ + Process multiple datasets into binned AnnData objects and concatenate them into a cohort. + + Parameters + ---------- + df_list : List[pd.DataFrame] + List of DataFrames to process. + df_name_list : List[str] + List of dataset names corresponding to each DataFrame. + cohort_name : str + Name of the cohort to assign to all data. + """ + start_time = time.time() + adata_list = [] + for df, df_name in zip(df_list, df_name_list): + adata = self.get_bin_df(df, df_name) + adata.obs['cohort'] = cohort_name + adata_list.append(adata) + combined_adata = ad.concat(adata_list, join='outer') + self.adata = combined_adata + self.logger.info(f'Time to get bins for {len(df_list)} dataframes: {time.time() - start_time:.2f} seconds') + self.logger.info(f'Number of bins: {len(combined_adata)}') + self.logger.info(f'Number of genes: {len(combined_adata.var_names)}') + + def preprocess_bin(self, min_counts: int = 10, adata: Optional[ad.AnnData] = None) -> None: + """ + Filter and normalize the binned AnnData. + + Parameters + ---------- + min_counts : int, optional + Minimum total counts per bin to retain it, by default 10 + adata : Optional[ad.AnnData], optional + AnnData object to preprocess (defaults to internal one), by default None + """ + if adata is None: + adata = self.adata + sc.pp.filter_cells(adata, min_counts=min_counts) + adata.layers["counts"] = adata.X.copy() + adata.obs['total_counts'] = adata.X.sum(axis=1) + adata.obs['n_genes_by_counts'] = (adata.X > 0).sum(axis=1) + + sc.pp.normalize_total(adata, inplace=True) + sc.pp.log1p(adata) + self.adata = adata + + +class GetContour: + """ + Perform SOM clustering on spatial bins and evaluate clusters. + """ + + def __init__(self, adata: ad.AnnData, logger: Optional[logging.Logger] = None): + """ + Initialize GetContour. + + Parameters + ---------- + adata : ad.AnnData + AnnData object containing binned spatial transcriptomics data. + logger : Optional[logging.Logger], optional + Logger instance, by default None + """ + self.adata = adata + self.logger = logger + if logger is None: + # Configure default logger if none is provided + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + else: + self.logger = logger + + def run_som( + self, + som_shape: Tuple[int, int] = (2, 1), + n_iter: int = 5000, + sigma: float = 0.5, + learning_rate: float = 0.5, + random_state: int = 42 + ) -> None: + """ + Apply SOM clustering on the AnnData object. + + Parameters + ---------- + som_shape : Tuple[int, int], optional + Shape of the SOM grid (rows, columns), by default (2, 1) + n_iter : int, optional + Number of iterations for SOM training, by default 5000 + sigma : float, optional + Width of the Gaussian neighborhood function, by default 0.5 + learning_rate : float, optional + Learning rate for SOM training, by default 0.5 + random_state : int, optional + Random seed for reproducibility, by default 42 + """ + start = time.time() + som = MiniSom(som_shape[0], som_shape[1], self.adata.shape[1], + sigma=sigma, learning_rate=learning_rate, random_seed=random_state) + som.train_random(self.adata.X,n_iter) + + # Step 3: Assign Clusters + clusters = np.zeros(len(self.adata), dtype=int) + clusters = list(clusters) + + possible_tuples = list(product(range(som_shape[0]), range(som_shape[1]))) + table_values = list(range(len(possible_tuples))) + # Create a dictionary to map tuples to values + table_dict = {t: v for t, v in zip(possible_tuples, table_values)} + # print(table_dict) + + for i, q in enumerate(self.adata.X): + # print(som.winner(q)) #(x,y) + x, y = som.winner(q) + clusters[i] = int(table_dict.get((x, y))) + + self.adata.obs['cluster_som'] = pd.Categorical(clusters) + self.logger.info(f'Time to run som on {len(self.adata.X)} bins: {time.time() - start:.2f}') + self.logger.info(f'Number of clusters: {len(set(clusters))}') + self.logger.info(f'number of bins in each cluster: {self.adata.obs["cluster_som"].value_counts()}') + + def eval_som_statistical(self, top_n: int = 20) -> None: + """ + Compute and log top ranked features per SOM cluster. + + Parameters + ---------- + top_n : int, optional + Number of top features to retrieve for each cluster, by default 20 + """ + sc.tl.rank_genes_groups(self.adata, "cluster_som", method="t-test") + stats = [] + groups = self.adata.uns['rank_genes_groups']['names'].dtype.names + for group in groups: + df = sc.get.rank_genes_groups_df(self.adata, group) + df['group'] = group + df_sorted = df.sort_values(by='scores', ascending=False).head(top_n) + self.logger.info(f"n top genes for group {group}") + self.logger.info("\n" + df_sorted.to_string()) + stats.append(df_sorted) + self.eval_som_statistical_df = pd.concat(stats, ignore_index=True) + + def create_cluster_image(self, adata: ad.AnnData, grid_size: int) -> np.ndarray: + """ + Reconstruct an image from cluster annotations in the AnnData object. + + Parameters + ---------- + adata : ad.AnnData + AnnData object containing clustering results and grid positions. + grid_size : int + Size of each grid cell in pixels. + + Returns + ------- + np.ndarray + 2D array with cluster IDs as pixel values. + """ + # Initialize an empty image + max_x_grid = adata.obs['x_grid'].max() + max_y_grid = adata.obs['y_grid'].max() + image_shape = (int((max_x_grid + 1) * grid_size), int((max_y_grid + 1) * grid_size)) + + reconstructed_image = np.zeros(image_shape) + + # Iterate over the observations in the AnnData object + for _, row in adata.obs.iterrows(): + # Retrieve the SOM cluster and grid coordinates + cluster = row['cluster_som'] +1 + x_start = int(row['x_grid'] * grid_size) + y_start = int(row['y_grid'] * grid_size) + + # Set all pixels in the corresponding grid to the SOM cluster value + reconstructed_image[x_start:x_start + grid_size, y_start:y_start + grid_size] = cluster + + return reconstructed_image + + def plot_som( + self, + som_image: np.ndarray, + cmap: Optional[Any] = None, + path: Optional[str] = None, + show: bool = False, + figsize: Tuple[int, int] = (10, 10), + ax: Optional[plt.Axes] = None, + legend_labels: Optional[Dict[int, str]] = None + ) -> plt.Axes: + """ + Visualize the SOM cluster map. + + Parameters + ---------- + som_image : np.ndarray + 2D array representing the SOM clusters. + cmap : Optional[Any], optional + Colormap to use for visualization, by default None (uses 'tab10') + path : Optional[str], optional + Optional path to save the plot image, by default None + show : bool, optional + Whether to display the plot, by default False + figsize : Tuple[int, int], optional + Size of the figure, by default (10, 10) + ax : Optional[plt.Axes], optional + Matplotlib Axes to plot on, by default None (creates new figure) + legend_labels : Optional[Dict[int, str]], optional + Dictionary mapping cluster indices to labels for legend, by default None + + Returns + ------- + plt.Axes + The matplotlib Axes object containing the plot. + """ + if ax is None: + plt.figure(figsize=figsize) + ax = plt.gca() + ax.imshow(som_image, cmap=cmap, interpolation='none', origin='lower') + ax.set_title('SOM clustering') + ax.set_xlabel('X-axis') + ax.set_ylabel('Y-axis') + + if legend_labels: + # Create custom legend handles + handles = [mpatches.Patch(color=cmap(idx / max(legend_labels.keys())), label=label) + for idx, label in legend_labels.items()] + # ax.legend(handles=handles, loc='upper right', title="Clusters") + ax.legend(handles=handles, loc='center left', bbox_to_anchor=(1.05, 0.5), title="Clusters") + + if path is not None: + save_path = os.path.join(path, 'SOM_clustering.png') + plt.savefig(save_path, dpi=1000, bbox_inches='tight') + self.logger.info(f'Plot saved at {save_path}') + + if show: + plt.show() + + return ax \ No newline at end of file diff --git a/gridgene/contours.py b/gridgene/contours.py new file mode 100644 index 0000000..2da1288 --- /dev/null +++ b/gridgene/contours.py @@ -0,0 +1,1099 @@ +import logging +import cv2 +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable +import os +from scipy.signal import convolve2d +from scipy.ndimage import convolve1d +from scipy.spatial import ConvexHull +import alphashape +from shapely.geometry import Polygon +import time +from scipy.spatial import KDTree +from sklearn.cluster import DBSCAN +from gridgene.logger import get_logger +from typing import Optional, Tuple, List, Dict, Any, Union +from matplotlib.figure import Figure + +from matplotlib.axes import Axes +from functools import wraps +from sklearn.neighbors import BallTree + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + elapsed = end - start + + # Try to log using `self.logger` if available + instance = args[0] # self, assuming it's a method + logger = getattr(instance, 'logger', None) + if logger: + logger.info(f"{func.__name__} took {elapsed:.4f} seconds") + else: + print(f"{func.__name__} took {elapsed:.4f} seconds") + + return result + return wrapper + +class GetContour: + """ + Parent class for contour handling and filtering from a 3D array. + + This class handles extraction and filtering of contours from a 3D array where the first two + dimensions represent spatial coordinates (x and y), and the third dimension contains gene-specific values. + + Attributes + ---------- + array : np.ndarray + The 3D array from which contours are to be extracted. The first two dimensions are x and y + spatial positions; the third dimension corresponds to genes. + local_sum_image : np.ndarray + The 2D array representing the local sum of the input array. + contours : list of np.ndarray + List of contours extracted from the array. + contour_name : str + Name of the contour for identification. + total_valid_contours : int + Total number of valid contours after filtering. + contours_filtered_area : int + Number of contours remaining after area filtering. + logger : logging.Logger + Logger instance for logging information and errors. + points_x_y : np.ndarray, optional + Optional 2D array of shape (N, 2) containing (x, y) points for plotting or further analysis. + + Methods + ------- + __init__(array_to_contour, logger=None, contour_name=None, points_x_y=None) + Initializes the GetContour object. + check_contours() + Validates and closes contours. + filter_contours_area(min_area_threshold) + Filters contours based on a minimum area threshold. + filter_contours_no_counts() + Filters out contours with no transcript counts in the array. + filter_contours_by_gene_threshold(gene_array, threshold, gene_name="") + Filters contours based on gene-specific thresholds. + filter_contours_by_gene_comparison(gene_array1, gene_array2, gene_name1="", gene_name2="") + Filters contours where gene_array1 has higher signal than gene_array2. + plot_contours_scatter(...) + Creates a scatter plot with overlaid contours. + plot_conv_sum(...) + Plots the convolution sum image with contours. + """ + + def __init__(self, array_to_contour, logger=None, contour_name=None, points_x_y: np.ndarray = None): + """ + Initialize the contour handler with a 3D input array. + + Parameters + ---------- + array_to_contour : np.ndarray + 3D array from which contours are to be derived. + logger : logging.Logger, optional + Optional logger instance; a default logger will be used if None. + contour_name : str, optional + Optional name for this set of contours. + points_x_y : np.ndarray, optional + Optional array of (x, y) positions for scatter plotting. + """ + self.array = array_to_contour + self.contours = None + self.contour_name = contour_name + self.total_valid_contours = 0 + self.contours_filtered_area = 0 + self.logger = logger or get_logger(f'{__name__}.{contour_name or "GetContour"}') + self.logger.info("Initialized GetContour") + self.points_x_y = points_x_y + + ############################# + # Filtering of contours + + def check_contours(self) -> None: + """ + Validate and clean contour data. + + Ensures each contour has at least 3 points, closes open contours, + and converts them to integer format for OpenCV compatibility. + """ + + if not self.contours: + self.logger.info("No contours to check.") + return + + filtered_and_closed_contours = [] + for contour in self.contours: + squeezed = np.squeeze(contour) + if squeezed.ndim != 2 or squeezed.shape[0] < 3 or squeezed.shape[1] < 2: + continue + if not np.array_equal(squeezed[0], squeezed[-1]): + squeezed = np.vstack([squeezed, squeezed[0]]) + filtered_and_closed_contours.append(squeezed.astype(np.int32)) + + self.contours = filtered_and_closed_contours + + def filter_contours_area(self, min_area_threshold: float) -> None: + """ + Filter contours by minimum area. + + Parameters + ---------- + min_area_threshold : float + Contours with area below this threshold are discarded. + """ + self.contours = [contour for contour in self.contours if cv2.contourArea(contour) >= min_area_threshold] + # self.contours_filtered_area = len(self.contours) + self.logger.info(f'Number of contours after area filtering: {len(self.contours)}') + + def filter_contours_no_counts(self) -> List[np.ndarray]: + """ + Remove contours that do not contain any non-zero values in the array. + + Returns + ------- + List[np.ndarray] + Contours that contain non-zero signal in the original array. + """ + # todo check what is more efficient + array2d = np.sum(self.array, axis=2) + valid_contours = [] + + mask_all = np.zeros_like(array2d, dtype=np.uint8) + cv2.drawContours(mask_all, self.contours, -1, 1, thickness=cv2.FILLED) + # Multiply once to get the masked array + masked_array2d = array2d * mask_all + + # valid_contours = [ + # contour for contour in self.contours + # if np.sum(cv2.drawContours( + # np.zeros_like(array2d, dtype=np.uint8), [contour], -1, 1, thickness=cv2.FILLED) * masked_array2d) > 0 + # ] + for contour in self.contours: + x, y, w, h = cv2.boundingRect(contour) # bounding box of contour + roi = array2d[y:y + h, x:x + w] # region of interest in array2d + + mask = np.zeros((h, w), dtype=np.uint8) + # Shift contour points to roi coords + shifted_contour = contour - [x, y] + cv2.drawContours(mask, [shifted_contour], -1, 1, thickness=cv2.FILLED) + + if np.sum(roi * mask) > 0: + valid_contours.append(contour) + + self.contours = valid_contours + self.logger.info(f'Number of contours after filtering no counts: {len(self.contours)}') + return self.contours + + def filter_contours_no_counts_and_area(self, min_area_threshold: float) -> List[np.ndarray]: + """ + Filter contours based on both signal presence and minimum area. + + Parameters + ---------- + min_area_threshold : float + Minimum area a contour must have to be retained. + + Returns + ------- + List[np.ndarray] + Valid contours satisfying both count and area criteria. + """ + if self.array.ndim == 3: + array2d = np.sum(self.array, axis=2) + elif self.array.ndim == 2: + array2d = self.array + else: + raise ValueError(f"Unexpected array shape: {array.shape}") + valid_contours = [] + + for contour in self.contours: + area = cv2.contourArea(contour) + if area < min_area_threshold: + continue + + x, y, w, h = cv2.boundingRect(contour) + roi = array2d[y:y + h, x:x + w] + + mask = np.zeros((h, w), dtype=np.uint8) + shifted_contour = contour - [x, y] + cv2.drawContours(mask, [shifted_contour], -1, 1, thickness=cv2.FILLED) + + if np.sum(roi * mask) > 0: + valid_contours.append(contour) + + self.contours = valid_contours + # self.contours_filtered_area = len(self.contours) + self.logger.info(f'Number of contours after filtering no counts: {len(self.contours)}') + return self.contours + + def filter_contours_by_gene_threshold( + self, + gene_array: np.ndarray, + threshold: float, + gene_name: Optional[str] = "" + ) -> None: + """ + Retain contours where the gene signal meets a minimum threshold. + + Parameters + ---------- + gene_array : np.ndarray + 2D or 3D array of gene expression values. + threshold : float + Minimum gene signal required inside the contour. + gene_name : str, optional + Optional name of the gene, used for logging. + """ + valid_contours = [] + for i, contour in enumerate(self.contours): + mask_ = np.zeros((gene_array.shape[0], gene_array.shape[1]), dtype=np.uint8) + cv2.drawContours(mask_, [contour], -1, 1, thickness=cv2.FILLED) + gene_count = np.sum(gene_array * mask_) + if gene_count >= threshold: + valid_contours.append(contour) + else: + self.logger.info( + f'Excluding contour {i}. Gene {gene_name} count {gene_count} is below threshold {threshold}') + self.contours = valid_contours + self.logger.info(f'Number of contours remaining: {len(valid_contours)}') + + def filter_contours_by_gene_comparison( + self, + gene_array1: np.ndarray, + gene_array2: np.ndarray, + gene_name1: Optional[str] = "", + gene_name2: Optional[str] = "" + ) -> None: + """ + Filter contours by comparing signal from two gene arrays. + + Only retain contours where the first gene has higher counts than the second. + + Parameters + ---------- + gene_array1 : np.ndarray + First gene array. + gene_array2 : np.ndarray + Second gene array. + gene_name1 : str, optional + Name for the first gene (used in logging). + gene_name2 : str, optional + Name for the second gene (used in logging). + """ + # Ensure arrays are 2D by summing if needed + if gene_array1.ndim == 3: + gene_array1 = np.sum(gene_array1, axis=-1) + if gene_array2.ndim == 3: + gene_array2 = np.sum(gene_array2, axis=-1) + + height, width = gene_array1.shape + valid_contours = [] + for i, contour in enumerate(self.contours): + mask_ = np.zeros((gene_array1.shape[0], gene_array1.shape[1]), dtype=np.uint8) + cv2.drawContours(mask_, [contour], -1, 1, thickness=cv2.FILLED) + gene_count1 = np.sum(gene_array1 * mask_) + gene_count2 = np.sum(gene_array2 * mask_) + if gene_count1 > gene_count2: + valid_contours.append(contour) + else: + self.logger.info( + f'Excluding contour {i}. ' + f'{gene_name1 or "Gene1"} count {gene_count1:.2f} ' + f'≤ {gene_name2 or "Gene2"} count {gene_count2:.2f}' + ) + self.contours = valid_contours + self.logger.info(f'Contours remaining after gene comparison: {len(valid_contours)}') + + # Plotting + def plot_contours_scatter( + self, + path: Optional[str] = None, + show: bool = False, + s: float = 0.1, + alpha: float = 0.5, + linewidth: float = 1, + c_points: str = 'blue', + c_contours: str = 'red', + figsize: Tuple[int, int] = (10, 10), + ax: Optional[Axes] = None, + **kwargs: Dict[str, Any] + ) -> Axes: + """ + Plot a scatter plot of spatial points overlaid with contours. + + Parameters + ---------- + path : str, optional + Directory where the plot will be saved (if specified). + show : bool + Whether to display the plot interactively. + s : float + Size of scatter points. + alpha : float + Transparency of scatter points. + linewidth : float + Width of contour lines. + c_points : str + Color for scatter points. + c_contours : str + Color for contour lines. + figsize : Tuple[int, int] + Size of the figure. + ax : matplotlib.axes.Axes, optional + Axes on which to plot; if None, a new figure is created. + **kwargs : dict + Additional keyword arguments for customizing scatter and line plots. + + Returns + ------- + Axes + The matplotlib Axes object used for plotting. + """ + if self.points_x_y is not None: + x = self.points_x_y[:, 0].astype(int) # X column + y = self.points_x_y[:, 1].astype(int) # Y column + else: + x, y = np.where(np.sum(self.array, axis=2) > 0) + + if ax is None: + plt.figure(figsize=figsize) + ax = plt.gca() + + # Extract specific kwargs for scatter and plot if provided + scatter_kwargs = kwargs.get('scatter_kwargs', {}) + plot_kwargs = kwargs.get('plot_kwargs', {}) + + # Scatter plot with original coordinates + ax.scatter(x, y, c=c_points, marker='.', s=s, alpha=alpha, **scatter_kwargs) + + # Rescale and plot the contours + for contour in self.contours: + ax.plot(contour[:, 1], contour[:, 0], linewidth=linewidth, color=c_contours, **plot_kwargs) + + ax.set_title(f'Scatter with contours and genes {self.contour_name}') + + if path is not None: + save_path = os.path.join(path, f'Scatter_contours_{self.contour_name}.png') + plt.savefig(save_path, dpi=1000, bbox_inches='tight') + self.logger.info(f'Plot saved at {save_path}') + + if show: + plt.show() + + return ax + + def plot_conv_sum( + self, + cmap: str = 'plasma', + c_countour: str = 'white', + path: Optional[str] = None, + show: bool = False, + figsize: Tuple[int, int] = (10, 10), + ax: Optional[Axes] = None + ) -> Axes: + """ + Plot the local sum (convolution) image with contour overlays. + + Parameters + ---------- + cmap : str + Colormap for the local sum image. + c_countour : str + Color for overlaying contours. + path : str, optional + Path to save the figure (if specified). + show : bool + Whether to display the plot interactively. + figsize : Tuple[int, int] + Size of the figure. + ax : matplotlib.axes.Axes, optional + Axes on which to plot; if None, a new figure is created. + + Returns + ------- + Axes + The matplotlib Axes object used for plotting. + """ + if ax is None: + plt.figure(figsize=figsize) + ax = plt.gca() + + im = ax.imshow(self.local_sum_image.T, cmap=cmap, interpolation='none', origin='lower') + ax.set_title(f'Count distribution with contours for {self.contour_name}') + ax.set_xlabel('X-axis') + ax.set_ylabel('Y-axis') + + # Rescale and plot the contours + for contour in self.contours: + ax.plot(contour[:, 1], contour[:, 0], linewidth=1, color=c_countour) + + # Add a colorbar for the colormap + # cbar = plt.colorbar(im, ax=ax) + # cbar.set_label('Color scale', rotation=270) + # Create a divider for the existing axes instance + divider = make_axes_locatable(ax) + # Append axes to the right of ax, with 5% width of ax + cax = divider.append_axes("right", size="5%", pad=0.05) + + # Create colorbar in the appended axes + # `cax` argument places the colorbar in the cax axes + cbar = plt.colorbar(im, cax=cax) + cbar.set_label('Local Transcript Sum', rotation=270, labelpad=20) + + if path is not None: + save_path = os.path.join(path, f'count_dist_contours_{self.contour_name}.png') + plt.savefig(save_path, dpi=1000, bbox_inches='tight') + + if show: + plt.show() + + return ax + +class ConvolutionContours(GetContour): + """ + A subclass of GetContour for generating contours based on a convolution + of the input array. This class provides methods for computing a local + density map and extracting contours based on intensity thresholds. + + Attributes + ---------- + local_sum_image : np.ndarray or None + 2D array containing the result of convolution on the input data, + used as a basis for contour detection. + """ + + def __init__(self, array_to_contour: np.ndarray, logger=None, contour_name: Optional[str] = None): + """ + Initialize the ConvolutionContours instance. + + Parameters + ---------- + array_to_contour : np.ndarray + 3D input array used to compute the convolution-based contours. + logger : logging.Logger, optional + Optional logger instance for debugging and logging purposes. + contour_name : str, optional + Optional identifier for this contour set. + """ + # Initialize parent class attributes + super().__init__(array_to_contour, logger, contour_name) + # Initialize subclass-specific attributes + self.local_sum_image: Optional[np.ndarray] = None + + @timeit + def get_conv_sum(self, kernel_size: int, kernel_shape: str = 'square') -> None: + """ + Compute a 2D convolution sum across the 3D input array to create a density map. + + Parameters + ---------- + kernel_size : int + The size of the convolution kernel. + kernel_shape : str, optional + Shape of the kernel: either 'square' or 'circle'. Defaults to 'square'. + + Raises + ------ + ValueError + If `kernel_shape` is not one of {'square', 'circle'}. + """ + if kernel_shape not in {'square', 'circle'}: + raise ValueError("kernel_shape must be either 'square' or 'circle'.") + + kernel = (np.ones((kernel_size, kernel_size), dtype=np.float32) + if kernel_shape == 'square' + else cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))) + + array_sum = np.sum(self.array, axis=2) + # self.local_sum_image = cv2.filter2D(array_sum, -1, kernel) + self.local_sum_image = cv2.filter2D(array_sum, -1, kernel, borderType=cv2.BORDER_REFLECT_101) + + del array_sum + + @timeit + def contours_from_sum( + self, + density_threshold: float, + min_area_threshold: float, + directionality: str = 'higher' + ) -> None: + """ + Generate contours from the convolution sum using a threshold, and filter by area. + + Parameters + ---------- + density_threshold : float + The threshold applied to the convolution image to create a binary mask. + min_area_threshold : float + Minimum area that a contour must have to be retained. + directionality : str, optional + Direction of thresholding: + - 'higher': select pixels greater than the threshold + - 'lower': select pixels less than the threshold + Default is 'higher'. + + Raises + ------ + RuntimeError + If the convolution sum image (`local_sum_image`) has not been computed. + ValueError + If `directionality` is not 'higher' or 'lower'. + """ + if self.local_sum_image is None: + raise RuntimeError("local_sum_image is not computed. Run get_conv_sum() first.") + + if directionality == 'higher': + binary_mask = (self.local_sum_image > density_threshold).astype(np.uint8) + elif directionality == 'lower': + binary_mask = (self.local_sum_image < density_threshold).astype(np.uint8) + else: + raise ValueError("directionality must be either 'higher' or 'lower'.") + + contours, _ = cv2.findContours(binary_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + self.contours = contours + + self.check_contours() + self.filter_contours_no_counts_and_area(min_area_threshold) + # self.logger.info(f'Contours extracted from sum after checks: {len(self.contours)}') + + +class KDTreeContours(GetContour): + """ + A subclass of GetContour for generating contours from spatial point data + using KD-tree or BallTree-based neighbor counts, clustering, and geometric hulls. + + This class supports estimating local point densities via neighbor search, + extracting density-based arrays, applying clustering (e.g., DBSCAN), + and generating contours from those clusters using circle or concave hulls. + + Attributes + ---------- + kd_tree_data : pd.DataFrame + Input coordinate data, with 'X' and 'Y' columns. + points_x_y : np.ndarray + Array of shape (n_samples, 2) containing the input (X, Y) coordinates. + height : int + Height of the image space (used to define output array size). + width : int + Width of the image space. + image_size : tuple + Tuple of (height+1, width+1) defining the output array dimensions. + radius : float + Radius used for neighborhood calculations and clustering. + """ + + def __init__( + self, + kd_tree_data: Union[pd.DataFrame, np.ndarray], + logger=None, + contour_name: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + ): + """ + Initialize the KDTreeContours instance with spatial point data. + + Parameters + ---------- + kd_tree_data : Union[pd.DataFrame, np.ndarray] + Input data with spatial X, Y coordinates. If ndarray, it must be of shape (n, 2). + logger : logging.Logger, optional + Optional logger for debugging or information output. + contour_name : str, optional + Optional name for labeling neighbor count columns and logs. + height : int, optional + Optional height (Y-extent) of the output image grid. Defaults to max Y in data. + width : int, optional + Optional width (X-extent) of the output image grid. Defaults to max X in data. + """ + # Coerce kd_tree_data to DataFrame if ndarray: + self.radius = None + if isinstance(kd_tree_data, np.ndarray): + kd_tree_data = pd.DataFrame(kd_tree_data, columns=["X", "Y"]) + assert isinstance(kd_tree_data, pd.DataFrame), "kd_tree_data must be DataFrame" + + # Set attributes + self.kd_tree_data = kd_tree_data.copy() + self.points_x_y = self.kd_tree_data[["X", "Y"]].to_numpy() + self.contour_name = contour_name or "contour" + self.height = height or int(self.kd_tree_data["Y"].max()) + self.width = width or int(self.kd_tree_data["X"].max()) + self.image_size = (self.height + 1, self.width + 1) + + super().__init__( + kd_tree_data, + logger, + self.contour_name, + self.points_x_y, + ) + + @timeit + def get_kdt_dist(self, radius: int) -> None: + """ + Compute the number of neighbors for each point using BallTree within a given radius. + + The neighbor count is stored in a new column of `kd_tree_data` named + '{contour_name}_neighbor_count'. + + Parameters + ---------- + radius : int + Search radius used to define neighborhood around each point. + """ + self.radius = radius # max_dist + # # Query neighbors within the radiu + ball_tree = BallTree(self.points_x_y) + neighbor_counts = ball_tree.query_radius(self.points_x_y, self.radius) + + self.kd_tree_data[f'{self.contour_name}_neighbor_count'] = np.array( + [len(neighbors) for neighbors in neighbor_counts] + ) + + @timeit + def get_neighbour_array(self) -> np.ndarray: + """ + Create a 2D array where each pixel value corresponds to the neighbor count + at the rounded integer (X, Y) location of the points. + + Returns + ------- + np.ndarray + 2D array with shape (height+1, width+1) where each pixel represents + the number of neighbors for that spatial location. + """ + self.array_total_nei = np.zeros((self.height + 1, self.width + 1)) + + # Get rounded integer indices as NumPy arrays + x_indices = np.round(self.kd_tree_data['X']).astype(int).to_numpy() + y_indices = np.round(self.kd_tree_data['Y']).astype(int).to_numpy() + values = self.kd_tree_data[f'{self.contour_name}_neighbor_count'].to_numpy() + # Assign values directly using advanced indexing + self.array_total_nei[x_indices, y_indices] = values + + return self.array_total_nei + + def interpolate_array(self) -> np.ndarray: + """ + Fill zero values in the neighbor count array using OpenCV inpainting. + + This helps in smoothing sparse or missing regions in the data grid. + + Returns + ------- + np.ndarray + Inpainted version of `array_total_nei`. + """ + assert hasattr(self, "array_total_nei"), "Call get_neighbour_array first" + + # Convert zeros to NaN to create a mask + mask = (self.array_total_nei == 0).astype(np.uint8) # Mask of missing values + self.array_total_nei = cv2.inpaint(self.array_total_nei.astype(np.float32), mask, inpaintRadius=3, + flags=cv2.INPAINT_TELEA) + return self.array_total_nei + + # same as covcontours_from_sum. change this future version + @timeit + def contours_from_neighbors(self, density_threshold: float, min_area_threshold: float, + directionality: str = 'higher') -> None: + """ + Extract contours from a local sum image using a density threshold. + + Parameters + ---------- + density_threshold : float + Density threshold for extracting contours. + min_area_threshold : float + Minimum area threshold to keep a contour. + directionality : str, optional + Direction to threshold ('higher' or 'lower'), by default 'higher'. + + Returns + ------- + None + """ + if self.array_total_nei is None: + raise RuntimeError("local_sum_image is not computed. Run get_conv_sum() first.") + + # check give the same name + self.local_sum_image = self.array_total_nei.copy() + self.array = self.array_total_nei.copy()[...,np.newaxis] # Add a new axis to make it 3D + + if directionality == 'higher': + binary_mask = (self.array_total_nei > density_threshold).astype(np.uint8) + elif directionality == 'lower': + binary_mask = (self.array_total_nei < density_threshold).astype(np.uint8) + else: + raise ValueError("directionality must be either 'higher' or 'lower'.") + + contours, _ = cv2.findContours(binary_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + self.contours = contours + + self.check_contours() + self.filter_contours_no_counts_and_area(min_area_threshold) + self.logger.info(f'Contours extracted from neighboor counts after checks: {len(self.contours)}') + + def find_points_with_neighoors(self, radius: float, min_neighbours: int) -> None: + """ + Find points with neighbors within a given radius using KDTree. + + Parameters + ---------- + radius : float + Search radius to find neighbors around each point. + min_neighbours : int + Minimum number of neighbors for a point to be considered valid. + + Returns + ------- + None + """ + self.radius = radius # todo add check for array_points + self.min_neighbours = min_neighbours + + # Initialize KDTree with the array of points + kd_tree = KDTree(self.points_x_y) + + # Count points within the radius for each point + counts = [len(kd_tree.query_ball_point(p, radius)) for p in self.points_x_y] + # Filter points with more than 2 neighbors + filtered_points = self.points_x_y[np.array(counts) > min_neighbours] + self.logger.info("Points with more than %d neighbors: %d", min_neighbours, len(filtered_points)) + + self.points_w_neig = filtered_points + if self.points_w_neig.ndim == 3: + self.points_w_neig = self.points_w_neig[:, :2] + + if len(filtered_points) == 0: + self.logger.WARNING("No points with neighbors found within the given radius.") + + def label_points_with_neigbors(self) -> None: + """ + Label clustered points with DBSCAN based on neighbor relationships. + + Uses the radius and minimum number of neighbors to identify point clusters. + + Returns + ------- + None + """ + # eps = 60 # Search radius (similar to the one used in KDTree) + min_samples = max(self.min_neighbours, 2) # Minimum number of points in a cluster + # is going to be min_neighboors, or in case is 1, 2 + + # Extract only the first two dimensions (x, y) if the points are in 3D + # Initialize DBSCAN with the given parameters + db = DBSCAN(eps=self.radius, min_samples=min_samples) + self.dbscan_labels = db.fit_predict(self.points_w_neig) + self.logger.info("Points w/ neig agglomerated in DBSCAN labels: %d", len(self.dbscan_labels)) + + def contours_from_kd_tree_simple_circle(self) -> None: + """ + Create contours using circular masks centered on DBSCAN clusters. + + A circle is drawn around each cluster, with a radius covering all points. + + Returns + ------- + None + """ + contours_list = [] + unique_labels = set(self.dbscan_labels) + for label in unique_labels: + if label == -1: # Skip noise points + continue + + # Select points in this cluster + cluster_points = self.points_w_neig[self.dbscan_labels == label] + + # Compute the centroid of the cluster + centroid = np.mean(cluster_points, axis=0) + + # Compute the maximum distance from the centroid to any point in the cluster + distances = np.linalg.norm(cluster_points - centroid, axis=1) + radius = np.max(distances) # This ensures all points are within the circle + + # Ensure the minimum radius is 20 pixels (minimum diameter of 40) + radius = max(radius, self.radius // 2) # todo problem here! + + # Create an image to draw the circle and convert it to a contour + circle_image = np.zeros(self.image_size, dtype=np.uint8) + center = (int(centroid[1]), int(centroid[0])) # Convert to (x, y) format for OpenCV + cv2.circle(circle_image, center, int(radius), (255), thickness=-1) # Fill the circle + + # Find contours from the circle mask + contours, _ = cv2.findContours(circle_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours_list.append(contours[0]) # Only need the external contour + + self.contours = contours_list + self.logger.info("N contours: %d", len(self.contours)) + del contours_list + + def contours_from_kd_tree_concave_hull(self, alpha: float = 0.1) -> None: + """ + Create contours using concave hulls (alpha shapes) on DBSCAN clusters. + + Parameters + ---------- + alpha : float + Alpha parameter controlling the shape tightness. Smaller values yield tighter shapes. + + Returns + ------- + None + """ + alpha = max(0.05, max(1.0, len(self.points_w_neig) / 1000)) + alpha = 10000 + print(alpha) + + contours_list = [] + unique_labels = set(self.dbscan_labels) + + for label in unique_labels: + if label == -1: # Skip noise points + continue + + # Select points in this cluster + cluster_points = self.points_w_neig[self.dbscan_labels == label] + + if len(cluster_points) < 3: + continue # Skip clusters with fewer than 3 points + + # Compute the concave hull using alphashape + concave_hull = alphashape.alphashape(cluster_points, alpha) + + # Ensure the concave hull is valid + if isinstance(concave_hull, Polygon): + # Extract exterior coordinates of the hull as a contour + contour = np.array(concave_hull.exterior.coords) + contours_list.append(contour) + + # Store the contours and log + self.contours = contours_list + self.logger.info("Generated %d concave hull contours", len(self.contours)) + + def contours_from_kd_tree_complex_hull(self) -> None: + """ + Create contours using convex hulls from DBSCAN clusters. + + Uses scipy.spatial.ConvexHull to wrap cluster points. + + Returns + ------- + None + """ + # todo not working properly + contours_list = [] + unique_labels = set(self.dbscan_labels) + + for label in unique_labels: + if label == -1: # Skip noise points + continue + + # Select points in this cluster + cluster_points = self.points_w_neig[self.dbscan_labels == label] + + ## Handle clusters with fewer than 3 points + if len(cluster_points) < 3: + self.logger.warning("Cluster with label %d has less than 3 points; skipping.", label) + continue + + try: + # Compute the convex hull for the points + hull = ConvexHull(cluster_points) + hull_points = cluster_points[hull.vertices] # Get the points on the hull + + # Ensure the contour is closed by adding the first point to the end + hull_points = np.vstack([hull_points, hull_points[0]]) + + # Convert the points into an OpenCV-compatible format + contour = hull_points.astype(np.int32) + + # Add the contour to the list + contours_list.append(contour) + + except Exception as e: + self.logger.error("Error processing cluster %d: %s", label, str(e)) + continue + + self.contours = contours_list + self.logger.info("N contours: %d", len(self.contours)) + + # def refine_contours(self): + # pass + + def get_contours_around_points_with_neighboors(self, type_contouring: str = 'simple_circle') -> None: + """ + Generate contours around points with neighbors using KDTree + DBSCAN. + + Parameters + ---------- + type_contouring : str, default='simple_circle' + Method for generating contours: + - 'simple_circle' : Circles around clusters. + - 'complex_hull' : Convex hulls around clusters. + - 'concave_hull' : Concave hulls (alpha shapes) around clusters. + + Returns + ------- + None + """ + self.label_points_with_neigbors() # 2. Label close points with DBSCAN + + # 3. Create contours from the labeled points + if type_contouring == 'simple_circle': + self.contours_from_kd_tree_simple_circle() + elif type_contouring == 'complex_hull': + self.contours_from_kd_tree_complex_hull() + elif type_contouring == 'concave_hull': + self.contours_from_kd_tree_concave_hull() + # 4. Check contours + self.check_contours() + self.total_valid_contours = len(self.contours) + + def plot_point_clusters_with_contours(self, show: bool = False, figsize: Tuple = (10, 10)) -> plt.Figure: + """ + Plot DBSCAN clusters with overlayed contour boundaries. + + Parameters + ---------- + show : bool, default=False + Whether to call `plt.show()` immediately. + figsize : tuple, default=(10, 10) + Size of the plot in inches. + + Returns + ------- + matplotlib.figure.Figure + The figure object for the plot. + """ + + if not hasattr(self, "contours"): # Validate that required attributes exist + raise AttributeError("`self.contours` is not defined. Run contour-generation first.") + if not hasattr(self, "points_w_neig") or not hasattr(self, "dbscan_labels"): + raise AttributeError("`self.points_w_neig` or `self.dbscan_labels` not defined. Run DBSCAN steps first.") + + # Create figure and axis + fig, ax = plt.subplots(figsize=figsize) + + # Plot each contour + for idx, contour in enumerate(self.contours): + # contour may have shape (N, 1, 2) or (N, 2). Squeeze the singleton dim if present. + arr = contour + try: + # If shape is (N, 1, 2), squeeze middle dimension + if arr.ndim == 3 and arr.shape[1] == 1: + arr = arr.squeeze(1) # now (N, 2) + except Exception: + raise ValueError(f"Contour at index {idx} has unexpected shape {contour.shape}") + + if arr.ndim != 2 or arr.shape[1] != 2: + # Skip or raise error; here we choose to skip with a warning + # You may replace with: raise ValueError(...) + import warnings + warnings.warn(f"Skipping contour at index {idx}: expected shape (N,2) after squeeze, got {arr.shape}") + continue + + # Plot: x = arr[:, 0], y = arr[:, 1] + ax.plot(arr[:, 0], arr[:, 1], color='blue', linestyle='-', alpha=0.7, label="_nolegend_") + + # Plot cluster points (skip noise label = -1) + labels = self.dbscan_labels + pts = self.points_w_neig + # Optionally, you could collect unique labels except -1: + unique_labels = sorted(set(labels) - {-1}) + for label in unique_labels: + mask = (labels == label) + cluster_points = pts[mask] + if cluster_points.size == 0: + continue + ax.scatter(cluster_points[:, 0], cluster_points[:, 1], + label=f"Cluster {label}", alpha=0.7, s=10) + + # If desired, plot centroids: + # centroid = cluster_points.mean(axis=0) + # ax.scatter(centroid[0], centroid[1], color='red', marker='x', + # label=f"Centroid {label}") + + ax.set_title("Clusters with Boundaries (Contours)") + ax.set_xlabel("X-axis") + ax.set_ylabel("Y-axis") + # You can enable legend if desired: + # ax.legend(loc='best', fontsize='small') + + if show: + plt.show() + + return plt + + def plot_dbscan_labels(self, show: bool = True, figsize: Tuple[int, int] = (8, 8)) -> Figure: + """ + Plot points colored by DBSCAN cluster labels. + + Parameters + ---------- + show : bool, default=True + Whether to call `plt.show()` immediately. + figsize : tuple of int, default=(8, 8) + Size of the figure in inches. + + Returns + ------- + matplotlib.figure.Figure + The figure object for the plot. + """ + # Validate attributes exist + if not hasattr(self, "points_w_neig") or not hasattr(self, "dbscan_labels"): + raise AttributeError( + "`self.points_w_neig` and `self.dbscan_labels` must be set before calling plot_dbscan_labels.") + + points = self.points_w_neig + labels = self.dbscan_labels + + # Basic shape/length check + if not hasattr(points, "shape") or points.ndim != 2 or points.shape[1] < 2: + raise ValueError( + f"`self.points_w_neig` must be an array of shape (N, 2); got shape {getattr(points, 'shape', None)}") + if labels.shape[0] != points.shape[0]: + raise ValueError(f"Length mismatch: points length {points.shape[0]}, labels length {labels.shape[0]}") + + # Create figure and axis + fig, ax = plt.subplots(figsize=figsize) + + # Determine unique labels + unique_labels = sorted(set(labels)) + # Use tab20 colormap to get distinct colors + colors = plt.cm.tab20(np.linspace(0, 1, len(unique_labels))) + + for label, color in zip(unique_labels, colors): + mask = (labels == label) + pts = points[mask] + if pts.size == 0: + continue + + if label == -1: + # Noise points: marker 'x' + ax.scatter( + pts[:, 0], pts[:, 1], + c=[color], label="Noise", marker="x", s=20, alpha=0.6 + ) + else: + ax.scatter( + pts[:, 0], pts[:, 1], + c=[color], label=f"Cluster {label}", marker="o", s=20, alpha=0.6 + ) + + ax.set_title("DBSCAN Clusters and Noise Points") + ax.set_xlabel("X Coordinate") + ax.set_ylabel("Y Coordinate") + ax.legend(loc="best", fontsize="small") + ax.grid(True) + + if show: + plt.show() + + return plt \ No newline at end of file diff --git a/gridgene/get_arrays.py b/gridgene/get_arrays.py new file mode 100644 index 0000000..48ba201 --- /dev/null +++ b/gridgene/get_arrays.py @@ -0,0 +1,193 @@ +import pandas as pd +import numpy as np +import timeit + +def transform_df_to_array(df: pd.DataFrame, target_dict: dict, array_shape: tuple) -> np.ndarray: + """ + Transforms a DataFrame into a 3D numpy array based on specified target dictionary and array shape. + + Parameters + ---------- + df : pd.DataFrame + The input DataFrame containing 'X', 'Y', and 'target' columns. + target_dict : dict + A dictionary mapping target values to unique indices. + array_shape : tuple + The shape of the output array (max(X)+1, max(Y)+1, number of targets). + + Returns + ------- + np.ndarray + A 3D numpy array with dimensions specified by array_shape, where each position [x, y, target_index] is set to 1 + if there is an entry in the DataFrame with coordinates (x, y) and the corresponding target. + """ + + # Create a numpy array of zeros with the specified shape + output_array = np.zeros(array_shape, dtype=np.int8) + + # Map the target values to their indices using the target_dict + target_indices = df['target'].map(target_dict).values + + # Extract x and y coordinates + x_coords = df['X'].astype(int).values + y_coords = df['Y'].astype(int).values + + # Set the appropriate positions in the output array to 1 using advanced indexing + output_array[x_coords, y_coords, target_indices] = 1 + + return output_array + + + + + +def get_subset_arrays_V1(df_total: pd.DataFrame, target_list: list, target_col: str = 'target', + col_x: str = 'X', col_y: str = 'Y') -> tuple: + """ + PROBABLY LESS EFFICIENT ! + + Filters the DataFrame based on target_list, then creates and returns a subset DataFrame, a dictionary of target mappings, + a 3D array representing the data, and a 2D summed array along the third axis. + + Parameters + ---------- + df_total : pd.DataFrame + The input DataFrame containing the data. + target_list : list + List of target values to filter the DataFrame. + target_col : str, optional + Column name in the DataFrame containing target values, by default 'target'. + col_x : str, optional + Column name in the DataFrame representing the X-coordinate, by default 'X'. + col_y : str, optional + Column name in the DataFrame representing the Y-coordinate, by default 'Y'. + + Returns + ------- + tuple + A tuple containing: + - df_subset (pd.DataFrame): The filtered DataFrame. + - target_dict_subset (dict): A dictionary mapping each target to a unique index. + - array_subset (np.ndarray): A 3D numpy array of shape (max(X)+1, max(Y)+1, len(target_list)), filled based on the filtered DataFrame. + - array_subset_2d (np.ndarray): A 2D numpy array obtained by summing `array_subset` along the third axis. + """ + + # Filter the DataFrame based on target_list + df_subset = df_total.loc[df_total[target_col].isin(target_list)] + + # Create a dictionary mapping each target to a unique index + target_dict_subset = {target: index for index, target in enumerate(df_subset[target_col].unique())} + + # Define the shape of the 3D array + array_shape_subset = (df_total[col_x].max() + 1, df_total[col_y].max() + 1, len(target_list)) + + # Create the 3D array using the provided get_array function + array_subset = transform_df_to_array(df=df_subset, target_dict=target_dict_subset, array_shape=array_shape_subset).astype(np.int8) + + # # Sum the 3D array along the third axis to create a 2D array + # array_subset_2d = np.sum(array_subset, axis=2) + + return df_subset, array_subset, target_dict_subset + + +def get_subset_arrays(df_total: pd.DataFrame, array_total: np.ndarray, target_dict_total: dict, + target_list: list, target_col: str = 'target') -> tuple: + """ + Get a subset of the DataFrame, the corresponding slices from the total array, and the subset target dictionary. + + Parameters + ---------- + df_total : pd.DataFrame + The input DataFrame containing the data. + array_total : np.ndarray + The 3D array representing the entire dataset. + target_dict_total : dict + A dictionary mapping each target in the total dataset to its index. + target_list : list + List of target values to filter the DataFrame and array. + target_col : str, optional + Column name in the DataFrame containing target values, by default 'target'. + + Returns + ------- + tuple + A tuple containing: + - df_subset (pd.DataFrame): The filtered DataFrame. + - array_subset (np.ndarray): The subset of the array corresponding to the target_list. + - target_dict_subset (dict): The subset dictionary mapping the filtered targets to indices. + """ + + # Filter the DataFrame based on target_list + df_subset = df_total.loc[df_total[target_col].isin(target_list)] + + # Create a mapping from target_list to indices in the total array + target_indices_subset = [target_dict_total.get(target, -1) for target in target_list] + + # Initialize an array of zeros with the same shape as array_total for the first two dimensions, + # and the length of target_list for the last dimension + array_subset = np.zeros(array_total.shape[:2] + (len(target_list),)) + + # Extract the relevant slices from the array + for i, target_index in enumerate(target_indices_subset): + if target_index != -1: # if the target is in target_dict_total + array_subset[:, :, i] = array_total[:, :, target_index] + + # Create the subset target dictionary + target_dict_subset = {target: index for index, target in enumerate(target_list)} + + return df_subset, array_subset, target_dict_subset + +if __name__ == "__main__": + + def compare_functions(df_total, array_total, target_dict_total, target_list): + setup_code = """ +import pandas as pd +import numpy as np +from __main__ import get_subset_arrays, get_subset_arrays_V1, df_total, array_total, target_dict_total, target_list +""" + stmt_V1 = "get_subset_arrays(df_total, array_total, target_dict_total, target_list)" + stmt_V2 = "get_subset_arrays_V1(df_total, target_list)" + + time_V1 = timeit.timeit(stmt=stmt_V1, setup=setup_code, number=100) + time_V2 = timeit.timeit(stmt=stmt_V2, setup=setup_code, number=100) + + result_V1 = get_subset_arrays(df_total, array_total, target_dict_total, target_list) + result_V2 = get_subset_arrays_V1(df_total, target_list) + + df_equal = result_V1[0].equals(result_V2[0]) + target_dict_equal = result_V1[2] == result_V2[2] + arrays_equal = np.array_equal(result_V1[1], result_V2[1]) + + return time_V1, time_V2, df_equal, target_dict_equal, arrays_equal + + # Sample data for testing + data = {'X': np.random.randint(0, 10, size=1000), + 'Y': np.random.randint(0, 10, size=1000), + 'target': np.random.choice(['target1', 'target2', 'target3'], size=1000)} + df_total = pd.DataFrame(data) + + target_dict_total = {target: index for index, target in enumerate(df_total['target'].unique())} + height, width = df_total['X'].max() + 1, df_total['Y'].max() + 1 + array_total = transform_df_to_array(df=df_total, target_dict=target_dict_total, + array_shape=(height, width, len(target_dict_total))).astype(np.int8) + + target_list = ['target1', 'target2'] + + time_V1, time_V2, df_equal, target_dict_equal, arrays_equal = compare_functions(df_total, array_total, + target_dict_total, target_list) + + print(f"Execution time for get_subset_arrays: {time_V1:.6f} seconds") + print(f"Execution time for get_subset_arrays_V1: {time_V2:.6f} seconds") + print(f"DataFrames are equal: {df_equal}") + print(f"Target dictionaries are equal: {target_dict_equal}") + print(f"Arrays are equal: {arrays_equal}") + + + """ + Execution time for get_subset_arrays: 0.028953 seconds ----- !!!!!!! + Execution time for get_subset_arrays_V1: 0.087730 seconds + DataFrames are equal: True + Target dictionaries are equal: True + Arrays are equal: True + + """ diff --git a/gridgene/get_masks.py b/gridgene/get_masks.py new file mode 100644 index 0000000..191b0e5 --- /dev/null +++ b/gridgene/get_masks.py @@ -0,0 +1,1278 @@ +import logging +import cv2 +import numpy as np +import os +import matplotlib # added for docs generation +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import matplotlib.cm as cm +import matplotlib.axes +from scipy.spatial import Voronoi, voronoi_plot_2d +from shapely.geometry import Polygon +from typing import Dict, List, Tuple, Union +from matplotlib.lines import Line2D +from matplotlib.colors import ListedColormap +from gridgene.logger import get_logger +from typing import Optional, Tuple, Dict, Any, List +from scipy.ndimage import distance_transform_edt +from skimage.measure import label +from scipy.ndimage import distance_transform_edt +from skimage.measure import regionprops +import cv2 +import numpy as np +from shapely.geometry import Polygon, box + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + print(f"{func.__name__} took {end - start:.4f} seconds") + return result + return wrapper + +class GetMasks: + """ + Class to handle mask processing operations such as filtering, creation, morphology, subtraction, saving, and plotting. + + Parameters + ---------- + logger : logging.Logger, optional + Logger instance for logging messages. If None, a default logger is configured. + image_shape : tuple of int, optional + Tuple representing the shape of the image (height, width). + """ + + def __init__(self, logger: Optional[logging.Logger] = None, image_shape: Optional[Tuple[int, int]] = None): + """ + Initialize the GetMasks class. + + Parameters + ---------- + logger : logging.Logger, optional + Logger instance for logging messages. If None, a default logger is created. + image_shape : tuple of int, optional + Tuple representing the shape of the image (height, width). + + Returns + ------- + None + """ + self.image_shape = image_shape + self.height = self.image_shape[0] if self.image_shape is not None else None + self.width = self.image_shape[1] if self.image_shape is not None else None + self.logger = logger or get_logger(f'{__name__}.{"GetMasks"}') + self.logger.info("Initialized GetMasks") + + def filter_binary_mask_by_area(self, mask: np.ndarray, min_area: int) -> np.ndarray: + """ + Remove small connected components from a binary mask. + + Parameters + ---------- + mask : np.ndarray + Binary mask (0 or 1). + min_area : int + Minimum area threshold. + + Returns + ------- + np.ndarray + Filtered binary mask. + """ + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8), connectivity=8) + + output_mask = np.zeros_like(mask, dtype=np.uint8) + for i in range(1, num_labels): # skip background + area = stats[i, cv2.CC_STAT_AREA] + if area >= min_area: + output_mask[labels == i] = 1 + return output_mask + + def filter_labeled_mask_by_area(self, mask: np.ndarray, min_area: int) -> np.ndarray: + """ + Filter a labeled mask by keeping only components with area >= min_area. + + Parameters + ---------- + mask : np.ndarray + Input labeled mask (integer labels). + min_area : int + Minimum area threshold. + + Returns + ------- + np.ndarray + Filtered labeled mask preserving label IDs. + """ + mask = mask.astype(np.int32) + unique_labels, counts = np.unique(mask, return_counts=True) + labels_to_keep = unique_labels[(counts >= min_area) & (unique_labels != 0)] + + filtered_mask = np.zeros_like(mask, dtype=np.int32) + for label in labels_to_keep: + filtered_mask[mask == label] = label + + # if logger: + self.logger.info(f'Filtered labeled mask by area >= {min_area}, kept {len(labels_to_keep)} components.') + + return filtered_mask + + def create_mask(self, contours: List[np.ndarray]) -> np.ndarray: + """ + Create a binary mask from contours. + + Parameters + ---------- + contours : list of np.ndarray + List of contours. + + Returns + ------- + np.ndarray + Binary mask. + + Raises + ------ + ValueError + If image shape is not defined. + """ + if self.height is None or self.width is None: + raise ValueError("Image shape must be defined to create mask.") + mask = np.zeros((self.height, self.width), dtype=np.uint8) + cv2.drawContours(mask, contours, -1, color=1, thickness=cv2.FILLED) + return mask + + def fill_holes(self, mask: np.ndarray) -> np.ndarray: + """ + Fill holes inside a binary mask. + + Parameters + ---------- + mask : np.ndarray + Binary mask. + + Returns + ------- + np.ndarray + Hole-filled binary mask. + """ + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + filled_mask = np.zeros_like(mask) + cv2.drawContours(filled_mask, contours, -1, color=1, thickness=cv2.FILLED) + return filled_mask + + def apply_morphology(self, mask: np.ndarray, operation: str = "open", kernel_size: int = 3) -> np.ndarray: + """ + Apply morphological operations to a binary mask. + + Parameters + ---------- + mask : np.ndarray + Binary mask to process. + operation : str, optional + Morphological operation: "open", "close", "erode", or "dilate" (default is "open"). + kernel_size : int, optional + Size of the structuring element (default is 3). + + Returns + ------- + np.ndarray + Processed binary mask. + """ + kernel = np.ones((kernel_size, kernel_size), np.uint8) + + if operation == "open": + result = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) + elif operation == "close": + result = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + elif operation == "erode": + result = cv2.erode(mask, kernel, iterations=1) + elif operation == "dilate": + result = cv2.dilate(mask, kernel, iterations=1) + else: + self.logger.warning(f"Unknown morphological operation '{operation}', returning original mask.") + result = mask + + self.logger.info(f'Applied morphology operation "{operation}" with kernel size {kernel_size}.') + return result + + def subtract_masks(self, base_mask: np.ndarray, *masks: np.ndarray) -> np.ndarray: + """ + Subtract one or more masks from a base mask. + + Parameters + ---------- + base_mask : np.ndarray + Initial binary mask. + *masks : np.ndarray + Masks to subtract from the base mask. + + Returns + ------- + np.ndarray + Resulting mask after subtraction. + """ + result_mask = base_mask.copy() + for mask in masks: + result_mask = cv2.subtract(result_mask, mask) + self.logger.info(f'Subtracted masks from base mask.') + return result_mask + + def save_masks_npy(self, mask: np.ndarray, save_path: str) -> None: + """ + Save mask as a .npy file. + + Parameters + ---------- + mask : np.ndarray + Mask to save. + save_path : str + Path to save the .npy file. + + Returns + ------- + None + """ + np.save(save_path, mask) + self.logger.info(f'Mask saved at {save_path}') + + def save_masks(self, mask: np.ndarray, path: str) -> None: + """ + Save mask as an image file. + + Parameters + ---------- + mask : np.ndarray + Binary mask to save. + path : str + Path to save the image file. + + Returns + ------- + None + """ + cv2.imwrite(path, mask * 255) + self.logger.info(f'Mask saved at {path}') + + def plot_masks( + self, + masks: List[np.ndarray], + mask_names: List[str], + background_color: Tuple[int, int, int] = (0, 0, 0), + mask_colors: Optional[Dict[str, Tuple[int, int, int]]] = None, + path: Optional[str] = None, + show: bool = True, + ax: Optional[plt.Axes] = None, + figsize: Tuple[int, int] = (10, 10) + ) -> None: + """ + Plot multiple masks with their corresponding names. + + Parameters + ---------- + masks : list of np.ndarray + List of masks to plot. + mask_names : list of str + Names corresponding to each mask. + background_color : tuple of int, optional + RGB color tuple for background areas (default (0, 0, 0)). + mask_colors : dict, optional + Mapping of mask names to RGB colors. + path : str, optional + Directory path to save the plot image. + show : bool, optional + Whether to display the plot (default True). + ax : matplotlib.axes.Axes, optional + Matplotlib axis to plot on. Creates new figure if None. + figsize : tuple of int, optional + Size of the figure in inches (width, height). + + Returns + ------- + None + """ + if len(masks) != len(mask_names): + self.logger.error('The number of masks and mask names must be the same.') + return + + # Create a background image filled with the background color + background = np.full((self.height, self.width, 3), background_color) + + # Create a list to store the patches for the legend + legend_patches = [] + + # Choose a colormap based on the number of masks + colormap = cm.get_cmap('tab10') if len(masks) <= 10 else cm.get_cmap('tab20') + + # Add each mask to the background image + for i, (mask, mask_name) in enumerate(zip(masks, mask_names)): + # Choose a color for the mask + if mask_colors and mask_name in mask_colors: + mask_color = np.array(mask_colors[mask_name]) + else: + mask_color = (np.array(colormap(i % colormap.N)[:3]) * 255).astype(int) + # Apply the mask color to the mask image + background[mask!=0] = mask_color + + # Create a patch for the legend + legend_patches.append(mpatches.Patch(color=mask_color / 255, label=mask_name)) + + # Flip the mask horizontally and rotate 90 degrees clockwise + background = np.fliplr(background) + background = np.rot90(background, k=1) + created_fig = False + if ax is None: + created_fig = True + fig, ax = plt.subplots(figsize=figsize) + + # Plot the background image + ax.imshow(background, origin='lower') + ax.set_axis_off() + + # Add legend + ax.legend( + handles=legend_patches, + bbox_to_anchor=(1.05, 1), + loc='upper left', + bbox_transform=ax.transAxes + ) + + # Save the image if path is provided + if path is not None: + save_path = os.path.join( + path, + f'masks_{"_".join(mask_names).replace(" ", "").lower()}.png' + ) + plt.savefig(save_path, dpi=1000, bbox_inches='tight') + self.logger.info(f'Plot saved at {save_path}') + + # Show the plot if required + if show: + plt.show() + plt.close() + + # Close the figure if it was created within this function + if created_fig: + plt.close(fig) + + + def plot_labeled_masks(self, label_mask,mask_name, show=False, save_path=None, dpi=300): + """ + Plot the labeled mask with colored objects and bounding boxes. + Parameters + ---------- + mask_dict : dict (required) + show : bool (optional) + + Returns + ------- + + """ + unique_labels = np.unique(label_mask) + + # Generate random colors for each label using a colormap + colormap = cm.get_cmap('tab10', len(unique_labels)) + colors = {label: colormap(i) for i, label in enumerate(unique_labels) if label != 0} + + # Create a colored mask + colored_mask = np.zeros((self.height,self.width, 3), dtype=np.float32) + for label in unique_labels: + if label == 0: + continue + colored_mask[label_mask == label] = colors[label][:3] + + # Create a figure and axis to plot the mask + fig, ax = plt.subplots() + ax.imshow(colored_mask, origin='lower') + + # Plot each labeled object with its corresponding color and label number + for region in regionprops(label_mask): + if region.label == 0: + continue + minr, minc, maxr, maxc = region.bbox + rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr, + fill=False, edgecolor=colors[region.label], linewidth=2) + ax.add_patch(rect) + y, x = region.centroid + ax.text(x, y, str(region.label), color='white', fontsize=8, ha='center', va='center') + + # Set the title and show the plot + ax.set_title(mask_name) + + # # Save the plot as a high-resolution image + if save_path is not None: + fig.savefig(save_path, dpi=dpi, bbox_inches='tight') + + # Show the plot if requested + if show: + plt.show() + + return fig, ax + +# CancerStromaInterfaceanalysis +class ConstrainedMaskExpansion(GetMasks): + """ + Class for expanding a seed mask with constraints, generating binary, labeled, and referenced expansions. + """ + + def __init__( + self, + seed_mask: np.ndarray, + constraint_mask: Optional[np.ndarray] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + """ + Initialize the ConstrainedMaskExpansion object. + + Parameters + ---------- + seed_mask : np.ndarray + Binary seed mask to expand (non-zero labeled regions). + constraint_mask : np.ndarray, optional + Binary mask to limit the expansion area. If None, no constraint is applied. + logger : logging.Logger, optional + Logger instance for logging messages. + + Raises + ------ + ValueError + If seed_mask is None. + """ + if seed_mask is None: + raise ValueError("Seed mask cannot be None.") + + self.seed_mask_raw = seed_mask.astype(np.uint8) + self.seed_mask = label(self.seed_mask_raw) # connected components + self.constraint_mask = ( + constraint_mask.astype(np.uint8) + if constraint_mask is not None + else np.ones_like(seed_mask, dtype=np.uint8) + ) + + image_shape = self.seed_mask.shape + super().__init__(logger=logger, image_shape=image_shape) + + self.binary_expansions: Dict[str, np.ndarray] = {} + self.labeled_expansions: Dict[str, np.ndarray] = {} + self.referenced_expansions: Dict[str, np.ndarray] = {} + + def expand_mask( + self, + expansion_pixels: List[int], + min_area: Optional[int] = None, + restrict_to_limit: bool = True, + ) -> None: + """ + Expand the seed mask outward by specified pixel distances with optional area filtering and constraints. + + Parameters + ---------- + expansion_pixels : list of int + List of expansion distances (in pixels) from the seed mask. + min_area : int, optional + Minimum area threshold for keeping connected components in each expansion ring. + restrict_to_limit : bool, optional + If True, limit the expansion within the constraint mask. + + Returns + ------- + None + """ + sorted_dists = sorted(expansion_pixels) + dist_map = distance_transform_edt(self.seed_mask == 0) + + previous_mask = np.zeros_like(self.seed_mask, dtype=bool) + current_labels = self.seed_mask.copy() + for dist in sorted_dists: + if dist == sorted_dists[0]: + ring = (dist_map <= dist) & (self.seed_mask == 0) + else: + prev_dist = sorted_dists[sorted_dists.index(dist) - 1] + ring = (dist_map <= dist) & (dist_map > prev_dist) & (self.seed_mask == 0) + + if restrict_to_limit: + ring &= self.constraint_mask.astype(bool) + + ring &= ~previous_mask + + if min_area: + ring = self.filter_binary_mask_by_area(ring.astype(np.uint8), min_area).astype(bool) + + previous_mask |= ring + + # Store binary mask + self.binary_expansions[f"expansion_{dist}"] = ring.astype(np.uint8) + + # Store labeled components using skimage + self.labeled_expansions[f"expansion_{dist}"] = label(ring.astype(np.uint8)) + + # Store label-referenced expansion using seed_mask + # referenced = self.propagate_labels(self.seed_mask, ring) + referenced = self.propagate_labels(current_labels, ring) + referenced[~ring] = 0 + self.referenced_expansions[f"expansion_{dist}"] = referenced + current_labels[referenced > 0] = referenced[referenced > 0] + + self.binary_expansions["seed_mask"] = (self.seed_mask > 0).astype(np.uint8) + self.labeled_expansions["seed_mask"] = self.seed_mask.copy() + self.referenced_expansions["seed_mask"] = self.seed_mask.copy() + + constraint_remaining = (self.constraint_mask.astype(bool) & ~previous_mask).astype(np.uint8) + self.binary_expansions["constraint_remaining"] = constraint_remaining + # self.labeled_expansions["constraint_remaining"] = np.zeros_like(self.seed_mask, dtype=np.int32) + self.labeled_expansions["constraint_remaining"] = label(constraint_remaining) + self.referenced_expansions["constraint_remaining"] = np.zeros_like(self.seed_mask, dtype=np.int32) + + def propagate_labels(self, seed_labeled: np.ndarray, expansion_mask: np.ndarray) -> np.ndarray: + """ + Propagate labels from the seed labeled mask into the expansion region + using nearest-neighbor distance transform. + + Parameters + ---------- + seed_labeled : np.ndarray + Labeled seed mask where non-zero values indicate components. + expansion_mask : np.ndarray + Binary mask indicating the expansion region to propagate labels into. + + Returns + ------- + np.ndarray + Labeled mask with propagated labels in the expansion area. + """ + output = np.zeros_like(seed_labeled, dtype=np.int32) + + # Compute distance transform on inverse of seed (background = True) + # Return indices of nearest labeled pixels + distance, indices = distance_transform_edt(seed_labeled == 0, return_indices=True) + + # Use the nearest labeled pixel for expansion mask locations + nearest_labels = seed_labeled[tuple(indices)] + + # Fill only the expansion region with nearest labels + output[expansion_mask.astype(bool)] = nearest_labels[expansion_mask.astype(bool)] + + # Preserve original seed labels + output[seed_labeled > 0] = seed_labeled[seed_labeled > 0] + + return output + +class SingleClassObjectAnalysis(GetMasks): + """ + Analyze and expand a single binary object mask using distance-based ring expansion. + + This class computes concentric ring-based expansions of a binary mask, + assigns unique labels to each expanded region, and tracks mask lineage + through label propagation. + + Attributes + ---------- + mask : np.ndarray + Binary mask of the object to be expanded. + expansion_distances : List[int] + List of expansion radii in pixels. + labelled_mask : np.ndarray + Resulting labeled mask with original and expanded areas. + binary_masks : Dict[str, np.ndarray] + Dictionary of binary masks keyed by expansion distance. + labelled_masks : Dict[str, np.ndarray] + Dictionary of labeled masks keyed by expansion distance. + reference_masks : Dict[str, np.ndarray] + Masks encoding reference to original object. + """ + + def __init__( + self, + get_masks_instance: GetMasks, + contours_object: List[np.ndarray], + contour_name: str = "" + ) -> None: + """ + Initialize SingleClassObjectAnalysis with contour data and a GetMasks utility instance. + + Parameters + ---------- + get_masks_instance : GetMasks + Instance of GetMasks providing access to shape and filtering methods. + contours_object : List[np.ndarray] + List of contours representing the object. + contour_name : str, optional + Optional name identifier for the object. + """ + + self.get_masks_instance = get_masks_instance + self.height = get_masks_instance.height + self.width = get_masks_instance.width + self.logger = get_masks_instance.logger + + self.mask_object_SA: Optional[np.ndarray] = None + self.binary_expansions: Dict[str, np.ndarray] = {} + self.labeled_expansions: Dict[str, np.ndarray] = {} + self.referenced_expansions: Dict[str, np.ndarray] = {} + self.contours_object = contours_object + self.contour_name = contour_name + + def get_mask_objects( + self, + exclude_masks: Optional[List[np.ndarray]] = None, + filter_area: Optional[int] = None + ) -> None: + """ + Generate binary mask from object contours, optionally subtract other masks, + and apply area-based filtering. + + Parameters + ---------- + exclude_masks : list of np.ndarray, optional + List of masks to subtract from the generated object mask. + filter_area : int, optional + Minimum area threshold to retain connected components in the object mask. + + Returns + ------- + None + """ + mask_object = np.zeros((self.height, self.width), dtype=np.uint8) + cv2.drawContours(mask_object, self.contours_object, -1, color=1, thickness=cv2.FILLED) + + if exclude_masks: + for mask in exclude_masks: + mask_object = cv2.subtract(mask_object, mask) + + if filter_area is not None: + self.logger.info(f"Filtering object mask by area: {filter_area}") + mask_object = self.get_masks_instance.filter_mask_by_area(mask_object, min_area=filter_area) + + self.mask_object_SA = mask_object + self.logger.info("Mask for objects created.") + + def get_objects_expansion( + self, + expansions_pixels: Optional[List[int]] = None, + filter_area: Optional[int] = None + ) -> None: + """ + Expand the object mask using distance-based rings and optionally filter + each ring by minimum area. Generates binary, labeled, and propagated-label expansion masks. + + Parameters + ---------- + expansions_pixels : list of int, optional + List of pixel distances for expansion. + filter_area : int, optional + Minimum area threshold to retain connected components in each expansion ring. + + Returns + ------- + None + """ + if self.mask_object_SA is None: + self.logger.error("No object mask to expand.") + return + + if expansions_pixels is None: + expansions_pixels = [] + + self.seed_mask = label(self.mask_object_SA) + dist_map = distance_transform_edt(self.seed_mask == 0) + previous_mask = np.zeros_like(self.seed_mask, dtype=bool) + current_labels = self.seed_mask.copy() + + for i, dist in enumerate(sorted(expansions_pixels)): + prev_dist = sorted(expansions_pixels)[i - 1] if i > 0 else 0 + raw_ring = (dist_map <= dist) & (dist_map > prev_dist) & (self.seed_mask == 0) + + if filter_area: + raw_ring = self.get_masks_instance.filter_binary_mask_by_area(raw_ring.astype(np.uint8), + filter_area).astype(bool) + key = f"expansion_{dist}" + + ring = raw_ring & (~previous_mask) + + if not np.any(ring): + self.logger.warning(f"Expansion ring for distance {dist} is empty.") + empty_mask = np.zeros_like(self.seed_mask, dtype=np.uint8) + self.binary_expansions[key] = empty_mask + self.labeled_expansions[key] = empty_mask + self.referenced_expansions[key] = empty_mask + continue + + previous_mask |= ring + + self.binary_expansions[key] = ring.astype(np.uint8) + self.labeled_expansions[key] = label(ring.astype(np.uint8)) + + referenced = self.propagate_labels(current_labels, ring) + referenced[~ring] = 0 + self.referenced_expansions[key] = referenced + current_labels[referenced > 0] = referenced[referenced > 0] + + # Store the base seed info + self.binary_expansions["seed_mask"] = (self.seed_mask > 0).astype(np.uint8) + self.labeled_expansions["seed_mask"] = self.seed_mask.copy() + self.referenced_expansions["seed_mask"] = self.seed_mask.copy() + + def propagate_labels(self, seed_labeled: np.ndarray, expansion_mask: np.ndarray) -> np.ndarray: + """ + Propagate labeled regions from a seed mask into the expansion area using iterative dilation. + + Parameters + ---------- + seed_labeled : np.ndarray + Input labeled mask where each connected component has a unique integer label. + expansion_mask : np.ndarray + Binary mask indicating the region where labels should expand. + + Returns + ------- + np.ndarray + Labeled mask with labels propagated into the expansion region. + """ + output = np.zeros_like(seed_labeled, dtype=np.int32) + # output[seed_labeled > 0] = seed_labeled[seed_labeled > 0] + # + # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) + # expansion_mask = expansion_mask.astype(bool) + # iteration = 0 + # + # while True: + # iteration += 1 + # prev = output.copy() + # + # mask_to_fill = (output == 0) & expansion_mask + # dilated = cv2.dilate(output.astype(np.float32), kernel) + # dilated = dilated.astype(np.int32) + # + # output[mask_to_fill] = dilated[mask_to_fill] + # + # if np.array_equal(output, prev): + # break + # if iteration > 1000: + # if self.logger: + # self.logger.warning("Label propagation exceeded 1000 iterations.") + # break + # + # return output + distance, indices = distance_transform_edt(seed_labeled == 0, return_indices=True) + + # Use the nearest labeled pixel for expansion mask locations + nearest_labels = seed_labeled[tuple(indices)] + + # Fill only the expansion region with nearest labels + output[expansion_mask.astype(bool)] = nearest_labels[expansion_mask.astype(bool)] + + # Preserve original seed labels + output[seed_labeled > 0] = seed_labeled[seed_labeled > 0] + + return output + +# Propagate labels: If performance is a concern, the dilation-based propagation loop can be optimized with a queue-based BFS flood-fill instead. +class MultiClassObjectAnalysis(GetMasks): + """ + Analyze and expand multiple object contours across different classes using Voronoi constraints. + + Constructs Voronoi diagrams to limit spatial expansion, assigns unique labels to each object, + and tracks class-wise and parent-wise mask lineage for downstream analysis. + + Attributes + ---------- + multiple_contours : dict[str, list[np.ndarray]] + Input contours grouped by class. + height : int + Image height. + width : int + Image width. + save_path : str or None + Optional path to save outputs. + vor : scipy.spatial.Voronoi or None + Computed Voronoi diagram. + all_centroids : np.ndarray or None + Coordinates of centroids of input objects. + class_labels : list[str] or None + Class label for each object. + binary_masks : dict[str, np.ndarray] + Output binary masks by class and expansions. + labeled_masks : dict[str, np.ndarray] + Output labeled masks by class and expansions. + referenced_masks : dict[str, np.ndarray] + Output referenced masks mapping pixels back to parent objects. + """ + + def __init__(self, get_masks_instance, multiple_contours: dict, save_path: str = None): + """ + Initialize MultiClassObjectAnalysis instance. + + Parameters + ---------- + get_masks_instance : GetMasks + Instance of GetMasks class with base image properties. + multiple_contours : dict[str, list[np.ndarray]] + Dictionary mapping class names to lists of contours. + save_path : str, optional + Directory path to save outputs (default is None). + """ + super().__init__() + self.get_masks_instance = get_masks_instance + + self.height = self.get_masks_instance.height + self.width = self.get_masks_instance.width + self.logger = self.get_masks_instance.logger + + # Remove tumour/stroma mask references as per your note + self.multiple_contours = multiple_contours + self.masks = None + self.vor = None + self.list_of_polygons = None + self.class_labels = None + self.all_centroids = None + self.voronoi_regions = None + self.voronoi_vertices = None + self.save_path = save_path + + for class_label, contours in self.multiple_contours.items(): + for i, contour in enumerate(contours): + if contour.shape[0] < 4: + self.logger.warning(f"Skipping contour with less than 4 points for class '{class_label}'.") + continue + self.multiple_contours[class_label][i] = contour[::-1] + + @staticmethod + def voronoi_finite_polygons_2d(vor, radius=None): + """ + Reconstruct finite Voronoi polygons in 2D by clipping infinite regions. + + Parameters + ---------- + vor : scipy.spatial.Voronoi + The original Voronoi diagram from scipy.spatial. + radius : float, optional + Distance to extend infinite edges (default is twice the maximum image dimension). + + Returns + ------- + regions : list[list[int]] + List of polygon regions as indices of vertices. + vertices : np.ndarray + Array of Voronoi vertices coordinates. + """ + if vor.points.shape[1] != 2: + raise ValueError("Requires 2D input") + + new_regions = [] + new_vertices = vor.vertices.tolist() + + center = vor.points.mean(axis=0) + if radius is None: + radius = vor.points.ptp().max() * 2 + + # Map of all ridges for a point + all_ridges = {} + for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices): + all_ridges.setdefault(p1, []).append((p2, v1, v2)) + all_ridges.setdefault(p2, []).append((p1, v1, v2)) + + # Reconstruct finite polygons + for p1, region_index in enumerate(vor.point_region): + vertices = vor.regions[region_index] + + if all(v >= 0 for v in vertices): + # Finite region + new_regions.append(vertices) + continue + + ridges = all_ridges[p1] + new_region = [v for v in vertices if v >= 0] + + for p2, v1, v2 in ridges: + if v1 >= 0 and v2 >= 0: + continue + + t = vor.points[p2] - vor.points[p1] # tangent + t /= np.linalg.norm(t) + n = np.array([-t[1], t[0]]) # normal vector + + midpoint = vor.points[[p1, p2]].mean(axis=0) + direction = np.sign(np.dot(midpoint - center, n)) * n + far_point = vor.vertices[v1 if v1 >= 0 else v2] + direction * radius + + new_vertices.append(far_point.tolist()) + new_region.append(len(new_vertices) - 1) + + # Sort region counterclockwise + vs = np.array([new_vertices[v] for v in new_region]) + c = vs.mean(axis=0) + angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0]) + new_region = [new_region[i] for i in np.argsort(angles)] + + new_regions.append(new_region) + + return new_regions, np.asarray(new_vertices) + + def get_polygons_from_contours(self, contours: List[np.ndarray]) -> List[Polygon]: + """ + Convert contours into Shapely polygons. + + Parameters + ---------- + contours : list[np.ndarray] + List of contour arrays of shape (N, 2). + + Returns + ------- + polygons : list[Polygon] + List of valid Shapely Polygon objects. + """ + polygons = [] + for cnt in contours: + if cnt.shape[0] < 4: + continue # Too few points to form a polygon + + coords = cnt.squeeze() + + if coords.shape[0] < 4: + continue # Still too few after squeezing + + # Ensure it's closed (first point == last point) + if not np.array_equal(coords[0], coords[-1]): + coords = np.vstack([coords, coords[0]]) + + try: + polygon = Polygon(coords) + if not polygon.is_valid or polygon.area == 0: + continue # Skip invalid or zero-area polygons + polygons.append(polygon) + except Exception: + continue # Defensive: skip any invalid contour + return polygons + + def derive_voronoi_from_contours(self) -> None: + """ + Compute a Voronoi diagram from centroids of contours. + + Computes Voronoi regions and finite polygons clipped to a large radius. + Stores regions, vertices, class labels, and centroids for further processing. + + Raises + ------ + ValueError + If no contours are available to derive the Voronoi diagram. + """ + all_contours = [contour for contour_points in self.multiple_contours.values() for contour in contour_points if contour.shape[0] >= 4] + if not all_contours: + raise ValueError("No contours found to derive Voronoi diagram.") + + list_of_polygons = self.get_polygons_from_contours(all_contours) + + centroids = [] + class_labels = [] + for class_label, contours in self.multiple_contours.items(): + for contour in contours: + contour = contour.squeeze() + + if contour is not None and len(contour) >= 3: + polygon = Polygon(contour) + centroids.append(polygon.centroid) + class_labels.append(class_label) + else: + self.logger.warning(f"Skipping contour with less than 4 points for class '{class_label}'.") + continue + + if len(centroids) < 4: + # Not enough data to compute Voronoi + self.logger.warning("Not enough valid centroids for Voronoi diagram. Skipping Voronoi computation.") + self.list_of_polygons = list_of_polygons + self.class_labels = class_labels + self.all_centroids = np.array([(c.x, c.y) for c in centroids]) if centroids else None + self.vor = None + self.voronoi_regions = None + self.voronoi_vertices = None + return + + all_centroids = np.array([(c.x, c.y) for c in centroids]) + vor = Voronoi(all_centroids) + + # Use finite polygons clipped to a large radius (image max dimension * 2) + regions, vertices = self.voronoi_finite_polygons_2d(vor, radius=max(self.height, self.width) * 2) + + self.list_of_polygons = list_of_polygons + self.class_labels = class_labels + self.all_centroids = all_centroids + self.vor = vor + self.voronoi_regions = regions + self.voronoi_vertices = vertices + + def get_voronoi_mask(self, category_name: str) -> np.ndarray: + """ + Get a binary mask for the Voronoi region of a given category. + + If Voronoi regions are not computed (e.g. too few centroids), returns a full mask. + + Parameters + ---------- + category_name : str + The category/class name for which the mask is requested. + + Returns + ------- + mask : np.ndarray + Binary mask of shape (height, width) with Voronoi regions for the category. + """ + mask = np.zeros((self.height, self.width), dtype=np.uint8) + + # If Voronoi could not be computed, default to full image for that category + if self.voronoi_regions is None or self.voronoi_vertices is None: + # Option 1: Allow expansion to go anywhere + mask[:, :] = 255 + return mask + + # Normal case + for idx, (label, region) in enumerate(zip(self.class_labels, self.voronoi_regions)): + if label != category_name: + continue + polygon = self.voronoi_vertices[region] + polygon[:, 0] = np.clip(polygon[:, 0], 0, self.width - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, self.height - 1) + int_polygon = polygon.astype(np.int32) + if len(int_polygon) >= 3: + cv2.fillPoly(mask, [int_polygon], color=255) + + return mask + + def expand_mask(self, mask: np.ndarray, expansion_distance: int) -> np.ndarray: + """ + Expand a binary mask by a given pixel distance using distance transform. + + The returned mask corresponds to the expansion region excluding the original mask. + + Parameters + ---------- + mask : np.ndarray + Binary input mask to expand. + expansion_distance : int + Number of pixels to expand the mask by. + + Returns + ------- + np.ndarray + Binary mask representing the expansion area only. + """ + if not np.any(mask): + return np.zeros_like(mask, dtype=np.uint8) + + # Compute distance from the background to the object mask + dist_transform = distance_transform_edt(mask == 0) + + # Select pixels within the expansion distance (excluding original mask) + expanded_mask = (dist_transform <= expansion_distance) & (mask == 0) + expanded_mask = expanded_mask.astype(np.uint8) # Convert to binary mask + return expanded_mask + + def generate_expanded_masks_limited_by_voronoi( + self, + expansion_distances: list[int] + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], dict[str, np.ndarray]]: + """ + Generate expanded masks for each object limited by their Voronoi regions. + + For each class and its contours, original masks are created and then expanded + by the specified distances, clipped to the corresponding Voronoi region. + All expansions are labeled and tracked with parent IDs. + + Parameters + ---------- + expansion_distances : list[int] + List of pixel distances for mask expansion rings. + + Returns + ------- + tuple of dict + - binary_masks: dict mapping mask names to binary masks. + - labeled_masks: dict mapping mask names to labeled masks with unique IDs. + - referenced_masks: dict mapping mask names to masks referencing parent object IDs. + """ + masks = {} # Step 1: Generate masks for each contour, and label objects + labeled_masks = {} + referenced_labeled_mask = np.zeros((self.height, self.width), dtype=np.int32) + + parent_id_counter = 1 # unique ID for each original object across all classes + + # Map from category -> list of (parent_id, mask) + original_masks_info = {} + + # Create binary masks for each individual contour, label them, assign parent IDs + for category_name, contours in self.multiple_contours.items(): + if not contours or all(c.shape[0] < 4 for c in contours): + empty_mask = np.zeros((self.height, self.width), dtype=np.uint8) + empty_labeled = np.zeros_like(empty_mask, dtype=np.int32) + key = f"{category_name}" + + masks[key] = empty_mask + labeled_masks[key] = empty_labeled + + original_masks_info[category_name] = [] + # Add empty expansions too + for expansion_distance in expansion_distances: + exp_key = f"{category_name}_expansion_{expansion_distance}" + masks[exp_key] = empty_mask.copy() + labeled_masks[exp_key] = empty_labeled.copy() + + category_masks = [] + for contour in contours: + mask = np.zeros((self.height, self.width), dtype=np.uint8) + cv2.drawContours(mask, [contour], -1, 1, thickness=cv2.FILLED) + # Label connected components (should be 1 per mask but be safe) + labeled = label(mask > 0) + # Extract regionprops if needed, here we just assign parent_id directly + labeled_mask = np.zeros_like(labeled, dtype=np.int32) + # Assign the unique parent ID to all pixels in this object + labeled_mask[labeled > 0] = parent_id_counter + + # Update global referenced mask + referenced_labeled_mask[labeled_mask > 0] = parent_id_counter + + # Store original mask and label + masks[f'{category_name}_{parent_id_counter}'] = mask + labeled_masks[f'{category_name}_{parent_id_counter}'] = labeled_mask + + category_masks.append((parent_id_counter, mask)) + parent_id_counter += 1 + original_masks_info[category_name] = category_masks + + # Step 2: Generate expansions and label them, mapping back to parent IDs + expanded_masks = {} + expanded_labeled_masks = {} + + for category_name, masks_info in original_masks_info.items(): + voronoi_mask = self.get_voronoi_mask(category_name) + for parent_id, base_mask in masks_info: + previous_expansion_mask = np.zeros((self.height, self.width), dtype=np.uint8) + for expansion_distance in expansion_distances: + current_expansion_mask = self.expand_mask(base_mask.copy(), expansion_distance) + current_expansion_mask = cv2.bitwise_and(current_expansion_mask, + cv2.bitwise_not(previous_expansion_mask)) + current_expansion_mask = cv2.bitwise_and(current_expansion_mask, voronoi_mask) + + # Label this expanded mask (connected components) + labeled_expansion = label(current_expansion_mask > 0) + labeled_mask = np.zeros_like(labeled_expansion, dtype=np.int32) + + # For each component in expansion assign a unique label encoding: + # parent_id * 1000 + expansion_distance (assuming expansion_distance < 1000) + # This allows tracing expansions to parent + # label_value = parent_id * 1000 + expansion_distance + label_value = parent_id + + labeled_mask[labeled_expansion > 0] = label_value + + # Update global referenced mask — careful to avoid overwriting originals + referenced_labeled_mask[labeled_mask > 0] = label_value + + key = f'{category_name}_expansion_{expansion_distance}_parent_{parent_id}' + expanded_masks[key] = current_expansion_mask + expanded_labeled_masks[key] = labeled_mask + + previous_expansion_mask = cv2.bitwise_or(previous_expansion_mask, current_expansion_mask) + + # Combine all masks and labeled masks + masks.update(expanded_masks) + labeled_masks.update(expanded_labeled_masks) + # Step 3: Aggregate masks by class and expansion name + aggregate_binary = {} + aggregate_labeled = {} + aggregate_referenced = {} + + for key, mask in masks.items(): + parts = key.split('_') + + if 'expansion' in parts: + category = parts[0] + expansion_distance = parts[2] + agg_key = f"{category}_expansion_{expansion_distance}" + else: + category = parts[0] + agg_key = category + + if agg_key not in aggregate_binary: + aggregate_binary[agg_key] = np.zeros_like(mask) + aggregate_labeled[agg_key] = np.zeros_like(mask, dtype=np.int32) + aggregate_referenced[agg_key] = np.zeros_like(mask, dtype=np.int32) + + aggregate_binary[agg_key] = cv2.bitwise_or(aggregate_binary[agg_key], mask) + aggregate_labeled[agg_key] = np.maximum(aggregate_labeled[agg_key], labeled_masks[key]) + + # Referenced mask is pulled from the global referenced_labeled_mask + aggregate_referenced[agg_key] = np.maximum( + aggregate_referenced[agg_key], + np.where(mask > 0, referenced_labeled_mask, 0) + ) + + # Final output + self.binary_masks = aggregate_binary + self.labeled_masks = aggregate_labeled + self.referenced_masks = aggregate_referenced + return self.binary_masks, self.labeled_masks, self.referenced_masks + + def plot_masks_with_voronoi(self, + mask_colors: Dict[str, Tuple[int, int, int]], + background_color: Tuple[int, int, int] = (255, 255, 255), + show: bool = True, + axes: Optional["matplotlib.axes.Axes"] = None, + figsize: Tuple[int, int] = (8, 8) + ) -> Optional["matplotlib.axes.Axes"]: + """ + Plots the generated masks overlaid with Voronoi edges. + + Args: + mask_colors (Dict[str, Tuple[int, int, int]]): Mapping from class name to RGB color. + background_color (Tuple[int, int, int], optional): RGB color for background. Defaults to white. + show (bool, optional): If True, displays the plot. Defaults to True. + axes (matplotlib.axes.Axes, optional): Existing axes to plot on. + figsize (Tuple[int, int], optional): Figure size for new plot. + + Returns: + matplotlib.axes.Axes: The plot axes (if `axes` was provided). + """ + masks = self.binary_masks + background = np.full((self.height, self.width, 3), background_color, dtype=np.uint8) + fig, ax = plt.subplots(figsize=figsize) if axes is None else (None, axes) + legend_patches = [] + seen_classes = set() + + for mask_name, mask in masks.items(): + # Identify base class: 'gd' or 'cd8' from names like 'gd_expansion_30_0' + base_class = mask_name.split('_')[0] + + # Get color for this base class + color = np.array(mask_colors.get(base_class, (128, 128, 128))) + background[mask != 0] = color + + # Add legend entry only once per base class + if base_class not in seen_classes: + legend_patches.append(mpatches.Patch(color=color / 255, label=base_class)) + seen_classes.add(base_class) + + ax.imshow(background, origin='lower') + + # Draw Voronoi edges + if self.vor: + voronoi_plot_2d(self.vor, ax=ax, show_vertices=False, line_colors='black', line_alpha=0.6) + + # Plot centroids (smaller dots) + if self.all_centroids is not None: + centroids = np.array(self.all_centroids) + ax.plot(centroids[:, 0], centroids[:, 1], '*', markersize=1, alpha=0.6) + + # Add clean legend (gd, cd8) + ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', bbox_transform=ax.transAxes) + + if self.save_path: + save_path = os.path.join(self.save_path, 'masks_with_voronoi_edges.png') + plt.savefig(save_path, dpi=1000, bbox_inches='tight') + self.logger.info(f'Plot saved at {save_path}') + + if show: + plt.show() + + return ax if axes is not None else None diff --git a/gridgene/logger.py b/gridgene/logger.py new file mode 100644 index 0000000..2563e17 --- /dev/null +++ b/gridgene/logger.py @@ -0,0 +1,31 @@ +import logging + + +def get_logger(name: str = None, level=logging.INFO): + """ + Get a logger instance with a standard format and no duplicated handlers. + + Parameters + ---------- + name : str, optional + Name of the logger. Defaults to None which uses root logger. + level : int, optional + Logging level, defaults to logging.INFO. + + Returns + ------- + logging.Logger + Configured logger instance. + """ + logger = logging.getLogger(name) + logger.setLevel(level) + + if not logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger diff --git a/gridgene/mask_properties.py b/gridgene/mask_properties.py new file mode 100644 index 0000000..b4573c9 --- /dev/null +++ b/gridgene/mask_properties.py @@ -0,0 +1,500 @@ +from dataclasses import dataclass +from typing import List, Optional, Dict, Any +import numpy as np +import pandas as pd +from skimage.measure import label, regionprops_table, regionprops +from gridgene.logger import get_logger +from functools import wraps +import os +import time +import logging +# todo change to receive the logger from the main module +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + print(f"{func.__name__} took {end - start:.4f} seconds") + return result + return wrapper + + +@dataclass +class MaskDefinition: + """ + Definition of a mask to analyze. + + Parameters + ---------- + mask : np.ndarray + Binary mask array. + mask_name : str + Name identifier for the mask. + analysis_type : str, optional + Type of analysis ('per_object', 'bulk', or 'grid'), by default "per_object". + grid_size : int or None, optional + Grid size for 'grid' analysis type, by default None. + """ + mask: np.ndarray + mask_name: str + analysis_type: str = "per_object" # "per_object", "bulk", "grid" + grid_size: Optional[int] = None + +@dataclass +class MaskAnalysisResult: + """ + Container for the results of mask analysis. + + Parameters + ---------- + mask_name : str + Name of the analyzed mask. + analysis_type : str + Analysis type performed. + features : list of dict + List of extracted features per object. + """ + mask_name: str + analysis_type: str + features: List[Dict[str, Any]] + +class MorphologyExtractor: + """ + Extracts morphological features from labeled masks. + """ + + def extract_per_object_features(self, labeled_mask: np.ndarray) -> List[Dict[str, Any]]: + """ + Extract per-object morphological features from a labeled mask. + + Parameters + ---------- + labeled_mask : np.ndarray + Mask where each object is labeled with an integer. + + Returns + ------- + list of dict + List of dictionaries containing features per object. + """ + properties = [ + 'label', + 'area', + 'perimeter', + 'eccentricity', + 'solidity', + 'centroid', + 'bbox', + ] + props = regionprops_table(labeled_mask, properties=properties) + + rows = [] + for i in range(len(props['label'])): + row = { + 'object_id': props['label'][i], + 'area': props['area'][i], + 'perimeter': props['perimeter'][i], + 'eccentricity': props['eccentricity'][i], + 'solidity': props['solidity'][i], + 'centroid_y': props['centroid-0'][i], + 'centroid_x': props['centroid-1'][i], + 'min_row': props['bbox-0'][i], + 'min_col': props['bbox-1'][i], + 'max_row': props['bbox-2'][i], + 'max_col': props['bbox-3'][i], + } + rows.append(row) + + return rows + + def extract_bulk_features(self, mask: np.ndarray) -> List[Dict[str, Any]]: + """ + Extract bulk features for a whole mask (e.g., total area). + + Parameters + ---------- + mask : np.ndarray + Binary mask. + + Returns + ------- + list of dict + Single-item list with total area and object_id='bulk'. + """ + total_area = int(np.sum(mask)) + return [{'area': total_area, 'object_id': 'bulk'}] + + def extract_grid_features_per_object(self, labeled_mask: np.ndarray, grid_size: int) -> List[Dict[str, Any]]: + """ + Extract features per grid tile for each labeled object. + """ + features = [] + for region in regionprops(labeled_mask): + object_id = region.label + coords = region.coords # N x 2 array: (row, col) + grid_map = {} + + for y_px, x_px in coords: + gx = (x_px // grid_size) * grid_size + gy = (y_px // grid_size) * grid_size + key = (gx, gy) + + if key not in grid_map: + grid_map[key] = {'area': 0, 'x_sum': 0, 'y_sum': 0, 'count': 0} + + grid_map[key]['area'] += 1 + grid_map[key]['x_sum'] += x_px + grid_map[key]['y_sum'] += y_px + grid_map[key]['count'] += 1 + + for (gx, gy), values in grid_map.items(): + features.append({ + 'x': gx, + 'y': gy, + 'object_id': f'{object_id}', + 'area': values['area'], + 'centroid_x': values['x_sum'] / values['count'], + 'centroid_y': values['y_sum'] / values['count'], + }) + + return features +class GeneCounter: + """ + Counts gene expression values within masks. + """ + + def count_genes_per_object(self, labeled_mask: np.ndarray, array_counts: np.ndarray, target_dict: Dict[str, int]) -> List[Dict[str, Any]]: + """ + Count genes per labeled object. + + Parameters + ---------- + labeled_mask : np.ndarray + Labeled mask array. + array_counts : np.ndarray + 3D array of gene counts per pixel. + target_dict : dict + Mapping from gene names to indices in array_counts. + + Returns + ------- + list of dict + List of gene counts per object. + """ + results = [] + for obj_id in np.unique(labeled_mask): + if obj_id == 0: + continue + mask = labeled_mask == obj_id + # counts = np.einsum('ijk,ij->k', array_counts.astype(np.int16), mask.astype(np.int8)) + counts = array_counts[mask].sum(axis=0, dtype=np.int64) + + counts_dict = {gene: counts[i] for gene, i in target_dict.items()} + counts_dict['object_id'] = obj_id + results.append(counts_dict) + return results + + def count_genes_bulk(self, mask: np.ndarray, array_counts: np.ndarray, target_dict: Dict[str, int]) -> List[Dict[str, Any]]: + """ + Count genes in a bulk mask. + + Parameters + ---------- + mask : np.ndarray + Binary mask. + array_counts : np.ndarray + 3D array of gene counts. + target_dict : dict + Mapping from gene names to indices. + + Returns + ------- + list of dict + Single-item list with gene counts. + """ + mask = mask.astype(bool) + # counts = np.einsum('ijk,ij->k', array_counts.astype(np.int64), mask.astype(np.int64)) + counts = array_counts[mask].sum(axis=0, dtype=np.int64) + + counts_dict = {gene: counts[i] for gene, i in target_dict.items()} + counts_dict['object_id'] = 'bulk' + return [counts_dict] + + def count_genes_grid_per_object(self, labeled_mask: np.ndarray, array_counts: np.ndarray, + target_dict: Dict[str, int], grid_size: int) -> List[Dict[str, Any]]: + """ + Count genes per grid tile within each object. + """ + results = [] + for obj_id in np.unique(labeled_mask): + if obj_id == 0: + continue + coords = np.argwhere(labeled_mask == obj_id) + grid_map = {} + + for y_px, x_px in coords: + gx = (x_px // grid_size) * grid_size + gy = (y_px // grid_size) * grid_size + key = (gx, gy) + + if key not in grid_map: + grid_map[key] = [] + + grid_map[key].append((y_px, x_px)) + + for (gx, gy), pixels in grid_map.items(): + ys, xs = zip(*pixels) + gene_counts = array_counts[ys, xs].sum(axis=0, dtype=np.int64) + counts_dict = {gene: gene_counts[i] for gene, i in target_dict.items()} + counts_dict['object_id'] = f'{obj_id}' + counts_dict['x'] = gx + counts_dict['y'] = gy + results.append(counts_dict) + + return results + +class HierarchyMapper: + """ + Maps child objects to parent objects based on label overlaps. + """ + + def map_hierarchy(self, source_labels: np.ndarray, target_labels: np.ndarray) -> Dict[int, List[int]]: + """ + Map each source object ID to a list of parent object IDs from target mask. + + Parameters + ---------- + source_labels : np.ndarray + Labeled source mask. + target_labels : np.ndarray + Labeled target mask (parent). + + Returns + ------- + dict + Mapping from source object ID to list of parent IDs. + """ + mapping = {} + for src_id in np.unique(source_labels): + if src_id == 0: + continue + overlap = target_labels[source_labels == src_id] + mapping[src_id] = list(np.unique(overlap[overlap > 0])) + return mapping + +class MaskAnalysisPipeline: + """ + Main pipeline for analyzing masks with gene counts and morphology. + """ + + def __init__(self, mask_definitions: List[MaskDefinition], array_counts: np.ndarray, target_dict: Dict[str, int], + logger: Optional[logging.Logger] = None) -> None: + """ + Initialize the pipeline. + + Parameters + ---------- + mask_definitions : list of MaskDefinition + List of mask definitions. + array_counts : np.ndarray + 3D gene counts array. + target_dict : dict + Mapping gene names to indices in array_counts. + logger : logging.Logger or None, optional + Logger instance, by default None. + """ + self.mask_definitions = mask_definitions + self.array_counts = array_counts + self.target_dict = target_dict + self.extractor = MorphologyExtractor() + self.counter = GeneCounter() + self.results: List[MaskAnalysisResult] = [] + self.labeled_masks: Dict[str, np.ndarray] = {} # Store labeled versions of masks + self.logger = logger or get_logger(f'{__name__}.{"GetMasks"}') + self.logger.info(f"Initialized MaskAnalysisPipeline with {len(mask_definitions)} masks.") + + @timeit + def run(self) -> List[MaskAnalysisResult]: + """ + Run the full analysis pipeline on all mask definitions. + + Returns + ------- + list of MaskAnalysisResult + List of results per mask. + """ + self.results.clear() + + for defn in self.mask_definitions: + self.logger.info(f"Processing mask: {defn.mask_name} ({defn.analysis_type})") + + # if defn.analysis_type == 'per_object': + # labeled = label(defn.mask) + # self.labeled_masks[defn.mask_name] = labeled + if defn.analysis_type == "per_object" and defn.mask_name not in self.labeled_masks: + self.labeled_masks[defn.mask_name] = label(defn.mask) + morpho = self.extractor.extract_per_object_features(self.labeled_masks[defn.mask_name]) + counts = self.counter.count_genes_per_object(self.labeled_masks[defn.mask_name], self.array_counts.astype(np.int16), self.target_dict) + merged = self._merge_dicts_by_key(morpho, counts, 'object_id') + + elif defn.analysis_type == 'bulk': + morpho = self.extractor.extract_bulk_features(defn.mask) + counts = self.counter.count_genes_bulk(defn.mask, self.array_counts.astype(np.int16), self.target_dict) + merged = self._merge_dicts_by_key(morpho, counts, 'object_id') + + # elif defn.analysis_type == 'grid': + # if defn.grid_size is None: + # raise ValueError("Grid size required for grid analysis.") + # morpho = self.extractor.extract_grid_features(defn.mask, defn.grid_size) + # counts = self.counter.count_genes_grid(defn.mask, self.array_counts.astype(np.int16), self.target_dict, defn.grid_size) + # merged = self._merge_dicts_by_key(morpho, counts, 'object_id') if counts else morpho + elif defn.analysis_type == 'grid': + if defn.grid_size is None: + raise ValueError("Grid size required for grid analysis.") + if defn.mask_name not in self.labeled_masks: + self.labeled_masks[defn.mask_name] = label(defn.mask) + + labeled = self.labeled_masks[defn.mask_name] + + morpho = self.extractor.extract_grid_features_per_object(labeled, defn.grid_size) + counts = self.counter.count_genes_grid_per_object( + labeled, self.array_counts.astype(np.int16), self.target_dict, defn.grid_size + ) + merged = self._merge_dicts_by_key(morpho, counts, 'object_id') if counts else morpho + + else: + raise ValueError(f"Unsupported analysis type: {defn.analysis_type}, should be one of 'per_object', 'bulk', or 'grid'.") + + # Check for negative gene counts + for c in counts: + for gene, value in c.items(): + if gene != 'object_id' and value < 0: + print(f"Warning: Negative count for gene '{gene}' in object '{c.get('object_id')}'") + + for item in merged: + item['mask_name'] = defn.mask_name + item['analysis_type'] = defn.analysis_type + + self.results.append(MaskAnalysisResult(defn.mask_name, defn.analysis_type, merged)) + # self.logger.info(f"Finished {defn.mask_name} in {elapsed:.2f} sec") + + return self.results + + def get_results_df(self) -> pd.DataFrame: + """ + Get all results concatenated into a single pandas DataFrame. + + Returns + ------- + pandas.DataFrame + DataFrame with all extracted features. + """ + if not self.results: + self.run() + all_features = [item for r in self.results for item in r.features] + return pd.DataFrame(all_features) + + def _merge_dicts_by_key(self, list1: List[Dict[str, Any]], list2: List[Dict[str, Any]], key: str) -> List[ + Dict[str, Any]]: + """ + Merge two lists of dictionaries by matching values of a specified key. + + Parameters + ---------- + list1 : list of dict + First list of dictionaries. + list2 : list of dict + Second list of dictionaries. + key : str + Key to merge on. + + Returns + ------- + list of dict + Merged list of dictionaries. + """ + if not list1: + return list2 + if not list2: + return list1 + index2 = {d[key]: d for d in list2} + return [{**d1, **index2.get(d1[key], {})} for d1 in list1] + + @timeit + def map_hierarchies(self, hierarchy_definitions: Dict[str, Dict[str, Any]], save_dir: Optional[str] = None) -> pd.DataFrame: + """ + Map child objects to their parent objects using reference labeled masks. + + Parameters + ---------- + hierarchy_definitions : dict + Dictionary defining the hierarchy relationships. + save_dir : str or None, optional + Directory to save labeled masks, by default None. + + Returns + ------- + pandas.DataFrame + DataFrame with mapping of child to parent objects. + """ + records = [] + + for child_name, definition in hierarchy_definitions.items(): + # reference_labels = definition["labels"] + # parent_name = definition["level_hierarchy"] + # + # # Make sure both masks are labeled + # if parent_name not in self.labeled_masks: + # self.labeled_masks[parent_name] = label( + # next(d.mask for d in self.mask_definitions if d.mask_name == parent_name) + # ) + # if save_dir: + # os.makedirs(save_dir, exist_ok=True) + # np.save(os.path.join(save_dir, f"{parent_name}_labeled.npy"), self.labeled_masks[parent_name]) + # + # parent_labels = self.labeled_masks[parent_name] + # + # mapper = HierarchyMapper() + # hierarchy_map = mapper.map_hierarchy(reference_labels, parent_labels) + referenced_labels = definition["labels"] # The parent IDs per pixel + parent_name = definition["level_hierarchy"] + + # Get child labels: this corresponds to the same expansion + child_labels = next( + d for d in self.mask_definitions if d.mask_name == child_name + ).mask # binary — we need to label it! + + if child_name not in self.labeled_masks: + self.labeled_masks[child_name] = label(child_labels) + + labeled_child = self.labeled_masks[child_name] + + mapper = HierarchyMapper() + hierarchy_map = mapper.map_hierarchy(labeled_child, referenced_labels) + + # Update results + for result in self.results: + if result.mask_name == child_name: + for row in result.features: + obj_id = row.get("object_id") + try: + int_obj_id = int(obj_id) + except Exception: + continue + row["parent_ids"] = hierarchy_map.get(int_obj_id, []) + + # Collect for output + for obj_id, parent_ids in hierarchy_map.items(): + records.append({ + "mask_name": child_name, + "object_id": obj_id, + "parent_mask": parent_name, + "parent_ids": parent_ids + }) + + # Optionally save reference label + if save_dir: + np.save(os.path.join(save_dir, f"{child_name}_ref_labels.npy"), reference_labels) + + return pd.DataFrame(records) \ No newline at end of file diff --git a/gridgene/overlay.py b/gridgene/overlay.py new file mode 100644 index 0000000..ca5e1ac --- /dev/null +++ b/gridgene/overlay.py @@ -0,0 +1,465 @@ +import logging +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap, Normalize +from matplotlib import patches as mpatches +from matplotlib import patches +from PIL import Image, ImageDraw +import spatialdata as sd +import xarray as xr +from functools import wraps +import time +import logging +from functools import wraps +from typing import Optional, Dict, List, Tuple + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.colors import Normalize, ListedColormap +from PIL import Image, ImageDraw +import shapely.geometry as sg + +try: + import spatialdata as sd +except ImportError: + sd = None + +from shapely.geometry import Polygon +from matplotlib.colors import Normalize +from scipy.spatial import ConvexHull +from gridgene.logger import get_logger + +# TODO spatialdata support + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + self = args[0] # assumes it's a method + start = time.time() + result = func(*args, **kwargs) + duration = end = time.time() - start + if hasattr(self, "logger"): + self.logger.info(f"{func.__name__} took {duration:.4f} seconds") + else: + print(f"{func.__name__} took {duration:.4f} seconds") + return result + return wrapper + +class Overlay: + """ + Overlay segmentation with binary masks for comparison and visualization. + + Supports polygon-based and label-mask segmentations. + + Parameters + ---------- + mask_dict : dict[str, np.ndarray] + Dictionary of masks, where keys are mask names and values are binary mask arrays. + segmentation : dict or np.ndarray or SpatialData + Segmentation data, either GeoJSON-like dict, label mask ndarray, or SpatialData. + segmentation_type : {'auto', 'polygons', 'label_mask'}, optional + Type of segmentation input. Default is 'auto' to detect automatically. + save_path : str, optional + Optional path to save visualizations. + min_x : float, optional + Minimum x coordinate to shift polygons. Default is 0. + min_y : float, optional + Minimum y coordinate to shift polygons. Default is 0. + flip_masks : bool, optional + Whether to flip masks vertically and rotate. Default is True. + logger : logging.Logger, optional + Custom logger. If None, a default logger is used. + """ + + def __init__(self, mask_dict: Dict[str, np.ndarray], segmentation, segmentation_type: str = "auto", + save_path: Optional[str] = None, min_x: float = 0, min_y: float = 0, flip_masks: bool = True, + logger: Optional[logging.Logger] = None) -> None: + """ + Initialize the Overlay object. + + Parameters + ---------- + mask_dict : dict[str, np.ndarray] + Dictionary of masks {mask_name: mask_array}. + segmentation : dict or np.ndarray or SpatialData + Segmentation data (GeoJSON, label mask, or SpatialData). + segmentation_type : str, optional + Type of segmentation ('polygons', 'label_mask', or 'auto'), by default "auto". + save_path : str, optional + Optional path to save visualizations, by default None. + min_x : float, optional + Minimum x to shift polygons, by default 0. + min_y : float, optional + Minimum y to shift polygons, by default 0. + flip_masks : bool, optional + Whether to flip masks vertically and rotate, by default True. + logger : logging.Logger, optional + Optional custom logger, by default None. + """ + self.mask_dict = mask_dict + self.save_path = save_path + self.segmentation = segmentation + self.min_x = min_x + self.min_y = min_y + self.logger = logger or get_logger(__name__) + self.logger.info("Initialized Overlay") + + self.segmentation_type = self._detect_segmentation_type() if segmentation_type == "auto" else segmentation_type + if self.min_x != 0 or self.min_y != 0: + self.logger.info(f"Shifting polygons by min_x={self.min_x}, min_y={self.min_y}") + if self.segmentation_type == 'polygons': + self.shift_polygons() + self.results = None + if flip_masks: + self.logger.info("Flipping masks vertically") + for mask_name, mask in self.mask_dict.items(): + mask = np.flip(mask, 0) # Flip vertically + mask = np.rot90(mask, -1) # Rotate 90 degrees to the right + self.mask_dict[mask_name] = mask + @property + def mask_shape(self): + """ + Return the shape of the masks. + + Returns + ------- + tuple + Shape of the first mask in the dictionary. + """ + return next(iter(self.mask_dict.values())).shape + + def _detect_segmentation_type(self) -> str: + """ + Detect segmentation type based on input structure. + + Returns + ------- + str + Detected segmentation type: 'polygons' or 'label_mask'. + + Raises + ------ + ValueError + If segmentation type cannot be detected. + """ + if isinstance(self.segmentation, dict) and 'geometries' in self.segmentation: + return 'polygons' + elif isinstance(self.segmentation, np.ndarray): + return 'label_mask' + elif sd and isinstance(self.segmentation, sd.SpatialData): + if 'shapes' in self.segmentation: + return 'polygons' + elif 'labels' in self.segmentation: + return 'label_mask' + raise ValueError("Unable to detect segmentation type. Please specify it explicitly.") + + def shift_polygons(self) -> None: + """ + Shift all polygon coordinates by (min_x, min_y). + + Modifies + -------- + self.segmentation : dict + Polygon coordinates are shifted in place. + """ + for geometry in self.segmentation['geometries']: + polygon = np.array(geometry['coordinates'][0]) + shifted = polygon - np.array([self.min_x, self.min_y]) + geometry['coordinates'][0] = [(round(x), round(y)) for x, y in shifted] + + def _get_cell_masks_from_polygons(self) -> Dict[int, np.ndarray]: + """ + Create binary masks from polygon segmentation. + + Returns + ------- + dict[int, np.ndarray] + Dictionary mapping cell IDs to binary masks. + """ + masks = {} + shape = self.mask_shape + for geometry in self.segmentation['geometries']: + cell_id = int(geometry['cell']) + poly = [(round(x), round(y)) for x, y in geometry['coordinates'][0]] + img = Image.new('L', (shape[1], shape[0]), 0) + ImageDraw.Draw(img).polygon(poly, outline=1, fill=1) + masks[cell_id] = np.array(img) + return masks + + def _get_cell_masks_from_label_mask(self, label_mask) -> Dict[int, np.ndarray]: + """ + Convert labeled mask to binary masks per label. + + Parameters + ---------- + label_mask : np.ndarray + Label mask array. + + Returns + ------- + dict[int, np.ndarray] + Dictionary mapping label IDs to binary masks. + """ + return {cid: (label_mask == cid).astype(np.uint8) + for cid in np.unique(label_mask) if cid != 0} + + def _extract_segmentation_masks(self) -> Dict[int, np.ndarray]: + """ + Extract binary cell masks from segmentation input. + + Returns + ------- + dict[int, np.ndarray] + Dictionary mapping cell IDs to binary masks. + + Raises + ------ + ValueError + If segmentation type is unsupported. + """ + if self.segmentation_type == 'polygons': + if sd and isinstance(self.segmentation, sd.SpatialData): + shapes = self.segmentation.shapes[list(self.segmentation.shapes.keys())[0]] + polygons = [poly.exterior.coords for poly in shapes.geometry] + cell_ids = shapes.obs.get("cell", np.arange(len(shapes))) + geojson_like = {'geometries': [ + {'coordinates': [list(p)], 'cell': str(cid)} + for p, cid in zip(polygons, cell_ids) + ]} + return self._get_cell_masks_from_polygons() + return self._get_cell_masks_from_polygons() + elif self.segmentation_type == 'label_mask': + if sd and isinstance(self.segmentation, sd.SpatialData): + label_mask = self.segmentation.labels[list(self.segmentation.labels.keys())[0]].data.values + else: + label_mask = self.segmentation + return self._get_cell_masks_from_label_mask(label_mask) + raise ValueError("Unsupported segmentation type") + + @timeit + def compute_overlap(self): + """ + Compute overlap between masks and segmented regions. + + Returns + ------- + dict + Nested dictionary of overlap counts {cell_id: {mask_name: pixel_overlap_count}}. + + Raises + ------ + NotImplementedError + If overlap for label masks is requested. + """ + if self.segmentation_type == 'polygons': + self.map_mask_cell_polygons() + elif self.segmentation_type == 'label_mask': + # TODO add label mask overlap computation + raise NotImplementedError("Label mask overlap not implemented.") + self.logger.info("Computed overlap between masks and segmentation.") + return self.results + + def map_mask_cell_masks(self) -> None: + """ + Placeholder for label mask-based overlap computation. + """ + pass # Placeholder for label mask-based overlap + + def map_mask_cell_polygons(self) -> None: + """ + Compute overlap counts between each segmentation polygon and each mask. + + Stores results in `self.results` as: + {cell_id: {mask_name: pixel_overlap_count, ...}, ...} + """ + shape = self.mask_shape + results = {} + for geometry in self.segmentation['geometries']: + cell_id = int(geometry['cell']) + polygon = [(round(x), round(y)) for x, y in geometry['coordinates'][0]] + img = Image.new('L', (shape[1], shape[0]), 0) + ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1) + poly_mask = np.array(img, dtype=bool) + + results[cell_id] = {} + for mask_name, mask in self.mask_dict.items(): + results[cell_id][mask_name] = np.count_nonzero(mask[poly_mask]) + self.results = results + + def plot_masks_overlay_segmentation(self, titles: List[str], colors: List[str], background: str = 'white', + save_path: Optional[str] = None, show: bool = True, + show_legend: bool = True) -> None: + """ + Overlay binary masks and segmentation polygons for visualization. + + Parameters + ---------- + titles : list of str + Titles for each mask. + colors : list of str + Colors corresponding to each mask. + background : str, optional + Background color, by default 'white'. + save_path : str, optional + Path to save the overlay plot, by default None. + show : bool, optional + Whether to display the plot, by default True. + show_legend : bool, optional + Whether to show legend, by default True. + """ + if self.segmentation_type == 'polygons': + fig, ax = self._plot_masks_overlay_segmentation_polygons(titles, colors, background) + else: + raise NotImplementedError("Label mask overlay plot not implemented.") + + ax.set_title("Overlay of Masks and Segmentation", fontsize=16) + ax.axis('off') + if show_legend: + handles = [mpatches.Patch(color=color, label=label) for color, label in zip(colors, titles)] + legend = ax.legend( + handles=handles, + loc='lower left', + bbox_to_anchor=(0, -0.1), + frameon=True, + framealpha=1, + facecolor='white', + edgecolor='black', + fontsize=14 + ) + legend.set_title("Masks", prop={'size': 16}) + + plt.tight_layout() + if save_path: + fig.savefig(save_path, dpi=500) + self.logger.info(f"Saved overlay plot to {save_path}") + if show: + plt.show() + # else: + plt.close(fig) + + def _plot_masks_overlay_segmentation_polygons(self, titles: List[str], colors: List[str], + background: str = 'white') -> Tuple[plt.Figure, plt.Axes]: + """ + Internal: overlay binary masks and polygon outlines on canvas. + + Parameters + ---------- + titles : list of str + Titles for each mask. + colors : list of str + Colors for each mask. + background : str, optional + Background color, by default 'white'. + + Returns + ------- + tuple + Matplotlib figure and axes with overlay. + """ + if not hasattr(self, 'mask_shape'): + # Get the shape from the first mask in the dict + self.mask_shape = next(iter(self.mask_dict.values())).shape + + fig, ax = plt.subplots(figsize=(10, 10)) + cmap = ListedColormap([background] + colors) + + # Initialize canvas + overlay = np.zeros(self.mask_shape, dtype=np.int64) + + # Combine masks + for i, (mask_name, mask) in enumerate(self.mask_dict.items()): + overlay += (mask * (i + 1)) + + # Show image + ax.imshow(overlay, cmap=cmap, origin='lower') + + # Plot polygons + for geometry in self.segmentation['geometries']: + polygon = np.array(geometry['coordinates'][0]) + ax.plot(polygon[:, 0], polygon[:, 1], color='black', linewidth=0.3) + + ax.axis('off') + return fig, ax + + def plot_colored_by_mask_overlap(self, mask_to_color: List[str], color_map: str = 'Reds', show: bool = True, + save_path: Optional[str] = None, figsize: Tuple[int, int] = (15, 15)) -> None: + """ + Color segmented polygons based on overlap percentage with specified masks. + + Parameters + ---------- + mask_to_color : list of str + Mask names to base coloring on. + color_map : str, optional + Matplotlib colormap, by default 'Reds'. + show : bool, optional + Whether to display the plot, by default True. + save_path : str, optional + Path to save the plot, by default None. + figsize : tuple, optional + Figure size, by default (15, 15). + """ + if self.results is None: + self.logger.info("No results found. Computing overlap...") + self.compute_overlap() + if self.segmentation_type == 'polygons': + fig, ax = self._plot_colored_by_mask_overlap_polygons(mask_to_color, color_map, figsize) + else: + raise NotImplementedError("Only polygon plotting is implemented.") + + if save_path: + fig.savefig(save_path, dpi=500) + self.logger.info(f"Saved colored overlap plot to {save_path}") + + if show: + plt.show() + # else: + plt.close(fig) # Save memory if not showing + + def _plot_colored_by_mask_overlap_polygons(self, mask_to_color: List[str], color_map: str, + figsize: Tuple[int, int]) -> Tuple[plt.Figure, plt.Axes]: + """ + Internal: color polygons based on mask overlap percentages. + + Parameters + ---------- + mask_to_color : list of str + Masks used for coloring. + color_map : str + Colormap name. + figsize : tuple + Figure size. + + Returns + ------- + tuple + Matplotlib figure and axes. + """ + fig, ax = plt.subplots(1, 1, figsize=figsize) + + for feature in self.segmentation['geometries']: + cell_id = int(feature['cell']) + polygon = feature['coordinates'][0] + polygon_array = np.array(polygon) + polygon_results = self.results.get(cell_id, {}) + total_pixels = sum(polygon_results.values()) + + if total_pixels != 0: + percentages = {k: v / total_pixels * 100 for k, v in polygon_results.items()} + else: + percentages = {k: 0 for k in mask_to_color} + + total_percentage = sum(percentages.get(k, 0) for k in mask_to_color) + color = plt.get_cmap(color_map)(Normalize(0, 100)(total_percentage)) + + patch = mpatches.Polygon(polygon_array, facecolor=color, edgecolor='black', linewidth=0.3) + ax.add_patch(patch) + ax.axis('off') + all_coords = np.vstack([np.array(f['coordinates'][0]) for f in self.segmentation['geometries']]) + ax.set_xlim(all_coords[:, 0].min(), all_coords[:, 0].max()) + ax.set_ylim(all_coords[:, 1].min(), all_coords[:, 1].max()) + ax.set_aspect('equal') + + ax.axis('off') + return fig, ax From 1627868b492aef9082a44004ae9f4af3e9a39acf Mon Sep 17 00:00:00 2001 From: Ana Marta Sequeira Date: Thu, 28 Aug 2025 16:42:55 +0100 Subject: [PATCH 2/3] Rename package to GRIDGENE - tests --- tests/test_binsom.py | 2 +- tests/test_contours.py | 8 ++++---- tests/test_masks.py | 2 +- tests/test_overlay.py | 2 +- tests/test_properties.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_binsom.py b/tests/test_binsom.py index 2c8cf63..93ecf96 100644 --- a/tests/test_binsom.py +++ b/tests/test_binsom.py @@ -3,7 +3,7 @@ import pandas as pd import logging from anndata import AnnData -from gridgen.binsom import GetBins, GetContour # Assuming your classes are in `binsom.py` +from gridgene.binsom import GetBins, GetContour # Assuming your classes are in `binsom.py` class TestBinSOM(unittest.TestCase): diff --git a/tests/test_contours.py b/tests/test_contours.py index 07a6fbe..f48a891 100644 --- a/tests/test_contours.py +++ b/tests/test_contours.py @@ -7,9 +7,9 @@ from matplotlib.axes import Axes import matplotlib.pyplot as plt from matplotlib.figure import Figure -from gridgen.contours import GetContour -from gridgen.contours import ConvolutionContours -from gridgen.contours import KDTreeContours +from gridgene.contours import GetContour +from gridgene.contours import ConvolutionContours +from gridgene.contours import KDTreeContours def make_dummy_contour(center, radius, points=8): @@ -180,4 +180,4 @@ def test_plot_dbscan_labels(self): unittest.main() -# (GRIDGEN) martinha@gaia:~/PycharmProjects/phd/spatial_transcriptomics/GRIDGEN$ python -m unittest tests/test_contours.py +# (GRIDGENE) martinha@gaia:~/PycharmProjects/phd/spatial_transcriptomics/GRIDGENE$ python -m unittest tests/test_contours.py diff --git a/tests/test_masks.py b/tests/test_masks.py index 3da05d9..561ea70 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -2,7 +2,7 @@ import numpy as np import cv2 import os -from gridgen.get_masks import (GetMasks, ConstrainedMaskExpansion, +from gridgene.get_masks import (GetMasks, ConstrainedMaskExpansion, SingleClassObjectAnalysis, MultiClassObjectAnalysis) import logging from unittest.mock import MagicMock diff --git a/tests/test_overlay.py b/tests/test_overlay.py index ac4c255..6226389 100644 --- a/tests/test_overlay.py +++ b/tests/test_overlay.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from gridgen.overlay import Overlay # Assuming your class is in a file named overlay.py +from gridgene.overlay import Overlay # Assuming your class is in a file named overlay.py from PIL import Image class TestOverlay(unittest.TestCase): diff --git a/tests/test_properties.py b/tests/test_properties.py index 7dbc2e0..e42334b 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from gridgen.mask_properties import MaskDefinition, MaskAnalysisPipeline +from gridgene.mask_properties import MaskDefinition, MaskAnalysisPipeline from skimage.measure import label, regionprops_table From 5e4b86de41adc09b980ff38dc58e6ad608a0d5b3 Mon Sep 17 00:00:00 2001 From: Ana Marta Sequeira Date: Thu, 28 Aug 2025 16:43:47 +0100 Subject: [PATCH 3/3] Rename package to GRIDGENE - setup --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d4b77a7..1a1a1b1 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,10 @@ from setuptools import setup, find_packages setup( - name="gridgen", + name="gridgene", version="0.1.0", author="AM Sequeira", - description="GRIDGEN project", + description="GRIDGENE project", packages=find_packages(), install_requires=[ "numpy",