diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 171f69f4..a0935c6d 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -26,8 +26,14 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -e .[dev,test] + - name: Check style + run: | + black stlearn tests + ruff check stlearn tests + - name: Check types + run: | + mypy stlearn tests - name: Test with pytest run: | - pytest + pytest \ No newline at end of file diff --git a/.gitignore b/.gitignore index c5ab06d4..fcd9e498 100644 --- a/.gitignore +++ b/.gitignore @@ -1,21 +1,71 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class *.pyc -.ipynb_checkpoints -*/.ipynb_checkpoints/* + +# C extensions +*.so + +# Distribution / packaging +.Python build/ +docs/api/ +docs/_build/ +docs/generated/ +data/samples +develop-eggs/ dist/ -*.egg-info +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Unit tests / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Jupyter Notebook +.ipynb_checkpoints +*/.ipynb_checkpoints/* /*.ipynb + +# Data files /*.csv -output/ + +# MacOS caching .DS_Store */.DS_Store + +# PyCharm etc .idea/ -docs/_build + +# Sphinx documentation +docs.bk/_build + +# Distribution/package/temporary files data/ tiling/ -.pytest_cache figures/ *.h5ad -inferCNV/ -stlearn/tools/microenv/cci/junk_code.py -stlearn/tools/microenv/cci/.Rhistory diff --git a/.readthedocs.yml b/.readthedocs.yml index 6a8f1a14..e841d344 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,4 +1,25 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version, and other tools you might need build: - image: latest + os: ubuntu-24.04 + tools: + python: "3.10" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Optionally, but recommended, +# declare the Python requirements required to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html python: - version: 3.8 + install: + - method: pip + path: . + extra_requirements: + - dev \ No newline at end of file diff --git a/AUTHORS.rst b/AUTHORS.rst index d30eaa6e..a024f3f5 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -5,9 +5,12 @@ Credits Development Lead ---------------- -* Genomics and Machine Learning lab +* Genomics and Machine Learning Lab Contributors ------------ -None yet. Why not be the first? +* Brad Balderson +* Andrew Newman +* Duy Pham +* Xiao Tan diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index aa232892..b9769b45 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -64,11 +64,19 @@ Ready to contribute? Here's how to set up `stlearn` for local development. $ git clone git@github.com:your_name_here/stlearn.git -3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: +3. Install your local copy into a virtualenv. This is how you set up your fork for local development:: - $ mkvirtualenv stlearn + $ conda create -n stlearn-dev python=3.10 --y + $ conda activate stlearn-dev $ cd stlearn/ - $ python setup.py develop + $ pip install -e .[dev,test] + + Or if you prefer pip/virtualenv:: + + $ python -m venv stlearn-env + $ source stlearn-env/bin/activate # On Windows: stlearn-env\Scripts\activate + $ cd stlearn/ + $ pip install -e .[dev,test] 4. Create a branch for local development:: @@ -76,14 +84,16 @@ Ready to contribute? Here's how to set up `stlearn` for local development. Now you can make your changes locally. -5. When you're done making changes, check that your changes pass flake8 and the - tests, including testing other Python versions with tox:: +5. When you're done making changes, check that your changes pass linters and tests:: - $ flake8 stlearn tests - $ python setup.py test or pytest - $ tox + $ black stlearn tests + $ ruff check stlearn tests + $ mypy stlearn tests + $ pytest + +Or run everything with tox:: - To get flake8 and tox, just pip install them into your virtualenv. + $ tox 6. Commit your changes and push your branch to GitHub:: diff --git a/HISTORY.rst b/HISTORY.rst index 39a6759c..815cb9dd 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,24 +2,51 @@ History ======= +1.1.0 (2025-07-02) +------------------ +* Support Python 3.10.x +* Added quality checks black, ruff and mypy and fixed appropriate source code. +* Copy parameters now work with the same semantics as scanpy. +* Library upgrades for leidenalg, louvain, numba, numpy, scanpy, and tensorflow. +* datasets.xenium_sge - loads Xenium data (and caches it) similar to scanpy.visium_sge. + +API and Bug Fixes: +* Xenium TIFF and cell positions are now aligned. +* Consistent with type annotations - mainly missing None annotations. +* pl.cluster_plot - Does not keep colours from previous runs when clustering. +* pl.trajectory.pseudotime_plot - Fix typing of cluster values in .uns["split_node"]. +* Removed datasets.example_bcba - Replaced with wrapper for scanpy.visium_sge. +* Moved spatials directory to spatial, cleaned up pl and tl packages. + 0.4.11 (2022-11-25) ------------------ + 0.4.10 (2022-11-22) ------------------ + 0.4.8 (2022-06-15) ------------------ + 0.4.7 (2022-03-28) ------------------ + 0.4.6 (2022-03-09) ------------------ + 0.4.5 (2022-03-02) ------------------ + 0.4.0 (2022-02-03) ------------------ + 0.3.2 (2021-03-29) ------------------ + 0.3.1 (2020-12-24) ------------------ + 0.2.7 (2020-09-12) ------------------ + 0.2.6 (2020-08-04) +------------------ diff --git a/LICENSE b/LICENSE index 626beb6e..fafffeca 100644 --- a/LICENSE +++ b/LICENSE @@ -1,8 +1,6 @@ - - BSD License -Copyright (c) 2020, Genomics and Machine Learning lab +Copyright (c) 2020-2025, Genomics and Machine Learning lab All rights reserved. Redistribution and use in source and binary forms, with or without modification, diff --git a/docs/Makefile b/docs/Makefile index 96688bf3..d4bb2cbb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,10 +1,10 @@ # Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = python -msphinx -SPHINXPROJ = stlearn +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css new file mode 100644 index 00000000..6beb551f --- /dev/null +++ b/docs/_static/css/custom.css @@ -0,0 +1,5 @@ +/* Custom styling for stLearn documentation */ + +p img { + vertical-align: bottom +} diff --git a/docs/_temp/example_cci.py b/docs/_temp/example_cci.py deleted file mode 100644 index 15fe9a84..00000000 --- a/docs/_temp/example_cci.py +++ /dev/null @@ -1,180 +0,0 @@ -# """ Example code for running CCI analysis using new interface/approach. - -# Tested: * Within-spot mode -# * Between-spot mode - -# TODO tests: * Above with cell heterogeneity information -# """ - -################################################################################ -# Environment setup # -################################################################################ -import stlearn as st -import matplotlib.pyplot as plt - -################################################################################ -# Load your data # -################################################################################ -# TODO - load as an AnnData & perform usual pre-processing. -data = None # replace with your code - -# """ # Adding cell heterogeneity information if you have it. -# st.add.labels(data, 'tutorials/label_transfer_bc.csv', sep='\t') -# st.pl.cluster_plot(data, use_label="predictions") -# """ - -################################################################################ -# Performing cci_rank analysis # -################################################################################ -# Load the NATMI literature-curated database of LR pairs, data formatted # -lrs = st.tl.cci.load_lrs(["connectomeDB2020_lit"]) - -st.tl.cci.run( - data, - lrs, - use_label=None, # Need to add the label transfer results to object first, above code puts into 'label_transfer' - use_het="cell_het", # Slot for cell het. results in adata.obsm, only if use_label specified - min_spots=6, # Filter out any LR pairs with no scores for less than 6 spots - distance=None, # distance=0 for within-spot mode, None to auto-select distance to nearest neighbourhood. - n_pairs=1000, # Number of random pairs to generate - adj_method="fdr_bh", # MHT correction method - min_expr=0, # min expression for gene to be considered expressed. - pval_adj_cutoff=0.05, -) -# """ -# Example output: - -# Calculating neighbours... -# 0 spots with no neighbours, 6 median spot neighbours. -# Spot neighbour indices stored in adata.uns['spot_neighbours'] -# Altogether 1393 valid L-R pairs -# Generating random gene pairs... -# Generating the background... -# Calculating p-values for each LR pair in each spot...: 100%|██████████ [ time left: 00:00 ] - -# Storing results: - -# lr_scores stored in adata.obsm['lr_scores']. -# p_vals stored in adata.obsm['p_vals']. -# p_adjs stored in adata.obsm['p_adjs']. -# -log10(p_adjs) stored in adata.obsm['-log10(p_adjs)']. -# lr_sig_scores stored in adata.obsm['lr_sig_scores']. - -# Per-spot results in adata.obsm have columns in same order as rows in adata.uns['lr_summary']. -# Summary of LR results in adata.uns['lr_summary']. -# """ - -################################################################################ -# Visualising results # -################################################################################ -# Plotting the -log10(p_adjs) for the lr with the highest number of spots. -# Set use_lr to any listed in data.uns['lr_summary'] to visualise alternate lrs. -st.pl.lr_result_plot( - data, - use_lr=None, # Which LR to use, if None then uses top resuls from data.uns['lr_results'] - use_result="-log10(p_adjs)", # Which result to visualise, must be one of - # p_vals, p_adjs, -log10(p_adjs), lr_sig_scores -) -plt.show() - -################################################################################ -# Extra diagnostic plots for results # -################################################################################ -# TODO: -# Below needs to be updated with new way of storing results. - -# Looking at which LR pairs were significant across the most spots # -print(data.uns["lr_summary"]) # Rank-ordered by pairs with most significant spots - -# Now looking at the LR pair with the highest number of sig. spots # -best_lr = data.uns["lr_summary"].index.values[0] - -# Binary LR coexpression plot for all spots # -st.pl.lr_plot( - data, - best_lr, - inner_size_prop=0.1, - outer_mode="binary", - pt_scale=10, - use_label=None, - show_image=True, - sig_spots=False, -) -plt.show() - -# Significance scores for all spots # -st.pl.lr_plot( - data, - best_lr, - inner_size_prop=1, - outer_mode=None, - pt_scale=20, - use_label="lr_scores", - show_image=True, - sig_spots=False, -) -plt.show() - -# Binary LR coexpression plot for significant spots # -st.pl.lr_plot( - data, - best_lr, - outter_size_prop=1, - outer_mode="binary", - pt_scale=20, - use_label=None, - show_image=True, - sig_spots=True, -) -plt.show() - -# Continuous LR coexpression for signficant spots # -st.pl.lr_plot( - data, - best_lr, - inner_size_prop=0.1, - middle_size_prop=0.2, - outter_size_prop=0.4, - outer_mode="continuous", - pt_scale=150, - use_label=None, - show_image=True, - sig_spots=True, -) -plt.show() - -# Continous LR coexpression for significant spots with tissue_type information # -st.pl.lr_plot( - data, - best_lr, - inner_size_prop=0.08, - middle_size_prop=0.3, - outter_size_prop=0.5, - outer_mode="continuous", - pt_scale=150, - use_label="tissue_type", - show_image=True, - sig_spots=True, -) -plt.show() - - -# # Old version of visualisation # -# """ -# # LR enrichment scores -# data.obsm[f'{best_lr}_scores'] = data.uns['per_lr_results'][best_lr].loc[:, -# 'lr_scores'].values -# # -log10(p_adj) of LR enrichment scores -# data.obsm[f'{best_lr}_log-p_adj'] = data.uns['per_lr_results'][best_lr].loc[:, -# '-log10(p_adj)'].values -# # Significant LR enrichment scores -# data.obsm[f'{best_lr}_sig-scores'] = data.uns['per_lr_results'][best_lr].loc[:, -# 'lr_sig_scores'].values - -# # Visualising these results # -# st.pl.het_plot(data, use_het=f'{best_lr}_scores', cell_alpha=0.7) -# plt.show() - -# st.pl.het_plot(data, use_het=f'{best_lr}_sig-scores', cell_alpha=0.7) -# plt.show() -# """ diff --git a/docs/_templates/autosummary/base.rst b/docs/_templates/autosummary/base.rst deleted file mode 100644 index 7a780868..00000000 --- a/docs/_templates/autosummary/base.rst +++ /dev/null @@ -1,4 +0,0 @@ - -{% extends "!autosummary/base.rst" %} - -.. http://www.sphinx-doc.org/en/stable/ext/autosummary.html#customizing-templates diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst deleted file mode 100644 index 42c37f16..00000000 --- a/docs/_templates/autosummary/class.rst +++ /dev/null @@ -1,5 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. add toctree option to make autodoc generate the pages diff --git a/docs/api.rst b/docs/api.rst index 19568d0a..c27132ff 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,11 +13,10 @@ Import stLearn as:: Wrapper functions: `wrapper` ------------------------------ -.. module:: stlearn.wrapper .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ Read10X ReadOldST @@ -31,11 +30,10 @@ Wrapper functions: `wrapper` Add: `add` ------------------- -.. module:: stlearn.add .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ add.image add.positions @@ -56,7 +54,7 @@ Preprocessing: `pp` .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ pp.filter_genes pp.log1p @@ -75,7 +73,7 @@ Embedding: `em` .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ em.run_pca em.run_umap @@ -91,7 +89,7 @@ Spatial: `spatial` .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ spatial.clustering.localization @@ -99,7 +97,7 @@ Spatial: `spatial` .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ spatial.trajectory.pseudotime spatial.trajectory.pseudotimespace_global @@ -113,7 +111,7 @@ Spatial: `spatial` .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ spatial.morphology.adjust @@ -121,7 +119,7 @@ Spatial: `spatial` .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ spatial.SME.SME_impute0 spatial.SME.pseudo_spot @@ -130,22 +128,13 @@ Spatial: `spatial` Tools: `tl` ------------------- -.. module:: stlearn.tl.clustering .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ tl.clustering.kmeans tl.clustering.louvain - - -.. module:: stlearn.tl.cci -.. currentmodule:: stlearn - -.. autosummary:: - :toctree: . - tl.cci.load_lrs tl.cci.grid tl.cci.run @@ -156,11 +145,10 @@ Tools: `tl` Plot: `pl` ------------------- -.. module:: stlearn.pl .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ pl.QC_plot pl.gene_plot @@ -168,7 +156,6 @@ Plot: `pl` pl.cluster_plot pl.cluster_plot_interactive pl.subcluster_plot - pl.subcluster_plot pl.non_spatial_plot pl.deconvolution_plot pl.plot_mask @@ -186,11 +173,10 @@ Plot: `pl` pl.lr_plot_interactive pl.spatialcci_plot_interactive -.. module:: stlearn.pl.trajectory .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ pl.trajectory.pseudotime_plot pl.trajectory.local_plot @@ -198,13 +184,13 @@ Plot: `pl` pl.trajectory.transition_markers_plot pl.trajectory.DE_transition_plot -Tools: `datasets` -------------------- +Datasets: `datasets` +--------------------------- -.. module:: stlearn.datasets .. currentmodule:: stlearn .. autosummary:: - :toctree: . + :toctree: api/ - datasets.example_bcba() + datasets.visium_sge + datasets.xenium_sge diff --git a/docs/conf.py b/docs/conf.py index 272a059b..ccda1308 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,216 +1,78 @@ -#!/usr/bin/env python -# -# stlearn documentation build configuration file, created by -# sphinx-quickstart on Fri Jun 9 13:47:02 2017. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another -# directory, add these directories to sys.path here. If the directory is -# relative to the documentation root, use os.path.abspath to make it -# absolute, like shown here. -# import os import sys +import re +import requests + +def download_gdrive_file(file_id, filename): + session = requests.Session() + url = f"https://docs.google.com/uc?export=download&id={file_id}" + response = session.get(url) + + form_action_match = re.search(r'action="([^"]+)"', response.text) + if not form_action_match: + raise Exception("Could not find form action URL") + download_url = form_action_match.group(1) + + params = {} + hidden_inputs = re.findall( + r':5000` in your web browser. - -Check the detail tutorial in this pdf file: `Link `_ diff --git a/docs/list_tutorial.txt b/docs/list_tutorial.txt deleted file mode 100644 index 53badb09..00000000 --- a/docs/list_tutorial.txt +++ /dev/null @@ -1,11 +0,0 @@ -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Pseudo-time-space-tutorial.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Read_MERFISH.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Read_seqfish.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Read_slideseq.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/ST_deconvolution_visualization.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Working-with-Old-Spatial-Transcriptomics-data.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/stLearn-CCI.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/stSME_clustering.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/stSME_comparison.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Xenium_PSTS.ipynb -https://raw.githubusercontent.com/BiomedicalMachineLearning/stLearn/master/tutorials/Xenium_CCI.ipynb diff --git a/docs/make.bat b/docs/make.bat index 2afd47f0..954237b9 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -5,32 +5,31 @@ pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=python -msphinx + set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build -set SPHINXPROJ=stlearn - -if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. - echo.The Sphinx module was not found. Make sure you have Sphinx installed, - echo.then set the SPHINXBUILD environment variable to point to the full - echo.path of the 'sphinx-build' executable. Alternatively you may add the - echo.Sphinx directory to PATH. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://www.sphinx-doc.org/ exit /b 1 ) -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd diff --git a/docs/release_notes/0.3.2.rst b/docs/release_notes/0.3.2.rst index 9b141ff5..d8e459c7 100644 --- a/docs/release_notes/0.3.2.rst +++ b/docs/release_notes/0.3.2.rst @@ -1,7 +1,7 @@ 0.3.2 `2021-03-29` ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. rubric:: Feature +.. rubric:: Features - Add interactive plotting functions: :func:`~stlearn.pl.gene_plot_interactive`, :func:`~stlearn.pl.cluster_plot_interactive`, :func:`~stlearn.pl.het_plot_interactive` - Add basic unittest (will add more in the future). diff --git a/docs/release_notes/0.4.6.rst b/docs/release_notes/0.4.6.rst index b2f08dd6..b8ee0324 100644 --- a/docs/release_notes/0.4.6.rst +++ b/docs/release_notes/0.4.6.rst @@ -1,7 +1,7 @@ 0.4.0 `2022-02-03` ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. rubric:: Feature +.. rubric:: Features - Upgrade stSME, PSTS and CCI analysis methods. diff --git a/docs/release_notes/1.1.0.rst b/docs/release_notes/1.1.0.rst new file mode 100644 index 00000000..bd32dc46 --- /dev/null +++ b/docs/release_notes/1.1.0.rst @@ -0,0 +1,19 @@ +1.1.0 `2025-07-02` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. rubric:: Features + +* Support Python 3.10.x +* Added quality checks black, ruff and mypy and fixed appropriate source code. +* Copy parameters now work with the same semantics as scanpy. +* Library upgrades for leidenalg, louvain, numba, numpy, scanpy, and tensorflow. +* datasets.xenium_sge - loads Xenium data (and caches it) similar to scanpy.visium_sge. + +.. rubric:: Bug fixes + +* Xenium TIFF and cell positions are now aligned. +* Consistent with type annotations - mainly missing None annotations. +* pl.cluster_plot - Does not keep colours from previous runs when clustering. +* pl.trajectory.pseudotime_plot - Fix typing of cluster values in .uns["split_node"]. +* Removed datasets.example_bcba - Replaced with wrapper for scanpy.visium_sge. +* Moved spatials directory to spatial, cleaned up pl and tl packages. \ No newline at end of file diff --git a/docs/release_notes/index.rst b/docs/release_notes/index.rst index 48c9c3be..6194c62c 100644 --- a/docs/release_notes/index.rst +++ b/docs/release_notes/index.rst @@ -1,10 +1,7 @@ -Release notes +Release Notes =================================================== -Version 0.4.9 ---------------------------- - -.. include:: 0.4.10.rst +.. include:: 1.1.0.rst .. include:: 0.4.6.rst diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index a2c20d28..00000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ --r ../requirements.txt -ipyvolume -ipywebrtc -ipywidgets -jupyter_sphinx -nbclean -nbformat -nbsphinx -pygments -recommonmark -sphinx -sphinx-autodoc-typehints -sphinx_gallery==0.10.1 -sphinx_rtd_theme -typing_extensions diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 0c0ecf16..83889f22 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -4,33 +4,24 @@ Tutorials .. nbgallery:: :caption: Main features: - tutorials/stSME_clustering - tutorials/stSME_comparison - tutorials/Pseudo-time-space-tutorial - tutorials/stLearn-CCI - tutorials/Xenium_PSTS - tutorials/Xenium_CCI + tutorials/cell_cell_interaction + tutorials/cell_cell_interaction_xenium + tutorials/pseudotime_space + tutorials/pseudotime_space_xenium + tutorials/stsme_clustering + tutorials/stsme_comparison + .. nbgallery:: :caption: Visualisation and additional functionalities: - tutorials/Interactive_plot - tutorials/Core_plots - tutorials/ST_deconvolution_visualization - tutorials/Integration_multiple_datasets - + tutorials/core_plots + tutorials/integrate_multiple_datasets .. nbgallery:: :caption: Supporting platforms: - - tutorials/Read_MERFISH - tutorials/Read_seqfish - tutorials/Working-with-Old-Spatial-Transcriptomics-data - tutorials/Read_slideseq - .. nbgallery:: :caption: Integration with other spatial tools: - tutorials/Read_any_data - tutorials/Working_with_scanpy + tutorials/working_with_scanpy diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..5d8b6f99 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +follow_untyped_imports = True +no_site_packages = True +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..d6e72578 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,86 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "stlearn" +version = "1.1.0" +authors = [ + {name = "Genomics and Machine Learning lab", email = "andrew.newman@uq.edu.au"}, +] +description = "A downstream analysis toolkit for Spatial Transcriptomic data" +readme = {file = "README.md", content-type = "text/markdown"} +license = {text = "BSD license"} +requires-python = "~=3.10.0" +keywords = ["stlearn"] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Natural Language :: English", + "Programming Language :: Python :: 3.10", +] +dynamic = ["dependencies"] + +[project.optional-dependencies] +dev = [ + "black>=23.0", + "ruff>=0.1.0", + "mypy>=1.16", + "pytest>=7.0", + "tox>=4.0", + "ghp-import>=2.1.0", + "sphinx>=4.0", + "furo==2024.8.6", + "myst-parser>=0.18", + "nbsphinx>=0.9.0", + "sphinx-autodoc-typehints>=1.24.0", + "sphinx-autosummary-accessors>=2023.4.0", +] +test = [ + "pytest", + "pytest-cov", +] +webapp = [ + "flask>=2.0.0", + "flask-wtf>=1.0.0", + "wtforms>=3.0.0", + "markupsafe>2.1.0", +] +jupyter = [ + "jupyter>=1.0.0", + "jupyterlab>=3.0.0", + "ipywidgets>=7.6.0", + "plotly>=5.0.0", + "bokeh>=2.4.0", + "rpy2>=3.4.0", +] + +[project.urls] +Homepage = "https://github.com/BiomedicalMachineLearning/stLearn" +Repository = "https://github.com/BiomedicalMachineLearning/stLearn" + +[project.scripts] +stlearn = "stlearn.app.cli:main" + +[tool.setuptools.packages.find] +include = ["stlearn", "stlearn.*"] + +[tool.setuptools.package-data] +"*" = ["*"] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[tool.ruff] +line-length=88 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I", "N", "UP"] +ignore = ["E722", "F811", "N802", "N803", "N806", "N818", "N999", "UP031"] +exclude = [".git", "__pycache__", "build", "dist"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c5452f88..4c11dc46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,14 @@ -bokeh>= 2.4.2 -click>=8.0.4 -leidenalg -louvain -numba<=0.57.1 -numpy>=1.18,<1.22 -Pillow>=9.0.1 -scanpy>=1.8.2 -scikit-image>=0.19.2 -tensorflow +bokeh==3.7.3 +click==8.2.1 +leidenalg==0.10.2 +louvain==0.8.2 +numba==0.58.1 +numpy==1.26.4 +pillow==11.2.1 +scanpy==1.10.4 +scikit-image==0.22.0 +tensorflow==2.14.1 +keras==2.14.0 +types-tensorflow>=2.8.0 +imageio==2.37.0 +scipy==1.11.4 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index f877626d..00000000 --- a/setup.cfg +++ /dev/null @@ -1,21 +0,0 @@ -[bumpversion] -current_version = 0.4.11 -commit = True -tag = True - -[bumpversion:file:setup.py] -search = version='{current_version}' -replace = version='{new_version}' - -[bumpversion:file:stlearn/__init__.py] -search = __version__ = '{current_version}' -replace = __version__ = '{new_version}' - -[bdist_wheel] -universal = 1 - -[flake8] -exclude = docs - -[aliases] -# Define setup.py command aliases here diff --git a/setup.py b/setup.py deleted file mode 100644 index e728fba4..00000000 --- a/setup.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python - -"""The setup script.""" - -from setuptools import setup, find_packages - -with open("README.md", encoding="utf8") as readme_file: - readme = readme_file.read() - -with open("HISTORY.rst") as history_file: - history = history_file.read() - -with open("requirements.txt") as f: - requirements = f.read().splitlines() - - -setup_requirements = [] - -test_requirements = [] - -setup( - author="Genomics and Machine Learning lab", - author_email="duy.pham@uq.edu.au", - python_requires=">=3.7", - classifiers=[ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - ], - description="A downstream analysis toolkit for Spatial Transcriptomic data", - entry_points={ - "console_scripts": [ - "stlearn=stlearn.app.cli:main", - ], - }, - install_requires=requirements, - license="BSD license", - long_description=readme + "\n\n" + history, - long_description_content_type="text/markdown", - include_package_data=True, - keywords="stlearn", - name="stlearn", - packages=find_packages(include=["stlearn", "stlearn.*"]), - setup_requires=setup_requirements, - test_suite="tests", - tests_require=test_requirements, - url="https://github.com/BiomedicalMachineLearning/stLearn", - version="0.4.11", - zip_safe=False, -) diff --git a/stlearn/__init__.py b/stlearn/__init__.py index 1fc79b20..213fe82f 100644 --- a/stlearn/__init__.py +++ b/stlearn/__init__.py @@ -1,30 +1,43 @@ """Top-level package for stLearn.""" -__author__ = """Genomics and Machine Learning lab""" -__email__ = "duy.pham@uq.edu.au" -__version__ = "0.4.11" - - -from . import add -from . import pp -from . import em -from . import tl -from . import pl -from . import spatial -from . import datasets - -# Wrapper - -from .wrapper.read import ReadSlideSeq -from .wrapper.read import Read10X -from .wrapper.read import ReadOldST -from .wrapper.read import ReadMERFISH -from .wrapper.read import ReadSeqFish -from .wrapper.read import ReadXenium -from .wrapper.read import create_stlearn +__author__ = """Genomics and Machine Learning Lab""" +__email__ = "andrew.newman@uq.edu.au" +__version__ = "1.1.0" +from . import add, datasets, em, pl, pp, spatial, tl, types from ._settings import settings -from .wrapper.convert_scanpy import convert_scanpy from .wrapper.concatenate_spatial_adata import concatenate_spatial_adata +from .wrapper.convert_scanpy import convert_scanpy + +# Wrapper +from .wrapper.read import ( + Read10X, + ReadMERFISH, + ReadOldST, + ReadSeqFish, + ReadSlideSeq, + ReadXenium, + create_stlearn, +) # from . import cli +__all__ = [ + "add", + "pp", + "em", + "tl", + "pl", + "spatial", + "datasets", + "ReadSlideSeq", + "Read10X", + "ReadOldST", + "ReadMERFISH", + "ReadSeqFish", + "ReadXenium", + "create_stlearn", + "settings", + "types", + "convert_scanpy", + "concatenate_spatial_adata", +] diff --git a/stlearn/__main__.py b/stlearn/__main__.py index 981709a2..43559dfc 100644 --- a/stlearn/__main__.py +++ b/stlearn/__main__.py @@ -2,9 +2,7 @@ """Package entry point.""" - -from stlearn.app import main - +from stlearn.app import cli if __name__ == "__main__": # pragma: no cover - main() + cli.main() diff --git a/stlearn/_compat.py b/stlearn/_compat.py deleted file mode 100644 index 0ef291a2..00000000 --- a/stlearn/_compat.py +++ /dev/null @@ -1,15 +0,0 @@ -try: - from typing import Literal -except ImportError: - try: - from typing_extensions import Literal - except ImportError: - - class LiteralMeta(type): - def __getitem__(cls, values): - if not isinstance(values, tuple): - values = (values,) - return type("Literal_", (Literal,), dict(__args__=values)) - - class Literal(metaclass=LiteralMeta): - pass diff --git a/stlearn/_datasets/_datasets.py b/stlearn/_datasets/_datasets.py index 19ffb6d5..56f17fd5 100644 --- a/stlearn/_datasets/_datasets.py +++ b/stlearn/_datasets/_datasets.py @@ -1,18 +1,89 @@ +import zipfile as zf + import scanpy as sc -from .._settings import settings -from pathlib import Path from anndata import AnnData +from .._settings import settings + + +# TODO - Add scanpy and covert this over. +def visium_sge( + sample_id="V1_Breast_Cancer_Block_A_Section_1", + *, + include_hires_tiff: bool = False, +) -> AnnData: + """Processed Visium Spatial Gene Expression data from 10x Genomics’ database. -def example_bcba() -> AnnData: - """\ - Download processed BCBA data (10X genomics published data). - Reference: https://support.10xgenomics.com/spatial-gene-expression/datasets/1.1.0/V1_Breast_Cancer_Block_A_Section_1 + The database_ can be browsed online to find the ``sample_id`` you want. + + .. _database: https://support.10xgenomics.com/spatial-gene-expression/datasets + + Parameters + ---------- + sample_id + The ID of the data sample in 10x’s spatial database. + include_hires_tiff + Download and include the high-resolution tissue image (tiff) in + `adata.uns["spatial"][sample_id]["metadata"]["source_image_path"]`. + + Returns + ------- + Annotated data matrix. """ - settings.datasetdir.mkdir(exist_ok=True) - filename = settings.datasetdir / "example_bcba.h5" - url = "https://www.dropbox.com/s/u3m2f16mvdom1am/example_bcba.h5ad?dl=1" - if not filename.is_file(): - sc.readwrite._download(url=url, path=filename) - adata = sc.read_h5ad(filename) - return adata + sc.settings.datasetdir = settings.datasetdir + return sc.datasets.visium_sge(sample_id, include_hires_tiff=include_hires_tiff) + + +def xenium_sge( + base_url="https://cf.10xgenomics.com/samples/xenium/1.0.1", + image_filename="he_image.ome.tif", + alignment_filename="he_imagealignment.csv", + zip_filename="outs.zip", + library_id="Xenium_FFPE_Human_Breast_Cancer_Rep1", + include_hires_tiff: bool = False, +): + """ + Download and extract Xenium SGE data files. Unlike scanpy this current does not + load the data. Data is located in `settings.datasetdir` / `library_id`. + + Args: + base_url: Base URL for downloads + image_filename: Name of the image file to download + alignment_filename: Name of the affine transformation file to download + zip_filename: Name of the zip file to download + library_id: Identifier for the library + include_hires_tiff: Whether to download the high-res TIFF image + """ + sc.settings.datasetdir = settings.datasetdir + library_dir = settings.datasetdir / library_id + library_dir.mkdir(parents=True, exist_ok=True) + + files_to_extract = ["cell_feature_matrix.h5", "cells.csv.gz", "experiment.xenium"] + all_sge_files_exist = all( + (library_dir / sge_file).exists() for sge_file in files_to_extract + ) + + download_filenames = [] + if not all_sge_files_exist: + download_filenames.append(zip_filename) + if include_hires_tiff and ( + not (library_dir / alignment_filename).exists() + or not (library_dir / image_filename).exists() + ): + download_filenames += [alignment_filename, image_filename] + + for file_name in download_filenames: + file_path = library_dir / file_name + url = f"{base_url}/{library_id}/{library_id}_{file_name}" + if not file_path.is_file(): + sc.readwrite._download(url=url, path=file_path) + + if not all_sge_files_exist: + try: + zip_file_path = library_dir / zip_filename + with zf.ZipFile(zip_file_path, "r") as zip_ref: + for zip_filename in files_to_extract: + with open(library_dir / zip_filename, "wb") as file_name: + file_name.write(zip_ref.read(f"outs/{zip_filename}")) + except zf.BadZipFile: + raise ValueError(f"Invalid zip file: {library_dir / zip_filename}") diff --git a/stlearn/_settings.py b/stlearn/_settings.py index 30eb017a..9e75a8d4 100644 --- a/stlearn/_settings.py +++ b/stlearn/_settings.py @@ -1,21 +1,20 @@ import inspect import sys +from collections.abc import Iterable, Iterator from contextlib import contextmanager from enum import IntEnum +from logging import getLevelName from pathlib import Path from time import time -from logging import getLevelName -from typing import Any, Union, Optional, Iterable, TextIO -from typing import Tuple, List, ContextManager +from typing import Any, Literal, TextIO from . import logging -from .logging import _set_log_level, _set_log_file, _RootLogger -from ._compat import Literal +from .logging import _RootLogger, _set_log_file, _set_log_level # All the code here migrated from scanpy # It help to work with scanpy package -_VERBOSITY_TO_LOGLEVEL = { +_VERBOSITY_TO_LOGLEVEL: dict[str | int, str] = { "error": "ERROR", "warning": "WARNING", "info": "INFO", @@ -40,7 +39,7 @@ def level(self) -> int: return getLevelName(_VERBOSITY_TO_LOGLEVEL[self]) @contextmanager - def override(self, verbosity: "Verbosity") -> ContextManager["Verbosity"]: + def override(self, verbosity: "Verbosity") -> Iterator["Verbosity"]: """\ Temporarily override verbosity """ @@ -49,7 +48,7 @@ def override(self, verbosity: "Verbosity") -> ContextManager["Verbosity"]: settings.verbosity = self -def _type_check(var: Any, varname: str, types: Union[type, Tuple[type, ...]]): +def _type_check(var: Any, varname: str, types: type | tuple[type, ...]): if isinstance(var, types): return if isinstance(types, type): @@ -62,11 +61,15 @@ def _type_check(var: Any, varname: str, types: Union[type, Tuple[type, ...]]): raise TypeError(f"{varname} must be of type {possible_types_str}") -class stLearnConfig: +class stLearnConfig: # noqa N801 """\ Config manager for scanpy. """ + _logpath: Path | None + _logfile: TextIO + _verbosity: Verbosity + def __init__( self, *, @@ -76,14 +79,14 @@ def __init__( file_format_figs: str = "pdf", autosave: bool = False, autoshow: bool = True, - writedir: Union[str, Path] = "./write/", - cachedir: Union[str, Path] = "./cache/", - datasetdir: Union[str, Path] = "./data/", - figdir: Union[str, Path] = "./figures/", - cache_compression: Union[str, None] = "lzf", + writedir: str | Path = "./write/", + cachedir: str | Path = "./cache/", + datasetdir: str | Path = "./data/", + figdir: str | Path = "./figures/", + cache_compression: str | None = "lzf", max_memory=15, n_jobs=1, - logfile: Union[str, Path, None] = None, + logfile: str | Path | None = None, categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"), _frameon: bool = True, _vector_friendly: bool = False, @@ -139,14 +142,14 @@ def verbosity(self) -> Verbosity: return self._verbosity @verbosity.setter - def verbosity(self, verbosity: Union[Verbosity, int, str]): + def verbosity(self, verbosity: Verbosity | int | str): verbosity_str_options = [ v for v in _VERBOSITY_TO_LOGLEVEL if isinstance(v, str) ] if isinstance(verbosity, Verbosity): - self._verbosity = verbosity + new_verbosity = verbosity elif isinstance(verbosity, int): - self._verbosity = Verbosity(verbosity) + new_verbosity = Verbosity(verbosity) elif isinstance(verbosity, str): verbosity = verbosity.lower() if verbosity not in verbosity_str_options: @@ -155,10 +158,9 @@ def verbosity(self, verbosity: Union[Verbosity, int, str]): f"Accepted string values are: {verbosity_str_options}" ) else: - self._verbosity = Verbosity(verbosity_str_options.index(verbosity)) - else: - _type_check(verbosity, "verbosity", (str, int)) - _set_log_level(self, _VERBOSITY_TO_LOGLEVEL[self._verbosity]) + new_verbosity = Verbosity(verbosity_str_options.index(verbosity)) + self._verbosity = new_verbosity + _set_log_level(self, self._verbosity) @property def plot_suffix(self) -> str: @@ -207,7 +209,8 @@ def file_format_figs(self, figure_format: str): @property def autosave(self) -> bool: """\ - Automatically save figures in :attr:`~stlearn._settings.stLearnConfig.figdir` (default `False`). + Automatically save figures in :attr:`~stlearn._settings.stLearnConfig.figdir` + (default `False`). Do not show plots/figures interactively. """ @@ -240,7 +243,7 @@ def writedir(self) -> Path: return self._writedir @writedir.setter - def writedir(self, writedir: Union[str, Path]): + def writedir(self, writedir: str | Path): _type_check(writedir, "writedir", (str, Path)) self._writedir = Path(writedir) @@ -252,7 +255,7 @@ def cachedir(self) -> Path: return self._cachedir @cachedir.setter - def cachedir(self, cachedir: Union[str, Path]): + def cachedir(self, cachedir: str | Path): _type_check(cachedir, "cachedir", (str, Path)) self._cachedir = Path(cachedir) @@ -264,7 +267,7 @@ def datasetdir(self) -> Path: return self._datasetdir @datasetdir.setter - def datasetdir(self, datasetdir: Union[str, Path]): + def datasetdir(self, datasetdir: str | Path): _type_check(datasetdir, "datasetdir", (str, Path)) self._datasetdir = Path(datasetdir).resolve() @@ -276,12 +279,12 @@ def figdir(self) -> Path: return self._figdir @figdir.setter - def figdir(self, figdir: Union[str, Path]): + def figdir(self, figdir: str | Path): _type_check(figdir, "figdir", (str, Path)) self._figdir = Path(figdir) @property - def cache_compression(self) -> Optional[str]: + def cache_compression(self) -> str | None: """\ Compression for `sc.read(..., cache=True)` (default `'lzf'`). @@ -290,7 +293,7 @@ def cache_compression(self) -> Optional[str]: return self._cache_compression @cache_compression.setter - def cache_compression(self, cache_compression: Optional[str]): + def cache_compression(self, cache_compression: str | None): if cache_compression not in {"lzf", "gzip", None}: raise ValueError( f"`cache_compression` ({cache_compression}) " @@ -299,7 +302,7 @@ def cache_compression(self, cache_compression: Optional[str]): self._cache_compression = cache_compression @property - def max_memory(self) -> Union[int, float]: + def max_memory(self) -> int | float: """\ Maximal memory usage in Gigabyte. @@ -308,7 +311,7 @@ def max_memory(self) -> Union[int, float]: return self._max_memory @max_memory.setter - def max_memory(self, max_memory: Union[int, float]): + def max_memory(self, max_memory: int | float): _type_check(max_memory, "max_memory", (int, float)) self._max_memory = max_memory @@ -325,18 +328,21 @@ def n_jobs(self, n_jobs: int): self._n_jobs = n_jobs @property - def logpath(self) -> Optional[Path]: + def logpath(self) -> Path | None: """\ The file path `logfile` was set to. """ return self._logpath @logpath.setter - def logpath(self, logpath: Union[str, Path, None]): - _type_check(logpath, "logfile", (str, Path)) - # set via “file object” branch of logfile.setter - self.logfile = Path(logpath).open("a") - self._logpath = Path(logpath) + def logpath(self, logpath: str | Path | None): + if logpath is None: + self._logpath = None + else: + _type_check(logpath, "logpath", (str, Path)) + # set via “file object” branch of logfile.setter + self.logfile = Path(logpath).open("a") + self._logpath = Path(logpath) @property def logfile(self) -> TextIO: @@ -347,23 +353,27 @@ def logfile(self) -> TextIO: The default `None` corresponds to :obj:`sys.stdout` in jupyter notebooks and to :obj:`sys.stderr` otherwise. - For backwards compatibility, setting it to `''` behaves like setting it to `None`. + For backwards compatibility, setting it to `''` behaves like setting it + to `None`. """ return self._logfile @logfile.setter - def logfile(self, logfile: Union[str, Path, TextIO, None]): - if not hasattr(logfile, "write") and logfile: - self.logpath = logfile - else: # file object - if not logfile: # None or '' - logfile = sys.stdout if self._is_run_from_ipython() else sys.stderr + def logfile(self, logfile: str | Path | TextIO | None): + if logfile is None or logfile == "": + self._logfile = sys.stdout if self._is_run_from_ipython() else sys.stderr + self._logpath = None + elif isinstance(logfile, (str | Path)): + path = Path(logfile) + self._logfile = path.open("a") + self._logpath = path + elif isinstance(logfile, TextIO): self._logfile = logfile self._logpath = None - _set_log_file(self) + _set_log_file(self) @property - def categories_to_ignore(self) -> List[str]: + def categories_to_ignore(self) -> list[str]: """\ Categories that are omitted in plotting etc. """ @@ -403,7 +413,7 @@ def set_figure_params( frameon: bool = True, vector_friendly: bool = True, fontsize: int = 14, - color_map: Optional[str] = None, + color_map: str | None = None, format: _Format = "pdf", transparent: bool = False, ipython_format: str = "png2x", @@ -414,18 +424,21 @@ def set_figure_params( Parameters ---------- dpi - Resolution of rendered figures – this influences the size of figures in notebooks. + Resolution of rendered figures – this influences the size of figures + in notebooks. dpi_save Resolution of saved figures. This should typically be higher to achieve publication quality. frameon Add frames and axes labels to scatter plots. vector_friendly - Plot scatter plots using `png` backend even when exporting as `pdf` or `svg`. + Plot scatter plots using `png` backend even when exporting as + `pdf` or `svg`. fontsize Set the fontsize for several `rcParams` entries. Ignored if `scanpy=False`. color_map - Convenience method for setting the default color map. Ignored if `scanpy=False`. + Convenience method for setting the default color map. Ignored if + `scanpy=False`. format This sets the default format for saving figures: `file_format_figs`. transparent @@ -438,9 +451,7 @@ def set_figure_params( try: import IPython - if isinstance(ipython_format, str): - ipython_format = [ipython_format] - IPython.display.set_matplotlib_formats(*ipython_format) + IPython.display.set_matplotlib_formats(*[ipython_format]) except Exception: pass from matplotlib import rcParams diff --git a/stlearn/add.py b/stlearn/add.py index fde7173d..025a232a 100644 --- a/stlearn/add.py +++ b/stlearn/add.py @@ -1,10 +1,22 @@ +from .adds.add_deconvolution import add_deconvolution from .adds.add_image import image -from .adds.add_positions import positions -from .adds.parsing import parsing -from .adds.add_lr import lr -from .adds.annotation import annotation from .adds.add_labels import labels -from .adds.add_deconvolution import add_deconvolution -from .adds.add_mask import add_mask -from .adds.add_mask import apply_mask from .adds.add_loupe_clusters import add_loupe_clusters +from .adds.add_lr import lr +from .adds.add_mask import add_mask, apply_mask +from .adds.add_positions import positions +from .adds.annotation import annotation +from .adds.parsing import parsing + +__all__ = [ + "image", + "positions", + "parsing", + "lr", + "annotation", + "labels", + "add_deconvolution", + "add_mask", + "apply_mask", + "add_loupe_clusters", +] diff --git a/stlearn/adds/add_deconvolution.py b/stlearn/adds/add_deconvolution.py index 3b5445be..5d892dda 100644 --- a/stlearn/adds/add_deconvolution.py +++ b/stlearn/adds/add_deconvolution.py @@ -1,16 +1,14 @@ -from typing import Optional, Union -from anndata import AnnData -import pandas as pd -import numpy as np from pathlib import Path +import pandas as pd +from anndata import AnnData + def add_deconvolution( adata: AnnData, - annotation_path: Union[Path, str], + annotation_path: Path | str, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Adding label transfered from Seurat @@ -29,7 +27,11 @@ def add_deconvolution( The annotation of cluster results. """ + adata = adata.copy() if copy else adata + label = pd.read_csv(annotation_path, index_col=0) label = label[adata.obs_names] adata.obsm["deconvolution"] = label[adata.obs.index].T + + return adata diff --git a/stlearn/adds/add_image.py b/stlearn/adds/add_image.py index 83c92d6b..20376ece 100644 --- a/stlearn/adds/add_image.py +++ b/stlearn/adds/add_image.py @@ -1,8 +1,8 @@ -from typing import Optional, Union +import os +from pathlib import Path + from anndata import AnnData from matplotlib import pyplot as plt -from pathlib import Path -import os from PIL import Image Image.MAX_IMAGE_PIXELS = None @@ -10,15 +10,14 @@ def image( adata: AnnData, - imgpath: Union[Path, str], + imgpath: Path | str | None, library_id: str, quality: str = "hires", scale: float = 1.0, visium: bool = False, spot_diameter_fullres: float = 50, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Adding image data to the Anndata object @@ -29,11 +28,13 @@ def image( imgpath Image path. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the visium library. Can be modified when concatenating + multiple adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow']. + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow']. visium Is this anndata read from Visium platform or not. copy @@ -44,6 +45,7 @@ def image( **tissue_img** : `adata.uns` field Array format of image, saving by Pillow package. """ + adata = adata.copy() if copy else adata if imgpath is not None and os.path.isfile(imgpath): try: @@ -68,8 +70,6 @@ def image( adata.obs[["imagecol", "imagerow"]] = adata.obsm["spatial"] * scale print("Added tissue image to the object!") - - return adata if copy else None except: raise ValueError( f"""\ diff --git a/stlearn/adds/add_labels.py b/stlearn/adds/add_labels.py index d4a05451..d11cad49 100644 --- a/stlearn/adds/add_labels.py +++ b/stlearn/adds/add_labels.py @@ -1,41 +1,47 @@ -from typing import Optional, Union -from anndata import AnnData -from pathlib import Path -import os -import pandas as pd import numpy as np +import pandas as pd +from anndata import AnnData from natsort import natsorted def labels( adata: AnnData, - label_filepath: str = None, + label_filepath: str, index_col: int = 0, - use_label: str = None, + use_label: str | None = None, sep: str = "\t", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Add label transfer results into AnnData object Parameters ---------- - adata: AnnData The data object to add L-R info into - label_filepath: str The path to the label transfer results file - use_label: str Where to store the label_transfer results, defaults to 'predictions' in adata.obs & 'label_transfer' in adata.uns. - sep: str Separator of the csv file - copy: bool Copy flag indicating copy or direct edit + adata: AnnData + The data object to add L-R info into + label_filepath: str + The path to the label transfer results file + use_label: str + Where to store the label_transfer results, defaults to 'predictions' + in adata.obs & 'label_transfer' in adata.uns. + sep: str + Separator of the csv file + copy: bool + Copy flag indicating copy or direct edit Returns ------- - adata: AnnData The data object that L-R added into + adata: AnnData + The data object that L-R added into """ + adata = adata.copy() if copy else adata + labels = pd.read_csv(label_filepath, index_col=index_col, sep=sep) - uns_key = "label_transfer" if type(use_label) == type(None) else use_label + uns_key = "label_transfer" if use_label is None else use_label adata.uns[uns_key] = labels.drop(["predicted.id", "prediction.score.max"], axis=1) - key_add = "predictions" if type(use_label) == type(None) else use_label + key_add = "predictions" if use_label is None else use_label key_source = "predicted.id" adata.obs[key_add] = pd.Categorical( values=np.array(labels[key_source]).astype("U"), diff --git a/stlearn/adds/add_loupe_clusters.py b/stlearn/adds/add_loupe_clusters.py index af614bd8..f257f80f 100644 --- a/stlearn/adds/add_loupe_clusters.py +++ b/stlearn/adds/add_loupe_clusters.py @@ -1,19 +1,17 @@ -from typing import Optional, Union -from anndata import AnnData -import pandas as pd -import numpy as np -import stlearn from pathlib import Path + +import numpy as np +import pandas as pd +from anndata import AnnData from natsort import natsorted def add_loupe_clusters( adata: AnnData, - loupe_path: Union[Path, str], + loupe_path: Path | str, key_add: str = "multiplex", copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Adding label transfered from Seurat @@ -36,9 +34,13 @@ def add_loupe_clusters( The annotation of cluster results. """ + adata = adata.copy() if copy else adata + label = pd.read_csv(loupe_path) adata.obs[key_add] = pd.Categorical( values=np.array(label[key_add]).astype("U"), categories=natsorted(label[key_add].unique().astype("U")), ) + + return adata if copy else None diff --git a/stlearn/adds/add_lr.py b/stlearn/adds/add_lr.py index 6ed99cde..d40d11a8 100644 --- a/stlearn/adds/add_lr.py +++ b/stlearn/adds/add_lr.py @@ -1,32 +1,35 @@ -from typing import Optional, Union -from anndata import AnnData -from pathlib import Path -import os import pandas as pd +from anndata import AnnData def lr( adata: AnnData, - db_filepath: str = None, + db_filepath: str, sep: str = "\t", source: str = "connectomedb", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """Add significant Ligand-Receptor pairs into AnnData object Parameters ---------- - adata: AnnData The data object to add L-R info into - db_filepath: str The path to the CPDB results file - sep: str Separator of the CPDB results file - source: str Source of LR database (default: connectomedb, can also support 'cellphonedb') - copy: bool Copy flag indicating copy or direct edit + adata: AnnData + The data object to add L-R info into + db_filepath: str + The path to the CPDB results file + sep: str + Separator of the CPDB results file + source: str + Source of LR database (default: connectomedb, can also support 'cellphonedb') + copy: bool + Copy flag indicating copy or direct edit Returns ------- adata: AnnData The data object that L-R added into """ + adata = adata.copy() if copy else adata if source == "cellphonedb": cpdb = pd.read_csv(db_filepath, sep=sep) diff --git a/stlearn/adds/add_mask.py b/stlearn/adds/add_mask.py index 6608e00f..d25a488c 100644 --- a/stlearn/adds/add_mask.py +++ b/stlearn/adds/add_mask.py @@ -1,19 +1,18 @@ +import os from pathlib import Path + import matplotlib -from matplotlib import pyplot as plt import numpy as np -from typing import Optional, Union from anndata import AnnData -import os -from stlearn._compat import Literal +from matplotlib import pyplot as plt def add_mask( adata: AnnData, - imgpath: Union[Path, str], + imgpath: Path | str, key: str = "mask", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Adding binary mask image to the Anndata object @@ -33,12 +32,14 @@ def add_mask( **mask_image** : `adata.uns` field Array format of image, saving by Pillow package. """ + adata = adata.copy() if copy else adata + try: library_id = list(adata.uns["spatial"].keys())[0] quality = adata.uns["spatial"][library_id]["use_quality"] except: raise KeyError( - f"""\ + """\ Please read ST data first and try again """ ) @@ -59,8 +60,6 @@ def add_mask( adata.uns["mask_image"][library_id][key][quality] = img print("Added tissue mask to the object!") - - return adata if copy else None except: raise ValueError( f"""\ @@ -78,11 +77,11 @@ def add_mask( def apply_mask( adata: AnnData, - masks: Optional[list] = "all", + masks: list | str = "all", select: str = "black", - cmap: str = "default", + cmap_name: str = "default", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Parsing the old spaital transcriptomics data @@ -106,19 +105,22 @@ def apply_mask( Array format of image, saving by Pillow package. """ from scanpy.plotting import palettes - from stlearn.plotting import palettes_st - if cmap == "vega_10_scanpy": + from stlearn.pl import palettes_st + + adata = adata.copy() if copy else adata + + if cmap_name == "vega_10_scanpy": cmap = palettes.vega_10_scanpy - elif cmap == "vega_20_scanpy": + elif cmap_name == "vega_20_scanpy": cmap = palettes.vega_20_scanpy - elif cmap == "default_102": + elif cmap_name == "default_102": cmap = palettes.default_102 - elif cmap == "default_28": + elif cmap_name == "default_28": cmap = palettes.default_28 - elif cmap == "jana_40": + elif cmap_name == "jana_40": cmap = palettes_st.jana_40 - elif cmap == "default": + elif cmap_name == "default": cmap = palettes_st.default else: raise ValueError( @@ -126,7 +128,6 @@ def apply_mask( ) cmaps = matplotlib.colors.LinearSegmentedColormap.from_list("", cmap) - cmap_ = plt.cm.get_cmap(cmaps) try: @@ -134,7 +135,7 @@ def apply_mask( quality = adata.uns["spatial"][library_id]["use_quality"] except: raise KeyError( - f"""\ + """\ Please read ST data first and try again """ ) @@ -163,16 +164,18 @@ def apply_mask( mask_image = np.where(mask_image > 155, 0, 1) else: raise ValueError( - f"""\ + """\ Only support black and white mask yet. """ ) mask_image_2d = mask_image.mean(axis=2) - apply_spot_mask = ( - lambda x: [i, mask] - if mask_image_2d[int(x["imagerow"]), int(x["imagecol"])] == 1 - else [x[key + "_code"], x[key]] - ) + + def apply_spot_mask(x): + if mask_image_2d[int(x["imagerow"]), int(x["imagecol"])] == 1: + return [i, mask] + else: + return [x[key + "_code"], x[key]] + spot_mask_df = adata.obs.apply(apply_spot_mask, axis=1, result_type="expand") adata.obs[key + "_code"] = spot_mask_df[0] adata.obs[key] = spot_mask_df[1] diff --git a/stlearn/adds/add_positions.py b/stlearn/adds/add_positions.py index 52872384..7b4c3cb7 100644 --- a/stlearn/adds/add_positions.py +++ b/stlearn/adds/add_positions.py @@ -1,18 +1,16 @@ -from typing import Optional, Union -from anndata import AnnData -import pandas as pd from pathlib import Path -import os + +import pandas as pd +from anndata import AnnData def positions( adata: AnnData, - position_filepath: Union[Path, str] = None, - scale_filepath: Union[Path, str] = None, + position_filepath: Path | str, + scale_filepath: Path | str, quality: str = "low", copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Adding spatial information into the Anndata object @@ -35,6 +33,8 @@ def positions( Spatial information of the tissue image. """ + adata = adata.copy() if copy else adata + tissue_positions = pd.read_csv(position_filepath, header=None) tissue_positions.columns = [ "barcode", diff --git a/stlearn/adds/annotation.py b/stlearn/adds/annotation.py index a8bc1ac9..8f5df9db 100644 --- a/stlearn/adds/annotation.py +++ b/stlearn/adds/annotation.py @@ -1,16 +1,12 @@ -from typing import Optional, Union, List from anndata import AnnData -from matplotlib import pyplot as plt -from pathlib import Path -import os def annotation( adata: AnnData, - label_list: List[str], + label_list: list[str], use_label: str = "louvain", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Adding annotation for cluster @@ -30,10 +26,11 @@ def annotation( **[cluster method name]_anno** : `adata.obs` field The annotation of cluster results. """ - if label_list is None: raise ValueError("Please give the label list!") + adata = adata.copy() if copy else adata + if len(label_list) != len(adata.obs[use_label].unique()): raise ValueError("Please give the correct number of label list!") diff --git a/stlearn/adds/parsing.py b/stlearn/adds/parsing.py index d92932cc..0ae6a9f0 100644 --- a/stlearn/adds/parsing.py +++ b/stlearn/adds/parsing.py @@ -1,18 +1,14 @@ -from typing import Optional, Union -from anndata import AnnData -from matplotlib import pyplot as plt -from pathlib import Path -import os -import sys +from os import PathLike + import numpy as np +from anndata import AnnData def parsing( adata: AnnData, - coordinates_file: Union[Path, str], + coordinates_file: int | str | bytes | PathLike[str] | PathLike[bytes], copy: bool = True, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Parsing the old spaital transcriptomics data @@ -33,7 +29,7 @@ def parsing( # Get a map of the new coordinates new_coordinates = dict() - with open(coordinates_file, "r") as filehandler: + with open(coordinates_file) as filehandler: for line in filehandler.readlines(): tokens = line.split() assert len(tokens) >= 6 or len(tokens) == 4 @@ -52,6 +48,8 @@ def parsing( "the coordinates file only contains 4 columns\n" ) + adata = adata.copy() if copy else adata + counts_table = adata.to_df() new_index_values = list() @@ -66,7 +64,7 @@ def parsing( imgcol.append(new_x) imgrow.append(new_y) - new_index_values.append("{0}x{1}".format(new_x, new_y)) + new_index_values.append(f"{new_x}x{new_y}") except KeyError: counts_table.drop(index, inplace=True) @@ -80,7 +78,6 @@ def parsing( adata.obs["imagecol"] = imgcol adata.obs["imagerow"] = imgrow - adata.obsm["spatial"] = np.c_[[imgcol, imgrow]].reshape(-1, 2) return adata if copy else None diff --git a/stlearn/app/app.py b/stlearn/app/app.py index 6eb6a5dc..d1e914f3 100644 --- a/stlearn/app/app.py +++ b/stlearn/app/app.py @@ -1,53 +1,36 @@ -import os, sys, subprocess +import os +import sys +from threading import Thread sys.path.append(os.path.dirname(__file__)) -try: - import flask -except ImportError: - subprocess.call( - "pip install -r " + os.path.dirname(__file__) + "//requirements.txt", shell=True - ) +import asyncio +import tempfile +import numpy +import numpy as np +import scanpy +from bokeh.application import Application +from bokeh.application.handlers import FunctionHandler +from bokeh.embed import server_document +from bokeh.layouts import row +from bokeh.server.server import Server from flask import ( Flask, - render_template, - request, flash, - url_for, redirect, - session, + render_template, + request, send_file, + url_for, ) -from bokeh.embed import components -from bokeh.plotting import figure -from bokeh.resources import INLINE +from tornado.ioloop import IOLoop from werkzeug.utils import secure_filename -import tempfile -import traceback - -import tempfile -import shutil import stlearn -import scanpy -import numpy -import numpy as np - -import asyncio -from bokeh.server.server import BaseServer -from bokeh.server.tornado import BokehTornado -from tornado.httpserver import HTTPServer -from tornado.ioloop import IOLoop -from bokeh.application import Application -from bokeh.application.handlers import FunctionHandler -from bokeh.server.server import Server -from bokeh.embed import server_document - -from bokeh.layouts import column, row # Functions related to processing the forms. -from source.forms import views # for changing data in response to input +from stlearn.app.source.forms import views # for changing data in response to input # Global variables. @@ -171,7 +154,6 @@ def folder_uploader(): uploaded = [] i = 0 for file in files: - filename = secure_filename(file.filename) if allow_files[0] in filename: @@ -243,7 +225,6 @@ def folder_uploader(): @app.route("/file_uploader", methods=["GET", "POST"]) def file_uploader(): if request.method == "POST": - global adata, step_log # Clean uploads folder before upload a new data @@ -385,7 +366,7 @@ def save_adata(): def modify_doc_gene_plot(doc): - from stlearn.plotting.classes_bokeh import BokehGenePlot + from stlearn.pl.classes_bokeh import BokehGenePlot gp_object = BokehGenePlot(adata) doc.add_root(row(gp_object.layout, width=800)) @@ -402,7 +383,7 @@ def modify_doc_gene_plot(doc): def modify_doc_cluster_plot(doc): - from stlearn.plotting.classes_bokeh import BokehClusterPlot + from stlearn.pl.classes_bokeh import BokehClusterPlot gp_object = BokehClusterPlot(adata) doc.add_root(row(gp_object.layout, width=800)) @@ -423,7 +404,7 @@ def modify_doc_cluster_plot(doc): def modify_doc_spatial_cci_plot(doc): - from stlearn.plotting.classes_bokeh import BokehSpatialCciPlot + from stlearn.pl.classes_bokeh import BokehSpatialCciPlot gp_object = BokehSpatialCciPlot(adata) doc.add_root(row(gp_object.layout, width=800)) @@ -439,7 +420,7 @@ def modify_doc_spatial_cci_plot(doc): def modify_doc_lr_plot(doc): - from stlearn.plotting.classes_bokeh import BokehLRPlot + from stlearn.pl.classes_bokeh import BokehLRPlot gp_object = BokehLRPlot(adata) doc.add_root(row(gp_object.layout, width=800)) @@ -453,7 +434,7 @@ def modify_doc_lr_plot(doc): def modify_doc_annotate_plot(doc): - from stlearn.plotting.classes_bokeh import Annotate + from stlearn.pl.classes_bokeh import Annotate gp_object = Annotate(adata) doc.add_root(row(gp_object.layout, width=800)) @@ -491,12 +472,10 @@ def bk_worker(): "/bokeh_annotate_plot": bkapp4, }, io_loop=IOLoop(), - allow_websocket_origin=["127.0.0.1:5000", "localhost:5000"], + allow_websocket_origin=["127.0.0.1:3000", "localhost:3000"], ) server.start() server.io_loop.start() -from threading import Thread - Thread(target=bk_worker).start() diff --git a/stlearn/app/cli.py b/stlearn/app/cli.py index 20154df4..78bfe02b 100644 --- a/stlearn/app/cli.py +++ b/stlearn/app/cli.py @@ -1,7 +1,8 @@ +import errno + import click -from .. import __version__ -import os +from .. import __version__ @click.group( @@ -18,7 +19,6 @@ help="Show the software version and exit.", ) def main(): - os._exit click.echo("Please run `stlearn launch` to start the web app") @@ -27,10 +27,14 @@ def launch(): from .app import app try: - app.run(host="0.0.0.0", port=5000, debug=True, use_reloader=False) + app.run(host="0.0.0.0", port=3000, debug=True, use_reloader=False) except OSError as e: if e.errno == errno.EADDRINUSE: raise click.ClickException( "Port is in use, please specify an open port using the --port flag." ) from e raise + + +if __name__ == "__main__": + main() diff --git a/stlearn/app/source/forms/form_validators.py b/stlearn/app/source/forms/form_validators.py index 3a82f887..4a279164 100644 --- a/stlearn/app/source/forms/form_validators.py +++ b/stlearn/app/source/forms/form_validators.py @@ -1,16 +1,15 @@ -""" Contains different kinds of form validators. -""" +"""Contains different kinds of form validators.""" + from wtforms.validators import ValidationError -class CheckNumberRange(object): +class CheckNumberRange: def __init__(self, lower, upper, hint=""): self.lower = lower self.upper = upper self.hint = hint def __call__(self, form, field): - if field.data is not None: if not (self.lower <= float(field.data) <= self.upper): if self.hint: diff --git a/stlearn/app/source/forms/forms.py b/stlearn/app/source/forms/forms.py index 0eef6b1d..466c1da1 100644 --- a/stlearn/app/source/forms/forms.py +++ b/stlearn/app/source/forms/forms.py @@ -1,64 +1,63 @@ """Purpose of this script is to create general forms that are programmable with - particular input. Will impliment forms for subsetting the data and - visualisation options in a general way so can be used with any - SingleCellAnalysis dataset. +particular input. Will impliment forms for subsetting the data and +visualisation options in a general way so can be used with any +SingleCellAnalysis dataset. """ -import sys +import wtforms from flask_wtf import FlaskForm # from flask_wtf.file import FileField -from wtforms import SelectMultipleField, SelectField -import wtforms +from wtforms import SelectField, SelectMultipleField def createSuperForm(elements, element_fields, element_values, validators=None): """ Creates a general form; goal is to create a fully programmable form \ - that essentially governs all the options the user will select. + that essentially governs all the options the user will select. - Args: - elements (list): Element names to be rendered on the page, in \ - order of how they will appear on the page. + Args: + elements (list): Element names to be rendered on the page, in \ + order of how they will appear on the page. - element_fields (list): The names of the fields to be rendered. \ - Each field is in same order as 'elements'. \ - Currently supported are: \ - 'Title', 'SelectMultipleField', 'SelectField', \ - 'StringField', 'Text', 'List'. + element_fields (list): The names of the fields to be rendered. \ + Each field is in same order as 'elements'. \ + Currently supported are: \ + 'Title', 'SelectMultipleField', 'SelectField', \ + 'StringField', 'Text', 'List'. - element_values (list): The information which will be put into \ - the field. Changes depending on field: \ + element_values (list): The information which will be put into \ + the field. Changes depending on field: \ - 'Title' and 'Text': 'object' is a string - containing the title which will be added as \ - a heading when rendered on the page. + 'Title' and 'Text': 'object' is a string + containing the title which will be added as \ + a heading when rendered on the page. - 'SelectMultipleField' and 'SelectField': - 'object' is list of options to select from. + 'SelectMultipleField' and 'SelectField': + 'object' is list of options to select from. - 'StringField': - The example values to display within the \ - fields text area. The 'placeholder' option. + 'StringField': + The example values to display within the \ + fields text area. The 'placeholder' option. - 'List': - A list of objects which will be attached \ - to the form. + 'List': + A list of objects which will be attached \ + to the form. - validators (list): A list of functions which take the \ - form as input, used to construct the form validator. \ - Form validator constructed by calling these \ - sequentially with form 'self' as input. + validators (list): A list of functions which take the \ + form as input, used to construct the form validator. \ + Form validator constructed by calling these \ + sequentially with form 'self' as input. - Args: - form (list): A WTForm which has attached as variable all the \ - fields mentioned, so then when rendered as input to - 'SuperDataDisplay.html' shows the form. - """ + Args: + form (list): A WTForm which has attached as variable all the \ + fields mentioned, so then when rendered as input to + 'SuperDataDisplay.html' shows the form. + """ class SuperForm(FlaskForm): """A base form on which all of the fields will be added.""" - if type(validators) == type(None): + if validators is None: validators = [None] * len(elements) # Add the information # @@ -82,7 +81,7 @@ class SuperForm(FlaskForm): # left. setattr(SuperForm, element + "_number", int(multiSelectLeft)) # inverts, so if left, goes right for the next multiSelectField - multiSelectLeft = multiSelectLeft == False + multiSelectLeft = not multiSelectLeft else: multiSelectLeft = True # Reset the MultiSelectField position @@ -100,9 +99,9 @@ class SuperForm(FlaskForm): ) # elif fieldName == 'FileField': - # setattr(SuperForm, element, FileField(validators=validators[i])) - # setattr(SuperForm, element + '_placeholder', # Setting default - # element_values[i]) + # setattr(SuperForm, element, FileField(validators=validators[i])) + # setattr(SuperForm, element + '_placeholder', # Setting default + # element_values[i]) elif fieldName in [ "StringField", @@ -198,10 +197,13 @@ def getCCIForm(adata): related to CCI analysis. """ elements = [ - "Cell information (only discrete labels available, unless mixture already in anndata.uns)", + "Cell information (only discrete labels available, unless mixture already in " + + "anndata.uns)", "Minimum spots for LR to be considered", - "Spot mixture (only if the 'Cell Information' label selected available in anndata.uns)", - "Cell proportion cutoff (value above which cell is considered in spot if 'Spot mixture' selected)", + "Spot mixture (only if the 'Cell Information' label selected available in " + + "anndata.uns)", + "Cell proportion cutoff (value above which cell is considered in spot " + + "if 'Spot mixture' selected)", "Permutations (recommend atleast 1000)", ] element_fields = [ @@ -211,12 +213,12 @@ def getCCIForm(adata): "FloatField", "IntegerField", ] - if type(adata) == type(None): + if adata is None: fields = [] mix = False else: fields = [ - key for key in adata.obs.keys() if type(adata.obs[key].values[0]) == str + key for key in adata.obs.keys() if isinstance(adata.obs[key].values[0], str) ] mix = fields[0] in adata.uns.keys() element_values = [fields, 20, mix, 0.2, 100] @@ -279,7 +281,7 @@ def getPSTSForm(trajectory, clusts, options): Args: cluster_set (numpy.array): The clusters which can be selected as - the root for psts analysis. + the root for psts analysis. Returns: FlaskForm: With attributes that allow input related to psts. @@ -308,7 +310,7 @@ def getDEAForm(list_labels, methods): Args: cluster_set (numpy.array): The clusters which can be selected as - the root for psts analysis. + the root for psts analysis. Returns: FlaskForm: With attributes that allow input related to psts. @@ -325,40 +327,40 @@ def getDEAForm(list_labels, methods): ######################## Junk Code ############################################# # def getCCIForm(step_log): -# """ Gets the CCI form generated from the superform above. +# """ Gets the CCI form generated from the superform above. # -# Returns: -# FlaskForm: With attributes that allow for inputs that are related to -# CCI analysis. -# """ -# elements, element_fields, element_values = [], [], [] -# if type(step_log['cci_het']) == type(None): -# # Analysis type form version # -# analysis_elements = ['Cell Heterogeneity Information', # Title -# 'cci_het', -# 'Permutation Testing', # Title -# 'cci_perm'] -# analysis_fields = ['Title', 'SelectField', 'Title', 'SelectField'] -# label_transfer_options = ['Upload Cell Label Transfer', -# 'No Cell Label Transfer'] -# permutation_options = ['With permutation testing', -# 'Without permutation testing'] -# analysis_values = ['', label_transfer_options, '', permutation_options] -# elements += analysis_elements -# element_fields += analysis_fields -# element_values += analysis_values +# Returns: +# FlaskForm: With attributes that allow for inputs that are related to +# CCI analysis. +# """ +# elements, element_fields, element_values = [], [], [] +# if type(step_log['cci_het']) == type(None): +# # Analysis type form version # +# analysis_elements = ['Cell Heterogeneity Information', # Title +# 'cci_het', +# 'Permutation Testing', # Title +# 'cci_perm'] +# analysis_fields = ['Title', 'SelectField', 'Title', 'SelectField'] +# label_transfer_options = ['Upload Cell Label Transfer', +# 'No Cell Label Transfer'] +# permutation_options = ['With permutation testing', +# 'Without permutation testing'] +# analysis_values = ['', label_transfer_options, '', permutation_options] +# elements += analysis_elements +# element_fields += analysis_fields +# element_values += analysis_values # -# else: -# # Core elements regardless of CCI mode # -# elements += ['Neighbourhood distance', -# 'L-R pair input (e.g. L1_R1, L2_R2, ...)'] -# element_fields += ['IntegerField', 'StringField'] -# element_values += [5, ''] +# else: +# # Core elements regardless of CCI mode # +# elements += ['Neighbourhood distance', +# 'L-R pair input (e.g. L1_R1, L2_R2, ...)'] +# element_fields += ['IntegerField', 'StringField'] +# element_values += [5, ''] # -# if step_log['cci_perm']: -# # Including cell heterogeneity information # -# elements += ['Permutations'] -# element_fields += ['IntegerField'] -# element_values += [200] +# if step_log['cci_perm']: +# # Including cell heterogeneity information # +# elements += ['Permutations'] +# element_fields += ['IntegerField'] +# element_values += [200] # -# return createSuperForm(elements, element_fields, element_values, None) +# return createSuperForm(elements, element_fields, element_values, None) diff --git a/stlearn/app/source/forms/helper_functions.py b/stlearn/app/source/forms/helper_functions.py index e9a64e40..692c98a9 100644 --- a/stlearn/app/source/forms/helper_functions.py +++ b/stlearn/app/source/forms/helper_functions.py @@ -1,6 +1,5 @@ # Purpose of this script is to write the functions that help facilitate # subsetting of the data depending on the users input -import numpy def printOut(text, fileName="stdout.txt", close=True, file=None): @@ -8,7 +7,7 @@ def printOut(text, fileName="stdout.txt", close=True, file=None): If close is Fale, returns open file. """ - if type(file) == type(None): + if file is None: file = open(fileName, "w") print(text, file=file) @@ -21,7 +20,7 @@ def printOut(text, fileName="stdout.txt", close=True, file=None): def filterOptions(metaDataSets, options): """Returns options that overlap with keys in metaDataSets dictionary""" - if type(options) == type(None): + if options is None: options = list(metaDataSets.keys()) else: options = [option for option in options if option in metaDataSets.keys()] diff --git a/stlearn/app/source/forms/utils.py b/stlearn/app/source/forms/utils.py index 3782c74f..42121bcf 100644 --- a/stlearn/app/source/forms/utils.py +++ b/stlearn/app/source/forms/utils.py @@ -1,7 +1,6 @@ -# -*- coding: utf-8 -*- """Helper utilities and decorators.""" + from flask import flash -import matplotlib.pyplot as plt def flash_errors(form, category="warning"): @@ -12,7 +11,6 @@ def flash_errors(form, category="warning"): def get_all_paths(adata): - import networkx as nx G = nx.from_numpy_array(adata.uns["paga"]["connectivities_tree"].toarray()) diff --git a/stlearn/app/source/forms/view_helpers.py b/stlearn/app/source/forms/view_helpers.py index 499edd7e..3c2de3d0 100644 --- a/stlearn/app/source/forms/view_helpers.py +++ b/stlearn/app/source/forms/view_helpers.py @@ -1,7 +1,4 @@ -""" Helper functions for views.py. -""" - -import numpy +"""Helper functions for views.py.""" def getVal(form, element): diff --git a/stlearn/app/source/forms/views.py b/stlearn/app/source/forms/views.py index 551c737e..3a857319 100644 --- a/stlearn/app/source/forms/views.py +++ b/stlearn/app/source/forms/views.py @@ -1,24 +1,20 @@ """ This is more a general views focussed on defining functions which are \ - called by other views for specify pages. This way different pages can be \ - used to display different data, but in a consistent way. + called by other views for specify pages. This way different pages can be \ + used to display different data, but in a consistent way. """ import sys -import numpy -import numpy as np -from flask import flash -from source.forms import forms - -from source.forms.utils import flash_errors -import source.forms.view_helpers as vhs import traceback -from flask import render_template - +import numpy +import numpy as np import scanpy as sc -import stlearn as st +from flask import flash, render_template -from scipy.spatial.distance import cosine +import stlearn as st +import stlearn.app.source.forms.view_helpers as vhs +from stlearn.app.source.forms import forms +from stlearn.app.source.forms.utils import flash_errors # Creating the forms using a class generator # PreprocessForm = forms.getPreprocessForm() @@ -35,7 +31,7 @@ def run_preprocessing(request, adata, step_log): if not form.validate_on_submit(): flash_errors(form) - elif type(adata) == type(None): + elif adata is None: flash("Need to load data first!") else: @@ -87,13 +83,12 @@ def run_lr(request, adata, step_log): if not form.validate_on_submit(): flash_errors(form) - elif type(adata) == type(None): + elif adata is None: flash("Need to load data first!") else: step_log["lr_params"] = vhs.getData(form) print(step_log["lr_params"], file=sys.stdout) - elements = numpy.array(list(step_log["lr_params"].keys())) # order: Species, Spot neighbourhood, min_spots, n_pairs, CPUs element_values = list(step_log["lr_params"].values()) dist = element_values[1] @@ -134,13 +129,12 @@ def run_cci(request, adata, step_log): if not form.validate_on_submit(): flash_errors(form) - elif type(adata) == type(None): + elif adata is None: flash("Need to load data first!") else: step_log["cci_params"] = vhs.getData(form) print(step_log["cci_params"], file=sys.stdout) - elements = numpy.array(list(step_log["cci_params"].keys())) # order: cell_type, min_spots, spot_mixtures, cell_prop_cutoff, sig_spots # n_perms element_values = list(step_log["cci_params"].values()) @@ -188,14 +182,13 @@ def run_clustering(request, adata, step_log): step_log["cluster_params"] = vhs.getData(form) print(step_log["cluster_params"], file=sys.stdout) - elements = list(step_log["cluster_params"].keys()) # order: pca_comps, SME bool, method, method_param element_values = list(step_log["cluster_params"].values()) if not form.validate_on_submit(): flash_errors(form) - elif type(adata) == type(None): + elif adata is None: flash("Need to load data first!") else: @@ -275,7 +268,6 @@ def run_psts(request, adata, step_log): step_log["psts_params"] = vhs.getData(form) print(step_log["psts_params"], file=sys.stdout) - elements = list(step_log["psts_params"].keys()) # order: pca_comps, SME bool, method, method_param element_values = list(step_log["psts_params"].values()) @@ -289,12 +281,12 @@ def run_psts(request, adata, step_log): if not form.validate_on_submit(): flash_errors(form) - elif type(adata) == type(None): + elif adata is None: flash("Need to load data first!") else: try: - from stlearn.spatials.trajectory import set_root + from stlearn.spatial.trajectory import set_root root_index = set_root( adata, use_label="clusters", cluster=str(element_values[0]) @@ -349,7 +341,6 @@ def run_psts(request, adata, step_log): def run_dea(request, adata, step_log): - list_labels = [] for col in adata.obs.columns: @@ -366,18 +357,16 @@ def run_dea(request, adata, step_log): step_log["dea_params"] = vhs.getData(form) print(step_log["dea_params"], file=sys.stdout) - elements = list(step_log["dea_params"].keys()) element_values = list(step_log["dea_params"].values()) if not form.validate_on_submit(): flash_errors(form) - elif type(adata) == type(None): + elif adata is None: flash("Need to load data first!") else: try: - sc.tl.rank_genes_groups(adata, element_values[0], method=element_values[1]) step_log["dea"][0] = True diff --git a/stlearn/classes.py b/stlearn/classes.py index afa4b997..b131c0d6 100644 --- a/stlearn/classes.py +++ b/stlearn/classes.py @@ -4,40 +4,37 @@ Date: 20 Feb 2021 """ -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes - import numpy as np from anndata import AnnData from .utils import ( Empty, - _empty, - _check_spatial_data, + _check_coords, _check_img, - _check_spot_size, _check_scale_factor, - _check_coords, + _check_spatial_data, + _check_spot_size, + _empty, ) -class Spatial(object): +class Spatial: + img: np.ndarray | None + def __init__( self, adata: AnnData, basis: str = "spatial", - img: Union[np.ndarray, None] = None, - img_key: Union[str, None, Empty] = _empty, - library_id: Union[str, None] = _empty, - crop_coord: Optional[bool] = True, - bw: Optional[bool] = False, - scale_factor: Optional[float] = None, - spot_size: Optional[float] = None, - use_raw: Optional[bool] = False, + img: np.ndarray | None = None, + img_key: str | None | Empty = _empty, + library_id: str | None | Empty = _empty, + crop_coord: bool = True, + bw: bool = False, + scale_factor: float | None = None, + spot_size: float | None = None, + use_raw: bool = False, **kwargs, ): - self.adata = (adata,) self.library_id, self.spatial_data = _check_spatial_data(adata.uns, library_id) self.img, self.img_key = _check_img(self.spatial_data, img, img_key, bw=bw) diff --git a/stlearn/datasets.py b/stlearn/datasets.py index 068c89d0..34a6ffd7 100644 --- a/stlearn/datasets.py +++ b/stlearn/datasets.py @@ -1 +1,3 @@ -from ._datasets._datasets import example_bcba +from ._datasets._datasets import visium_sge, xenium_sge + +__all__ = ["visium_sge", "xenium_sge"] diff --git a/stlearn/em.py b/stlearn/em.py index 193ade80..39d0c2db 100644 --- a/stlearn/em.py +++ b/stlearn/em.py @@ -1,7 +1,16 @@ -from .embedding.pca import run_pca -from .embedding.umap import run_umap -from .embedding.ica import run_ica +# from .embedding.scvi import run_ldvae +from .embedding.diffmap import run_diffmap # from .embedding.scvi import run_ldvae from .embedding.fa import run_fa -from .embedding.diffmap import run_diffmap +from .embedding.ica import run_ica +from .embedding.pca import run_pca +from .embedding.umap import run_umap + +__all__ = [ + "run_pca", + "run_umap", + "run_ica", + "run_fa", + "run_diffmap", +] diff --git a/stlearn/embedding/diffmap.py b/stlearn/embedding/diffmap.py index 97f916e8..fb309d9e 100644 --- a/stlearn/embedding/diffmap.py +++ b/stlearn/embedding/diffmap.py @@ -1,9 +1,5 @@ -from typing import Optional, Union -import numpy as np -from anndata import AnnData -from numpy.random.mtrand import RandomState -from scipy.sparse import issparse import scanpy +from anndata import AnnData def run_diffmap(adata: AnnData, n_comps: int = 15, copy: bool = False): @@ -38,10 +34,11 @@ def run_diffmap(adata: AnnData, n_comps: int = 15, copy: bool = False): Eigenvalues of transition matrix. """ - scanpy.tl.diffmap(adata, n_comps=n_comps, copy=copy) + adata = scanpy.tl.diffmap(adata, n_comps=n_comps, copy=copy) print( - "Diffusion Map is done! Generated in adata.obsm['X_diffmap'] nad adata.uns['diffmap_evals']" + "Diffusion Map is done! Generated in adata.obsm['X_diffmap'] and " + + "adata.uns['diffmap_evals']" ) - return adata if copy else None + return adata diff --git a/stlearn/embedding/fa.py b/stlearn/embedding/fa.py index 9c463aee..953ff96a 100644 --- a/stlearn/embedding/fa.py +++ b/stlearn/embedding/fa.py @@ -1,10 +1,6 @@ -import numpy as np -import pandas as pd -from typing import Optional - from anndata import AnnData -from sklearn.decomposition import FactorAnalysis from scipy.sparse import issparse +from sklearn.decomposition import FactorAnalysis def run_fa( @@ -15,10 +11,9 @@ def run_fa( svd_method: str = "randomized", iterated_power: int = 3, random_state: int = 2108, - use_data: str = None, + use_data: str | None = None, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Factor Analysis (FA) A simple linear generative model with Gaussian latent variables. @@ -74,6 +69,8 @@ def run_fa( Factor analysis representation of data. """ + adata = adata.copy() if copy else adata + if use_data is None: if issparse(adata.X): matrix = adata.X.toarray() diff --git a/stlearn/embedding/ica.py b/stlearn/embedding/ica.py index ed2f1d21..fde64c40 100644 --- a/stlearn/embedding/ica.py +++ b/stlearn/embedding/ica.py @@ -1,9 +1,6 @@ -import numpy as np -import pandas as pd -from typing import Optional from anndata import AnnData -from sklearn.decomposition import FastICA from scipy.sparse import issparse +from sklearn.decomposition import FastICA def run_ica( @@ -11,10 +8,9 @@ def run_ica( n_factors: int = 20, fun: str = "logcosh", tol: float = 0.0001, - use_data: str = None, + use_data: str | None = None, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ FastICA: a fast algorithm for Independent Component Analysis. @@ -47,25 +43,24 @@ def my_g(x): Independent Component Analysis representation of data. """ + adata = adata.copy() if copy else adata + if use_data is None: if issparse(adata.X): matrix = adata.X.toarray() else: matrix = adata.X - else: matrix = adata.obsm[use_data].values ica = FastICA(n_components=n_factors, fun=fun, tol=tol) - latent = ica.fit_transform(matrix) - adata.obsm["X_ica"] = latent - adata.uns["ica"] = {"params": {"n_factors": n_factors, "fun": fun, "tol": tol}} print( - "ICA is done! Generated in adata.obsm['X_ica'] and parameters in adata.uns['ica']" + "ICA is done! Generated in adata.obsm['X_ica'] and parameters in " + + "adata.uns['ica']" ) return adata if copy else None diff --git a/stlearn/embedding/pca.py b/stlearn/embedding/pca.py index 040a3b6f..8870994e 100644 --- a/stlearn/embedding/pca.py +++ b/stlearn/embedding/pca.py @@ -1,25 +1,23 @@ -import logging as logg -from typing import Union, Optional, Tuple, Collection, Sequence, Iterable -from anndata import AnnData import numpy as np -from scipy.sparse import issparse, isspmatrix_csr, csr_matrix, spmatrix -from numpy.random.mtrand import RandomState import scanpy +from anndata import AnnData +from numpy.random.mtrand import RandomState +from scipy.sparse import spmatrix def run_pca( - data: Union[AnnData, np.ndarray, spmatrix], + data: AnnData | np.ndarray | spmatrix, n_comps: int = 50, - zero_center: Optional[bool] = True, + zero_center: bool | None = True, svd_solver: str = "auto", - random_state: Optional[Union[int, RandomState]] = 0, + random_state: int | RandomState | None = 0, return_info: bool = False, - use_highly_variable: Optional[bool] = None, + use_highly_variable: bool | None = None, dtype: str = "float32", copy: bool = False, chunked: bool = False, - chunk_size: Optional[int] = None, -) -> Union[AnnData, np.ndarray, spmatrix]: + chunk_size: int | None = None, +) -> AnnData | None: """\ Wrap function scanpy.pp.pca Principal component analysis [Pedregosa11]_. @@ -85,7 +83,7 @@ def run_pca( covariance matrix. """ - scanpy.pp.pca( + adata = scanpy.pp.pca( data, n_comps=n_comps, zero_center=zero_center, @@ -100,5 +98,8 @@ def run_pca( ) print( - "PCA is done! Generated in adata.obsm['X_pca'], adata.uns['pca'] and adata.varm['PCs']" + "PCA is done! Generated in adata.obsm['X_pca'], adata.uns['pca'] and " + + "adata.varm['PCs']" ) + + return adata diff --git a/stlearn/embedding/umap.py b/stlearn/embedding/umap.py index 912aaa00..ad3079ca 100644 --- a/stlearn/embedding/umap.py +++ b/stlearn/embedding/umap.py @@ -1,12 +1,10 @@ -from typing import Optional, Union +from typing import Literal import numpy as np +import scanpy from anndata import AnnData from numpy.random.mtrand import RandomState -from .._compat import Literal -import scanpy - _InitPos = Literal["paga", "spectral", "random"] @@ -15,17 +13,17 @@ def run_umap( min_dist: float = 0.5, spread: float = 1.0, n_components: int = 2, - maxiter: Optional[int] = None, + maxiter: int | None = None, alpha: float = 1.0, gamma: float = 1.0, negative_sample_rate: int = 5, - init_pos: Union[_InitPos, np.ndarray, None] = "spectral", - random_state: Optional[Union[int, RandomState]] = 0, - a: Optional[float] = None, - b: Optional[float] = None, + init_pos: _InitPos | np.ndarray | None = "spectral", + random_state: int | RandomState | None = 0, + a: float | None = None, + b: float | None = None, copy: bool = False, - method: Literal["umap", "rapids"] = "umap", -) -> Optional[AnnData]: + method: Literal["umap", "rapids"] = "umap", # noqa: F821 +) -> AnnData | None: """\ Wrap function scanpy.pp.umap Embed the neighborhood graph using UMAP [McInnes18]_. @@ -58,7 +56,7 @@ def run_umap( """ - scanpy.tl.umap( + adata = scanpy.tl.umap( adata, min_dist=min_dist, spread=spread, @@ -76,3 +74,5 @@ def run_umap( ) print("UMAP is done! Generated in adata.obsm['X_umap'] nad adata.uns['umap']") + + return adata diff --git a/stlearn/image_preprocessing/feature_extractor.py b/stlearn/image_preprocessing/feature_extractor.py index b2946ee8..a4cf5730 100644 --- a/stlearn/image_preprocessing/feature_extractor.py +++ b/stlearn/image_preprocessing/feature_extractor.py @@ -1,15 +1,13 @@ -from .model_zoo import encode, Model -from typing import Optional, Union -from anndata import AnnData +from typing import Literal + import numpy as np -from .._compat import Literal +from anndata import AnnData from PIL import Image -import pandas as pd -from pathlib import Path - -# Test progress bar +from sklearn.decomposition import PCA from tqdm import tqdm +from .model_zoo import Model + _CNN_BASE = Literal["resnet50", "vgg16", "inception_v3", "xception"] @@ -17,68 +15,91 @@ def extract_feature( adata: AnnData, cnn_base: _CNN_BASE = "resnet50", n_components: int = 50, + seeds: int = 1, verbose: bool = False, copy: bool = False, - seeds: int = 1, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Extract latent morphological features from H&E images using pre-trained convolutional neural network base Parameters ---------- - adata + adata: Annotated data matrix. - cnn_base + cnn_base: Established convolutional neural network bases choose one from ['resnet50', 'vgg16', 'inception_v3', 'xception'] - n_components + n_components: Number of principal components to compute for latent morphological features - verbose + seeds: + Fix random state + verbose: Verbose output - copy + copy: Return a copy instead of writing to adata. - seeds - Fix random state Returns ------- Depending on `copy`, returns or updates `adata` with the following fields. **X_morphology** : `adata.obsm` field Dimension reduced latent morphological features. + Raises + ------ + ValueError + If any image fails to process or if tile_path column is missing. """ - feature_dfs = [] - model = Model(cnn_base) + + adata = adata.copy() if copy else adata if "tile_path" not in adata.obs: raise ValueError("Please run the function stlearn.pp.tiling") + model = Model(cnn_base) + + # Pre-allocate feature matrix, spot names and arrays to avoid overhead + tile_paths = adata.obs["tile_path"].values + n_spots = len(tile_paths) + if n_spots == 0: + raise ValueError("No tile paths found in adata.obs['tile_path']") + + first_features = _read_and_predict(tile_paths[0], model, verbose=verbose) + n_features = len(first_features) + + # Setup feature matrix + feature_matrix = np.empty((n_spots, n_features), dtype=np.float32) + feature_matrix[0] = first_features + with tqdm( - total=len(adata), + total=n_spots, desc="Extract feature", bar_format="{l_bar}{bar} [ time left: {remaining} ]", + initial=1, # We already processed the first image ) as pbar: - for spot, tile_path in adata.obs["tile_path"].items(): - tile = Image.open(tile_path) - tile = np.asarray(tile, dtype="int32") - tile = tile.astype(np.float32) - tile = np.stack([tile]) - if verbose: - print("extract feature for spot: {}".format(str(spot))) - features = encode(tile, model) - feature_dfs.append(pd.DataFrame(features, columns=[spot])) - pbar.update(1) + for i in range(1, n_spots): + features = _read_and_predict(tile_paths[i], model, verbose=verbose) + feature_matrix[i] = features + if i % 100 == 0: + pbar.update(100) - feature_df = pd.concat(feature_dfs, axis=1) + adata.obsm["X_tile_feature"] = feature_matrix + pca = PCA(n_components=n_components, random_state=seeds) + pca.fit(feature_matrix) + adata.obsm["X_morphology"] = pca.transform(feature_matrix) - adata.obsm["X_tile_feature"] = feature_df.transpose().to_numpy() + print("The morphology feature is added to adata.obsm['X_morphology']!") - from sklearn.decomposition import PCA + return adata if copy else None - pca = PCA(n_components=n_components, random_state=seeds) - pca.fit(feature_df.transpose().to_numpy()) - adata.obsm["X_morphology"] = pca.transform(feature_df.transpose().to_numpy()) +def _read_and_predict(path, model, verbose=False): + try: + with Image.open(path) as img: + tile = np.asarray(img, dtype=np.float32) - print("The morphology feature is added to adata.obsm['X_morphology']!") + if verbose: + print(f"Loaded image: {path}") - return adata if copy else None + tile = tile[np.newaxis, ...] + return model.predict(tile).ravel() + except Exception as e: + raise ValueError(f"Failed to process image: {path}. Error: {str(e)}") diff --git a/stlearn/image_preprocessing/image_tiling.py b/stlearn/image_preprocessing/image_tiling.py index bdb88a60..73f8b4b7 100644 --- a/stlearn/image_preprocessing/image_tiling.py +++ b/stlearn/image_preprocessing/image_tiling.py @@ -1,44 +1,46 @@ -from typing import Optional, Union -from anndata import AnnData -from .._compat import Literal -from PIL import Image from pathlib import Path -# Test progress bar -from tqdm import tqdm import numpy as np -import os +from anndata import AnnData +from PIL import Image +from tqdm import tqdm def tiling( adata: AnnData, - out_path: Union[Path, str] = "./tiling", - library_id: Union[str, None] = None, + out_path: Path | str = "./tiling", + library_id: str | None = None, crop_size: int = 40, target_size: int = 299, img_fmt: str = "JPEG", + quality: int = 75, verbose: bool = False, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ - Tiling H&E images to small tiles based on spot spatial location + Tiling H&E images to small tiles based on spot spatial location. Parameters ---------- - adata - Annotated data matrix. - out_path - Path to save spot image tiles - library_id - Library id stored in AnnData. - crop_size - Size of tiles - verbose - Verbose output - copy - Return a copy instead of writing to adata. - target_size - Input size for convolutional neuron network + adata: AnnData + Annotated data matrix containing spatial information. + out_path: Path or str, default "./tiling" + Path to save spot image tiles. + library_id: str, optional + Library id stored in AnnData. If None, uses first available library. + crop_size: int, default 40 + Size of tiles to crop from original image. + target_size: int, default 299 + Target size for resized tiles (input size for CNN). + img_fmt: str, default "JPEG" + Image format ('JPEG' or 'PNG'). + quality: int, default 75 + JPEG quality (1-100). Only used for JPEG format. + verbose: bool, default False + Enable verbose output. + copy: bool, default False + Return a copy instead of modifying adata in-place. + Returns ------- Depending on `copy`, returns or updates `adata` with the following fields. @@ -46,22 +48,17 @@ def tiling( Saved path for each spot image tiles """ - if library_id is None: - library_id = list(adata.uns["spatial"].keys())[0] + _validate_inputs(crop_size, target_size, img_fmt, quality) - # Check the exist of out_path - if not os.path.isdir(out_path): - os.mkdir(out_path) + adata = adata.copy() if copy else adata - image = adata.uns["spatial"][library_id]["images"][ - adata.uns["spatial"][library_id]["use_quality"] - ] - if image.dtype == np.float32 or image.dtype == np.float64: - image = (image * 255).astype(np.uint8) - img_pillow = Image.fromarray(image) + out_path = Path(out_path) + out_path.mkdir(parents=True, exist_ok=True) - if img_pillow.mode == "RGBA": - img_pillow = img_pillow.convert("RGB") + library_id = _get_library_id(adata, library_id) + img_pillow = _load_and_prepare_image(adata, library_id) + + coordinates = list(zip(adata.obs["imagerow"], adata.obs["imagecol"])) tile_names = [] @@ -70,35 +67,88 @@ def tiling( desc="Tiling image", bar_format="{l_bar}{bar} [ time left: {remaining} ]", ) as pbar: - for imagerow, imagecol in zip(adata.obs["imagerow"], adata.obs["imagecol"]): - imagerow_down = imagerow - crop_size / 2 - imagerow_up = imagerow + crop_size / 2 - imagecol_left = imagecol - crop_size / 2 - imagecol_right = imagecol + crop_size / 2 + for image_row, image_col in coordinates: + half_crop = crop_size // 2 + image_row_down = max(0, image_row - half_crop) + image_row_up = image_row + half_crop + image_col_left = max(0, image_col - half_crop) + image_col_right = image_col + half_crop + tile = img_pillow.crop( - (imagecol_left, imagerow_down, imagecol_right, imagerow_up) + (image_col_left, image_row_down, image_col_right, image_row_up) ) + tile.thumbnail((target_size, target_size), Image.Resampling.LANCZOS) tile = tile.resize((target_size, target_size)) - tile_name = str(imagecol) + "-" + str(imagerow) + "-" + str(crop_size) + tile_name = str(image_col) + "-" + str(image_row) + "-" + str(crop_size) if img_fmt == "JPEG": out_tile = Path(out_path) / (tile_name + ".jpeg") tile_names.append(str(out_tile)) - tile.save(out_tile, "JPEG") + tile.save(out_tile, "JPEG", quality=quality) else: out_tile = Path(out_path) / (tile_name + ".png") tile_names.append(str(out_tile)) tile.save(out_tile, "PNG") if verbose: - print( - "generate tile at location ({}, {})".format( - str(imagecol), str(imagerow) - ) - ) + print(f"generate tile at location ({str(image_col)}, {str(image_row)})") pbar.update(1) adata.obs["tile_path"] = tile_names return adata if copy else None + + +def _validate_inputs( + crop_size: int, target_size: int, img_fmt: str, quality: int +) -> None: + + if not isinstance(crop_size, int) or crop_size <= 0: + raise ValueError("crop_size must be a positive integer") + + if not isinstance(target_size, int) or target_size <= 0: + raise ValueError("target_size must be a positive integer") + + if img_fmt.upper() not in ["JPEG", "PNG"]: + raise ValueError("img_fmt must be 'JPEG' or 'PNG'") + + if img_fmt.upper() == "JPEG" and ( + not isinstance(quality, int) or not 1 <= quality <= 100 + ): + raise ValueError("quality must be an integer between 1 and 100 for JPEG format") + + +def _get_library_id(adata: AnnData, library_id: str | None) -> str: + if library_id is None: + try: + library_id = list(adata.uns["spatial"].keys())[0] + except (KeyError, IndexError): + raise ValueError("No spatial data found in adata.uns['spatial']") + + if library_id not in adata.uns["spatial"]: + raise ValueError(f"Library '{library_id}' not found in spatial data") + + return library_id + + +def _load_and_prepare_image(adata: AnnData, library_id: str) -> Image.Image: + try: + spatial_data = adata.uns["spatial"][library_id] + use_quality = spatial_data["use_quality"] + image = spatial_data["images"][use_quality] + except KeyError as e: + raise ValueError( + f"Could not find image data in adata.uns['spatial']['{library_id}']: {e}" + ) + + if image.dtype in (np.float32, np.float64): + image = np.clip(image, 0, 1) + image = (image * 255).astype(np.uint8) + + img_pillow = Image.fromarray(image) + + if img_pillow.mode == "RGBA": + img_pillow = img_pillow.convert("RGB") + + return img_pillow diff --git a/stlearn/image_preprocessing/model_zoo.py b/stlearn/image_preprocessing/model_zoo.py index a028f75f..1969f9b1 100644 --- a/stlearn/image_preprocessing/model_zoo.py +++ b/stlearn/image_preprocessing/model_zoo.py @@ -1,23 +1,17 @@ -def encode(tiles, model): - features = model.predict(tiles) - features = features.ravel() - return features - - class Model: __name__ = "CNN base model" def __init__(self, base, batch_size=1): - from tensorflow.keras import backend as K + from keras import backend as keras self.base = base self.model, self.preprocess = self.load_model() self.batch_size = batch_size - self.data_format = K.image_data_format() + self.data_format = keras.image_data_format() def load_model(self): if self.base == "resnet50": - from tensorflow.keras.applications.resnet50 import ( + from keras.applications.resnet50 import ( ResNet50, preprocess_input, ) @@ -26,11 +20,11 @@ def load_model(self): include_top=False, weights="imagenet", pooling="avg" ) elif self.base == "vgg16": - from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input + from keras.applications.vgg16 import VGG16, preprocess_input cnn_base_model = VGG16(include_top=False, weights="imagenet", pooling="avg") elif self.base == "inception_v3": - from tensorflow.keras.applications.inception_v3 import ( + from keras.applications.inception_v3 import ( InceptionV3, preprocess_input, ) @@ -39,7 +33,7 @@ def load_model(self): include_top=False, weights="imagenet", pooling="avg" ) elif self.base == "xception": - from tensorflow.keras.applications.xception import ( + from keras.applications.xception import ( Xception, preprocess_input, ) @@ -48,13 +42,13 @@ def load_model(self): include_top=False, weights="imagenet", pooling="avg" ) else: - raise ValueError("{} is not a valid model".format(self.base)) + raise ValueError(f"{self.base} is not a valid model") return cnn_base_model, preprocess_input def predict(self, x): - from tensorflow.keras import backend as K + from keras import backend as keras if self.data_format == "channels_first": x = x.transpose(0, 3, 1, 2) - x = self.preprocess(x.astype(K.floatx())) + x = self.preprocess(x.astype(keras.floatx())) return self.model.predict(x, batch_size=self.batch_size, verbose=False) diff --git a/stlearn/image_preprocessing/segmentation.py b/stlearn/image_preprocessing/segmentation.py deleted file mode 100644 index 76023058..00000000 --- a/stlearn/image_preprocessing/segmentation.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Optional -import histomicstk as htk -import numpy as np -import scipy as sp -import skimage.color -import skimage.io -import skimage.measure -from anndata import AnnData -from scipy import ndimage as ndi -from skimage.feature import peak_local_max -from skimage.segmentation import watershed -from tqdm import tqdm - - -def morph_watershed( - adata: AnnData, - library_id: str = None, - verbose: bool = False, - copy: bool = False, -) -> Optional[AnnData]: - """\ - Watershed method to segment nuclei and calculate morphological statistics - - Parameters - ---------- - adata - Annotated data matrix. - library_id - Library id stored in AnnData. - copy - Return a copy instead of writing to adata. - Returns - ------- - Depending on `copy`, returns or updates `adata` with the following fields. - **n_nuclei** : `adata.obs` field - saved number of nuclei of each spot image tiles - **nuclei_total_area** : `adata.obs` field - saved of total area of nuclei of each spot image tiles - **nuclei_mean_area** : `adata.obs` field - saved mean area of nuclei of each spot image tiles - **nuclei_std_area** : `adata.obs` field - saved stand deviation of nuclei area of each spot image tiles - **eccentricity** : `adata.obs` field - saved eccentricity of each spot image tiles - **mean_pix_r** : `adata.obs` field - saved mean pixel value of red channel of of each spot image tiles - **std_pix_r** : `adata.obs` field - saved stand deviation of red channel of each spot image tiles - **mean_pix_g** : `adata.obs` field - saved mean pixel value of green channel of each spot image tiles - **std_pix_g** : `adata.obs` field - saved stand deviation of green channel of each spot image tiles - **mean_pix_b** : `adata.obs` field - saved mean pixel value of blue channel of each spot image tiles - **std_pix_b** : `adata.obs` field - saved stand deviation of blue channel of each spot image tiles - **nuclei_total_area_per_tile** : `adata.obs` field - saved total nuclei area per tile of each spot image tiles - """ - - if library_id is None: - library_id = list(adata.uns["spatial"].keys())[0] - - n_nuclei_list = [] - nuclei_total_area_list = [] - nuclei_mean_area_list = [] - nuclei_std_area_list = [] - eccentricity_list = [] - mean_pix_list_r = [] - std_pix_list_r = [] - mean_pix_list_g = [] - std_pix_list_g = [] - mean_pix_list_b = [] - std_pix_list_b = [] - with tqdm( - total=len(adata), - desc="calculate morphological stats", - bar_format="{l_bar}{bar} [ time left: {remaining} ]", - ) as pbar: - for tile in adata.obs["tile_path"]: - ( - n_nuclei, - nuclei_total_area, - nuclei_mean_area, - nuclei_std_area, - eccentricity, - solidity, - mean_pix_r, - std_pix_r, - mean_pix_g, - std_pix_g, - mean_pix_b, - std_pix_b, - ) = _calculate_morph_stats(tile) - n_nuclei_list.append(n_nuclei) - nuclei_total_area_list.append(nuclei_total_area) - nuclei_mean_area_list.append(nuclei_mean_area) - nuclei_std_area_list.append(nuclei_std_area) - eccentricity_list.append(eccentricity) - mean_pix_list_r.append(mean_pix_r) - std_pix_list_r.append(std_pix_r) - mean_pix_list_g.append(mean_pix_g) - std_pix_list_g.append(std_pix_g) - mean_pix_list_b.append(mean_pix_b) - std_pix_list_b.append(std_pix_b) - pbar.update(1) - - adata.obs["n_nuclei"] = n_nuclei_list - adata.obs["nuclei_total_area"] = nuclei_total_area_list - adata.obs["nuclei_mean_area"] = nuclei_mean_area_list - adata.obs["nuclei_std_area"] = nuclei_std_area_list - adata.obs["eccentricity"] = eccentricity_list - adata.obs["mean_pix_r"] = mean_pix_list_r - adata.obs["std_pix_r"] = std_pix_list_r - adata.obs["mean_pix_g"] = mean_pix_list_g - adata.obs["std_pix_g"] = std_pix_list_g - adata.obs["mean_pix_b"] = mean_pix_list_b - adata.obs["std_pix_b"] = std_pix_list_b - adata.obs["nuclei_total_area_per_tile"] = adata.obs["nuclei_total_area"] / 299 / 299 - return adata if copy else None - - -def _calculate_morph_stats(tile_path): - imInput = skimage.io.imread(tile_path) - stain_color_map = htk.preprocessing.color_deconvolution.stain_color_map - stains = [ - "hematoxylin", # nuclei stain - "eosin", # cytoplasm stain - "null", - ] # set to null if input contains only two stains - w_est = htk.preprocessing.color_deconvolution.rgb_separate_stains_macenko_pca( - imInput, 255 - ) - - # Perform color deconvolution - deconv_result = htk.preprocessing.color_deconvolution.color_deconvolution( - imInput, w_est, 255 - ) - - channel = htk.preprocessing.color_deconvolution.find_stain_index( - stain_color_map[stains[0]], w_est - ) - im_nuclei_stain = deconv_result.Stains[:, :, channel] - - thresh = skimage.filters.threshold_otsu(im_nuclei_stain) - # im_fgnd_mask = im_nuclei_stain < thresh - im_fgnd_mask = sp.ndimage.morphology.binary_fill_holes( - im_nuclei_stain < 0.8 * thresh - ) - - distance = ndi.distance_transform_edt(im_fgnd_mask) - coords = peak_local_max(distance, footprint=np.ones((3, 3)), labels=im_fgnd_mask) - mask = np.zeros(distance.shape, dtype=bool) - mask[tuple(coords.T)] = True - markers, _ = ndi.label(mask) - - labels = watershed(im_nuclei_stain, markers, mask=im_fgnd_mask) - min_nucleus_area = 60 - im_nuclei_seg_mask = htk.segmentation.label.area_open( - labels, min_nucleus_area - ).astype(np.int64) - - # compute nuclei properties - objProps = skimage.measure.regionprops(im_nuclei_seg_mask) - - # # Display results - # plt.figure(figsize=(20, 10)) - # plt.imshow(skimage.color.label2rgb(im_nuclei_seg_mask, im_nuclei_stain, bg_label=0), - # origin='upper') - # plt.title('Nuclei segmentation mask overlay') - # plt.savefig("./Nuclei_segmentation_tiles_bc_wh/{}.png".format(tile_path.split("/")[-1].split(".")[0]), dpi=300) - - n_nuclei = len(objProps) - - nuclei_total_area = sum(map(lambda x: x.area, objProps)) - nuclei_mean_area = np.mean(list(map(lambda x: x.area, objProps))) - nuclei_std_area = np.std(list(map(lambda x: x.area, objProps))) - - mean_pix = imInput.reshape(3, -1).mean(1) - std_pix = imInput.reshape(3, -1).std(1) - - eccentricity = np.mean(list(map(lambda x: x.eccentricity, objProps))) - - solidity = np.mean(list(map(lambda x: x.solidity, objProps))) - - return ( - n_nuclei, - nuclei_total_area, - nuclei_mean_area, - nuclei_std_area, - eccentricity, - solidity, - mean_pix[0], - std_pix[0], - mean_pix[1], - std_pix[1], - mean_pix[2], - std_pix[2], - ) diff --git a/stlearn/logging.py b/stlearn/logging.py index 674e7f77..6b4d0ce8 100644 --- a/stlearn/logging.py +++ b/stlearn/logging.py @@ -1,37 +1,47 @@ -"""Logging and Profiling -""" +"""Logging and Profiling""" + import logging -from functools import update_wrapper, partial -from logging import CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET +from collections.abc import Mapping from datetime import datetime, timedelta, timezone -from typing import Optional +from functools import partial, update_wrapper +from logging import CRITICAL, DEBUG, ERROR, INFO, WARNING +from typing import Any, overload import anndata.logging - HINT = (INFO + DEBUG) // 2 logging.addLevelName(HINT, "HINT") +class CustomLogRecord(logging.LogRecord): + """Custom root logger that maintains compatibility with standard logging + interface.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.time_passed: timedelta | None = None + self.deep: str | None = None + + class _RootLogger(logging.RootLogger): def __init__(self, level): super().__init__(level) self.propagate = False _RootLogger.manager = logging.Manager(self) - def log( + def log_with_timing( self, level: int, msg: str, *, - extra: Optional[dict] = None, - time: datetime = None, - deep: Optional[str] = None, + extra: dict | None = None, + time: datetime | None = None, + deep: str | None = None, ) -> datetime: from . import settings now = datetime.now(timezone.utc) - time_passed: timedelta = None if time is None else now - time + time_passed: timedelta | None = None if time is None else now - time extra = { **(extra or {}), "deep": deep if settings.verbosity.level < level else None, @@ -40,23 +50,124 @@ def log( super().log(level, msg, extra=extra) return now - def critical(self, msg, *, time=None, deep=None, extra=None) -> datetime: - return self.log(CRITICAL, msg, time=time, deep=deep, extra=extra) + def _handle_enhanced_logging( + self, level: int, msg, *args, **kwargs + ) -> datetime | None: + """Handle logging with enhanced features (timing, deep info) or fall back to + standard logging.""" + if "time" in kwargs or "deep" in kwargs or "extra" in kwargs: + # Extract enhanced arguments + time_arg = kwargs.pop("time", None) + deep_arg = kwargs.pop("deep", None) + extra_arg = kwargs.pop("extra", None) + + # Format message if there are remaining args + if args or kwargs: + formatted_msg = msg % args if args else msg + else: + formatted_msg = msg - def error(self, msg, *, time=None, deep=None, extra=None) -> datetime: - return self.log(ERROR, msg, time=time, deep=deep, extra=extra) + return self.log_with_timing( + level, formatted_msg, time=time_arg, deep=deep_arg, extra=extra_arg + ) + else: + super().log(level, msg, *args, **kwargs) + return None - def warning(self, msg, *, time=None, deep=None, extra=None) -> datetime: - return self.log(WARNING, msg, time=time, deep=deep, extra=extra) + def hint( + self, + msg, + *, + time: datetime | None = None, + deep: str | None = None, + extra: dict[str, Any] | None = None, + ) -> datetime: + return self.log_with_timing(HINT, msg, time=time, deep=deep, extra=extra) - def info(self, msg, *, time=None, deep=None, extra=None) -> datetime: - return self.log(INFO, msg, time=time, deep=deep, extra=extra) + @overload + def debug( + self, + msg: object, + *args: object, + exc_info: bool | tuple | BaseException | None = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + ) -> None: ... + + @overload + def debug(self, msg, *args, **kwargs): ... + + def debug(self, msg, *args, **kwargs) -> datetime | None: + return self._handle_enhanced_logging(DEBUG, msg, *args, **kwargs) + + @overload + def info( + self, + msg: object, + *args: object, + exc_info: bool | tuple | BaseException | None = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + ) -> None: ... + + @overload + def info(self, msg, *args, **kwargs): ... + + def info(self, msg, *args, **kwargs) -> datetime | None: + return self._handle_enhanced_logging(INFO, msg, *args, **kwargs) + + @overload + def warning( + self, + msg: object, + *args: object, + exc_info: bool | tuple | BaseException | None = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + ) -> None: ... + + @overload + def warning(self, msg, *args, **kwargs): ... + + def warning(self, msg, *args, **kwargs) -> datetime | None: + return self._handle_enhanced_logging(WARNING, msg, *args, **kwargs) + + @overload + def error( + self, + msg: object, + *args: object, + exc_info: bool | tuple | BaseException | None = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + ) -> None: ... + + @overload + def error(self, msg, *args, **kwargs): ... + + def error(self, msg, *args, **kwargs) -> datetime | None: + return self._handle_enhanced_logging(ERROR, msg, *args, **kwargs) + + @overload + def critical( + self, + msg: object, + *args: object, + exc_info: bool | tuple | BaseException | None = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + ) -> None: ... - def hint(self, msg, *, time=None, deep=None, extra=None) -> datetime: - return self.log(HINT, msg, time=time, deep=deep, extra=extra) + @overload + def critical(self, msg, *args, **kwargs): ... - def debug(self, msg, *, time=None, deep=None, extra=None) -> datetime: - return self.log(DEBUG, msg, time=time, deep=deep, extra=extra) + def critical(self, msg, *args, **kwargs) -> datetime | None: + return self._handle_enhanced_logging(CRITICAL, msg, *args, **kwargs) def _set_log_file(settings): @@ -94,20 +205,24 @@ def format(self, record: logging.LogRecord): self._style._fmt = "--> {message}" elif record.levelno == DEBUG: self._style._fmt = " {message}" - if record.time_passed: - # strip microseconds - if record.time_passed.microseconds: - record.time_passed = timedelta( - seconds=int(record.time_passed.total_seconds()) - ) + + # Handle time_passed if present (should be in extra) + time_passed = getattr(record, "time_passed", None) + if time_passed: + # Strip microseconds + if time_passed.microseconds: + time_passed = timedelta(seconds=int(time_passed.total_seconds())) if "{time_passed}" in record.msg: - record.msg = record.msg.replace( - "{time_passed}", str(record.time_passed) - ) + record.msg = record.msg.replace("{time_passed}", str(time_passed)) else: self._style._fmt += " ({time_passed})" - if record.deep: - record.msg = f"{record.msg}: {record.deep}" + # Add time_passed to record for formatting + record.time_passed = time_passed + + deep = getattr(record, "deep", None) + if deep: + record.msg = f"{record.msg}: {deep}" + result = logging.Formatter.format(self, record) self._style._fmt = format_orig return result @@ -116,7 +231,6 @@ def format(self, record: logging.LogRecord): print_memory_usage = anndata.logging.print_memory_usage get_memory_usage = anndata.logging.get_memory_usage - _DEPENDENCIES_NUMERICS = [ "anndata", # anndata actually shouldn't, but as long as it's in development "umap", @@ -129,7 +243,6 @@ def format(self, record: logging.LogRecord): "louvain", ] - _DEPENDENCIES_PLOTTING = ["matplotlib", "seaborn"] @@ -167,21 +280,22 @@ def print_version_and_date(): from ._settings import settings print( - f"Running Scanpy {__version__}, " f"on {datetime.now():%Y-%m-%d %H:%M}.", + f"Running Scanpy {__version__}, on {datetime.now():%Y-%m-%d %H:%M}.", file=settings.logfile, ) def _copy_docs_and_signature(fn): + """Copy documentation and signature from function.""" return partial(update_wrapper, wrapped=fn, assigned=["__doc__", "__annotations__"]) def error( msg: str, *, - time: datetime = None, - deep: Optional[str] = None, - extra: Optional[dict] = None, + time: datetime | None = None, + deep: str | None = None, + extra: dict[str, Any] | None = None, ) -> datetime: """\ Log message with specific level and return current time. @@ -196,39 +310,67 @@ def error( If `msg` contains `{time_passed}`, the time difference is instead inserted at that position. deep - If the current verbosity is higher than the log function’s level, + If the current verbosity is higher than the log function's level, this gets displayed as well extra Additional values you can specify in `msg` like `{time_passed}`. """ from ._settings import settings - return settings._root_logger.error(msg, time=time, deep=deep, extra=extra) + result = settings._root_logger.error(msg, time=time, deep=deep, extra=extra) + return result or datetime.now(timezone.utc) @_copy_docs_and_signature(error) -def warning(msg, *, time=None, deep=None, extra=None) -> datetime: +def warning( + msg: str, + *, + time: datetime | None = None, + deep: str | None = None, + extra: dict[str, Any] | None = None, +) -> datetime: from ._settings import settings - return settings._root_logger.warning(msg, time=time, deep=deep, extra=extra) + result = settings._root_logger.warning(msg, time=time, deep=deep, extra=extra) + return result or datetime.now(timezone.utc) @_copy_docs_and_signature(error) -def info(msg, *, time=None, deep=None, extra=None) -> datetime: +def info( + msg: str, + *, + time: datetime | None = None, + deep: str | None = None, + extra: dict[str, Any] | None = None, +) -> datetime: from ._settings import settings - return settings._root_logger.info(msg, time=time, deep=deep, extra=extra) + result = settings._root_logger.info(msg, time=time, deep=deep, extra=extra) + return result or datetime.now(timezone.utc) @_copy_docs_and_signature(error) -def hint(msg, *, time=None, deep=None, extra=None) -> datetime: +def hint( + msg: str, + *, + time: datetime | None = None, + deep: str | None = None, + extra: dict[str, Any] | None = None, +) -> datetime: from ._settings import settings return settings._root_logger.hint(msg, time=time, deep=deep, extra=extra) @_copy_docs_and_signature(error) -def debug(msg, *, time=None, deep=None, extra=None) -> datetime: +def debug( + msg: str, + *, + time: datetime | None = None, + deep: str | None = None, + extra: dict[str, Any] | None = None, +) -> datetime: from ._settings import settings - return settings._root_logger.debug(msg, time=time, deep=deep, extra=extra) + result = settings._root_logger.debug(msg, time=time, deep=deep, extra=extra) + return result or datetime.now(timezone.utc) diff --git a/stlearn/pl.py b/stlearn/pl.py deleted file mode 100644 index 7f7577d4..00000000 --- a/stlearn/pl.py +++ /dev/null @@ -1,26 +0,0 @@ -from .plotting.gene_plot import gene_plot -from .plotting.gene_plot import gene_plot_interactive -from .plotting.feat_plot import feat_plot -from .plotting.cluster_plot import cluster_plot -from .plotting.cluster_plot import cluster_plot_interactive -from .plotting.subcluster_plot import subcluster_plot -from .plotting.non_spatial_plot import non_spatial_plot -from .plotting.deconvolution_plot import deconvolution_plot -from .plotting.stack_3d_plot import stack_3d_plot -from .plotting import trajectory -from .plotting.QC_plot import QC_plot -from .plotting.cci_plot import het_plot - -# from .plotting.cci_plot import het_plot_interactive -from .plotting.cci_plot import lr_plot_interactive, spatialcci_plot_interactive -from .plotting.cci_plot import grid_plot -from .plotting.cci_plot import lr_diagnostics, lr_n_spots, lr_summary, lr_go -from .plotting.cci_plot import lr_plot, lr_result_plot -from .plotting.cci_plot import ( - ccinet_plot, - cci_map, - lr_cci_map, - lr_chord_plot, - cci_check, -) -from .plotting.mask_plot import plot_mask diff --git a/stlearn/plotting/QC_plot.py b/stlearn/pl/QC_plot.py similarity index 96% rename from stlearn/plotting/QC_plot.py rename to stlearn/pl/QC_plot.py index 186d542b..a2224857 100644 --- a/stlearn/plotting/QC_plot.py +++ b/stlearn/pl/QC_plot.py @@ -1,13 +1,12 @@ -from matplotlib import pyplot as plt import numpy as np -from typing import Optional, Union from anndata import AnnData +from matplotlib import pyplot as plt def QC_plot( adata: AnnData, - library_id: str = None, - name: str = None, + name: str, + library_id: str | None = None, data_alpha: float = 0.8, tissue_alpha: float = 1.0, cmap: str = "Spectral_r", @@ -18,8 +17,8 @@ def QC_plot( cropped: bool = True, margin: int = 100, dpi: int = 150, - output: str = None, -) -> Optional[AnnData]: + output: str | None = None, +) -> None: """\ QC plot for sptial transcriptomics data. diff --git a/stlearn/pl/__init__.py b/stlearn/pl/__init__.py new file mode 100644 index 00000000..c24e63a2 --- /dev/null +++ b/stlearn/pl/__init__.py @@ -0,0 +1,77 @@ +# Import individual functions from modules +from .cci_plot import ( + cci_check, + cci_map, + ccinet_plot, + grid_plot, + het_plot, + lr_cci_map, + lr_chord_plot, + lr_diagnostics, + lr_go, + lr_n_spots, + lr_plot, + lr_plot_interactive, + lr_result_plot, + lr_summary, + spatialcci_plot_interactive, +) +from .cluster_plot import cluster_plot, cluster_plot_interactive +from .deconvolution_plot import deconvolution_plot +from .feat_plot import feat_plot +from .gene_plot import gene_plot, gene_plot_interactive +from .mask_plot import plot_mask +from .non_spatial_plot import non_spatial_plot +from .QC_plot import QC_plot +from .stack_3d_plot import stack_3d_plot +from .subcluster_plot import subcluster_plot + +# Import trajectory functions +from .trajectory import ( + DE_transition_plot, + check_trajectory, + local_plot, + pseudotime_plot, + transition_markers_plot, + tree_plot, + tree_plot_simple, +) + +__all__ = [ + # CCI plot functions + "cci_check", + "cci_map", + "ccinet_plot", + "grid_plot", + "het_plot", + "lr_cci_map", + "lr_chord_plot", + "lr_diagnostics", + "lr_go", + "lr_n_spots", + "lr_plot", + "lr_plot_interactive", + "lr_result_plot", + "lr_summary", + "spatialcci_plot_interactive", + # Other plot functions + "cluster_plot", + "cluster_plot_interactive", + "deconvolution_plot", + "feat_plot", + "gene_plot", + "gene_plot_interactive", + "plot_mask", + "non_spatial_plot", + "QC_plot", + "stack_3d_plot", + "subcluster_plot", + # Trajectory functions + "pseudotime_plot", + "local_plot", + "tree_plot", + "transition_markers_plot", + "DE_transition_plot", + "tree_plot_simple", + "check_trajectory", +] diff --git a/stlearn/plotting/_docs.py b/stlearn/pl/_docs.py similarity index 89% rename from stlearn/plotting/_docs.py rename to stlearn/pl/_docs.py index f9a66165..dbf36984 100644 --- a/stlearn/plotting/_docs.py +++ b/stlearn/pl/_docs.py @@ -6,7 +6,8 @@ figsize Figure size with the format (width,height). cmap - Color map to use for continous variables or discretes variables (e.g. viridis, Set1,...). + Color map to use for continous variables or discretes variables (e.g. viridis, + Set1,...). use_label Key for the label use in `adata.obs` (e.g. `leiden`, `louvain`,...). list_clusters @@ -39,7 +40,8 @@ doc_gene_plot = """\ gene_symbols - Single gene (str) or multiple genes (list) that user wants to display. It should be available in `adata.var_names`. + Single gene (str) or multiple genes (list) that user wants to display. It should + be available in `adata.var_names`. threshold Threshold to display genes in the figure. method @@ -83,23 +85,28 @@ sig_spots Whether to filter to significant spots or not. use_label - Label to use for the inner points, can be in adata.obs or in the lr stats of adata.uns['per_lr_results'][lr].columns + Label to use for the inner points, can be in adata.obs or in the lr stats of + adata.uns['per_lr_results'][lr].columns use_mix - The deconvolution/label_transfer results to use for visualising pie charts in the inner point, not currently implimented. + The deconvolution/label_transfer results to use for visualising pie charts in + the inner point, not currently implimented. outer_mode - Either 'binary', 'continuous', or None; controls how ligand-receptor expression shown (or not shown). + Either 'binary', 'continuous', or None; controls how ligand-receptor expression + shown (or not shown). l_cmap matplotlib cmap controlling ligand continous expression. r_cmap matplotlib cmap controlling receptor continuous expression. lr_cmap - matplotlib cmap controlling the ligand receptor binary expression, but have atleast 4 colours. + matplotlib cmap controlling the ligand receptor binary expression, but have + at least 4 colours. inner_cmap matplotlib cmap controlling the inner point colours. inner_size_prop multiplier which controls size of inner points. middle_size_prop - Multiplier which controls size of middle point (only relevant when outer_mode='continuous') + Multiplier which controls size of middle point (only relevant when + outer_mode='continuous') outer_size_prop Multiplier which controls size of the outter point. pt_scale @@ -109,12 +116,14 @@ show_image Whether to show the background H&E or not. kwargs - Extra arguments parsed to the other plotting functions such as gene_plot, cluster_plot, &/or het_plot. + Extra arguments parsed to the other plotting functions such as gene_plot, + cluster_plot, &/or het_plot. """ doc_het_plot = """\ use_het - Single gene (str) or multiple genes (list) that user wants to display. It should be available in `adata.var_names`. + Single gene (str) or multiple genes (list) that user wants to display. It should + be available in `adata.var_names`. contour Option to show the contour plot. step_size diff --git a/stlearn/plotting/cci_plot.py b/stlearn/pl/cci_plot.py similarity index 80% rename from stlearn/plotting/cci_plot.py rename to stlearn/pl/cci_plot.py index 06d997f8..cc8e0918 100644 --- a/stlearn/plotting/cci_plot.py +++ b/stlearn/pl/cci_plot.py @@ -1,91 +1,86 @@ -from matplotlib import pyplot as plt -from matplotlib.axes import Axes -from matplotlib.figure import Figure -import matplotlib -import pandas as pd -import numpy as np -import networkx as nx +import importlib import math -import matplotlib.patches as patches -from numba.typed import List -import seaborn as sns import sys -from anndata import AnnData -from typing import Optional, Union +from typing import ( + Any, + Optional, # Special +) -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes +import matplotlib +import matplotlib as plt +import matplotlib.axes as plt_axis +import matplotlib.figure as plt_figure +import matplotlib.patches as patches +import networkx as nx +import numpy as np +import pandas as pd +from anndata import AnnData +from bokeh.io import output_notebook +from bokeh.plotting import show +from scipy.stats import gaussian_kde -import warnings +import stlearn.pl.cci_plot_helpers as cci_hs +from stlearn.pl.utils import get_colors -from .classes import CciPlot, LrResultPlot -from .classes_bokeh import BokehSpatialCciPlot, BokehLRPlot -from ._docs import doc_spatial_base_plot, doc_het_plot, doc_lr_plot -from ..utils import Empty, _empty, _AxesSubplot, _docs_params -from .utils import get_cmap, check_cmap, get_colors -from .cluster_plot import cluster_plot -from .deconvolution_plot import deconvolution_plot -from .gene_plot import gene_plot -from stlearn.plotting.utils import get_colors -import stlearn.plotting.cci_plot_helpers as cci_hs +from ..utils import _docs_params +from ._docs import doc_het_plot, doc_spatial_base_plot from .cci_plot_helpers import ( - get_int_df, - add_arrows, - create_flat_df, _box_map, chordDiagram, + create_flat_df, + get_int_df, ) -from scipy.stats import gaussian_kde - -import importlib +from .classes import CciPlot, LrResultPlot +from .classes_bokeh import BokehLRPlot, BokehSpatialCciPlot +from .cluster_plot import cluster_plot +from .gene_plot import gene_plot +from .utils import check_cmap, get_cmap importlib.reload(cci_hs) -from bokeh.io import push_notebook, output_notebook -from bokeh.plotting import show #### Functions for visualising the overall LR results and diagnostics. def lr_diagnostics( adata, - highlight_lrs: list = None, - n_top: int = None, + highlight_lrs: list | None = None, + n_top: int | None = None, color0: str = "turquoise", color1: str = "plum", figsize: tuple = (10, 4), - lr_text_fp: dict = None, + lr_text_fp: dict | None = None, show: bool = True, ): - """Diagnostic plot looking at relationship between technical features of lrs and lr rank. - Two plots generated: left is the average of the median for nonzero - expressing spots for both the ligand and the receptor on the y-axis, & - LR-rank by no. of significant spots on the x-axis. Right is the average - of the proportion of zeros for the ligand and receptor gene on teh y-axis. + """Diagnostic plot looking at relationship between technical features of lrs and + lr rank. Two plots generated: left is the average of the median for nonzero + expressing spots for both the ligand and the receptor on the y-axis, & + LR-rank by no. of significant spots on the x-axis. Right is the average + of the proportion of zeros for the ligand and receptor gene on teh y-axis. Parameters ---------- - adata: AnnData + adata (AnnData): The data object on which st.tl.cci.run has been applied. - highlight_lrs: list - List of LRs to highlight, will add text and change point color for these LR pairs. - n_top: int + highlight_lrs (list): + List of LRs to highlight, will add text and change point color for these + LR pairs. + n_top (int): The number of LRs to display. If None shows all. - color0: str + color0 (str): The color of the nonzero-median scatter plot. - lr_text_fp: dict + lr_text_fp (dict): Font dict for the LR text if highlight_lrs not None. - axis_text_fp: dict + axis_text_fp (dict): Font dict for the axis text labels. Returns ------- Figure, Axes Figure and axes of the plot, if show=False. """ - if type(n_top) == type(None): + if n_top is None: n_top = adata.uns["lr_summary"].shape[0] - fig, axes = plt.subplots(ncols=2, figsize=figsize) + fig, axes = plt.pyplot.subplots(ncols=2, figsize=figsize) cci_hs.lr_scatter( adata, "nonzero-median", @@ -107,7 +102,7 @@ def lr_diagnostics( show=False, ) if show: - plt.show() + plt.pyplot.show() else: return fig, axes @@ -115,14 +110,14 @@ def lr_diagnostics( def lr_summary( adata, n_top: int = 50, - highlight_lrs: list = None, + highlight_lrs: list | None = None, y: str = "n_spots_sig", color: str = "gold", - figsize: tuple = None, + figsize: tuple | None = None, highlight_color: str = "red", max_text: int = 50, - lr_text_fp: dict = None, - ax: Axes = None, + lr_text_fp: dict | None = None, + ax: plt_axis.Axes | None = None, show: bool = True, ): """Plotting the top LRs ranked by number of significant spots. @@ -137,7 +132,7 @@ def lr_summary( A list of LRs to highlight on the plot, will added text and change color of points for these LRs. Useful for highlighting LRs of interest. y: str - The way to rank the LRs, default is by the no. of signifcant spots, + The way to rank the LRs, default is by the no. of significant spots, but can be any column in adata.uns['lr_summary']. color: str The color of the points. @@ -183,8 +178,8 @@ def lr_summary( def lr_n_spots( adata, n_top: int = 100, - font_dict: dict = None, - xtick_dict: dict = None, + font_dict: dict | None = None, + xtick_dict: dict | None = None, bar_width: float = 1, max_text: int = 50, non_sig_color: str = "dodgerblue", @@ -227,16 +222,16 @@ def lr_n_spots( Fig, Axes Figure & axes with the plot draw on; only if show=False. Else None. """ - if type(font_dict) == type(None): + if font_dict is None: font_dict = {"weight": "bold", "size": 12} - if type(xtick_dict) == type(None): + if xtick_dict is None: xtick_dict = {"fontweight": "bold", "rotation": 90, "size": 6} lrs = adata.uns["lr_summary"].index.values[0:n_top] n_sig = adata.uns["lr_summary"].loc[:, "n_spots_sig"].values n_non_sig = adata.uns["lr_summary"].loc[:, "n_spots"].values - n_sig rank = list(range(len(n_sig))) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.pyplot.subplots(figsize=figsize) ax.bar(rank[0:n_top], n_non_sig[0:n_top], bar_width, color=non_sig_color) ax.bar( rank[0:n_top], @@ -257,7 +252,7 @@ def lr_n_spots( ax.spines["right"].set_visible(False) if show: - plt.show() + plt.pyplot.show() else: return fig, ax @@ -265,10 +260,10 @@ def lr_n_spots( def lr_go( adata, n_top: int = 20, - highlight_go: list = None, + highlight_go: list | None = None, figsize=(6, 4), rot: float = 50, - lr_text_fp: dict = None, + lr_text_fp: dict | None = None, highlight_color: str = "yellow", max_text: int = 50, show: bool = True, @@ -335,7 +330,8 @@ def cci_check( tick_size=14, show=True, ): - """Checks relationship between no. of significant CCI-LR interactions and cell type frequency. + """Checks relationship between no. of significant CCI-LR interactions and cell + type frequency. Parameters ---------- @@ -364,12 +360,11 @@ def cci_check( xs = np.array(list(range(len(label_set)))) int_dfs = adata.uns[f"per_lr_cci_{use_label}"] - # Counting!!! # - cell_counts = [] # Cell type frequencies - cell_sigs = [] # Cell type significant interactions + cell_counts: np.ndarray = np.zeros(len(label_set), dtype=int) + cell_sigs: np.ndarray = np.zeros(len(label_set), dtype=int) for j, label in enumerate(label_set): counts = sum(labels == label) - cell_counts.append(counts) + cell_counts[j] = counts int_count = 0 for lr in int_dfs: @@ -381,18 +376,16 @@ def cci_check( # prevent double counts int_count -= int_bool[label_index, label_index] - cell_sigs.append(int_count) + cell_sigs[j] = int_count - cell_counts = np.array(cell_counts) - cell_sigs = np.array(cell_sigs) - order = np.argsort(cell_counts) + order: np.ndarray = np.argsort(cell_counts) cell_counts = cell_counts[order] cell_sigs = cell_sigs[order] colors = np.array(colors)[order] label_set = label_set[order] # Plotting bar plot # - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.pyplot.subplots(figsize=figsize) ax.bar(xs, cell_counts, color=colors) text_dist = max(cell_counts) * 0.015 fontdict = {"fontweight": "bold", "fontsize": cell_label_size} @@ -420,7 +413,7 @@ def cci_check( fig.tight_layout() if show: - plt.show() + plt.pyplot.show() else: return fig, ax, ax2 @@ -431,28 +424,28 @@ def lr_result_plot( use_lr: Optional["str"] = None, use_result: Optional["str"] = "lr_sig_scores", # plotting param - title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - zoom_coord: Optional[float] = None, - crop: Optional[bool] = True, - margin: Optional[float] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + title: str | None = None, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + ax: plt_axis.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + crop: bool = True, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, contour: bool = False, - step_size: Optional[int] = None, - vmin: float = None, - vmax: float = None, + step_size: int | None = None, + vmin: float | None = None, + vmax: float | None = None, ): """Plots the per spot statistics for given LR. @@ -481,6 +474,8 @@ def lr_result_plot( Whether to show axis or not. show_image: bool Whether to plot the image. + zoom_coord: Tuple[float, float, float, float] + Bounding box of plot. show_color_bar: bool Whether to show the color bar. crop: bool @@ -547,7 +542,7 @@ def lr_plot( lr: str, min_expr: float = 0, sig_spots=True, - use_label: str = None, + use_label: str | None = None, outer_mode: str = "continuous", l_cmap=None, r_cmap=None, @@ -560,19 +555,19 @@ def lr_plot( title="", show_image: bool = True, show_arrows: bool = False, - fig: Figure = None, - ax: Axes = None, + fig_or_none: plt_figure.Figure | None = None, + ax_or_none: plt_axis.Axes | None = None, arrow_head_width: float = 4, arrow_width: float = 0.001, - arrow_cmap: str = None, - arrow_vmax: float = None, + arrow_cmap: str | None = None, + arrow_vmax: float | None = None, sig_cci: bool = False, - lr_colors: dict = None, + lr_colors: dict | None = None, figsize: tuple = (6.4, 4.8), - use_mix: bool = None, + use_mix: bool | None = None, # plotting params **kwargs, -) -> Optional[AnnData]: +) -> None: """Creates different kinds of spatial visualisations for the LR analysis results. To see combinations of parameters refer to stLearn CCI tutorial. @@ -624,9 +619,9 @@ def lr_plot( Whether to show the background image. show_arrows: bool Whether to plot arrows indicating interactions between spots. - fig: Figure + fig_or_none: Figure Figure to draw on. - ax: Axes + ax_or_none: Axes Axes to draw on. arrow_head_width: float Width of arrow head; only if show_arrows is true. @@ -643,7 +638,7 @@ def lr_plot( interactions; particularly relevant when plotting the arrows. lr_colors: dict Specifies the colors of the LRs when plotting with outer_mode='binary'; - structures is {'l': color, 'r': color, 'lr': color, '': color}; + structures is {'ligand': color, 'receptor': color, 'lr': color, '': color}; the last key-value indicates colour for spots not expressing the ligand or receptor. figsize: tuple @@ -653,7 +648,7 @@ def lr_plot( """ # Input checking # - l, r = lr.split("_") + ligand, receptor = lr.split("_") ran_lr = "lr_summary" in adata.uns ran_sig = False if not ran_lr else "n_spots_sig" in adata.uns["lr_summary"].columns if ran_lr and lr in adata.uns["lr_summary"].index: @@ -672,7 +667,7 @@ def lr_plot( elif sig_spots and not lr_sig: raise Exception( - "LR has no significant spots, to visualise anyhow set" "sig_spots=False" + "LR has no significant spots, to visualise anyhow setsig_spots=False" ) # Making sure have run_cci first with respective labelling # @@ -700,35 +695,32 @@ def lr_plot( "lr_sig_scores", ] - if type(use_mix) != type(None) and use_mix not in adata.uns: + if use_mix is not None and use_mix not in adata.uns: raise Exception( - f"Specified use_mix, but no deconvolution results added " + "Specified use_mix, but no deconvolution results added " "to adata.uns matching the use_mix ({use_mix}) key." ) elif ( - type(use_label) != type(None) - and use_label in lr_use_labels - and ran_sig - and not lr_sig + use_label is not None and use_label in lr_use_labels and ran_sig and not lr_sig ): raise Exception( - f"Since use_label refers to lr stats & ran permutation testing, " - f"LR needs to be significant to view stats." + "Since use_label refers to lr stats & ran permutation testing, " + "LR needs to be significant to view stats." ) elif ( - type(use_label) != type(None) + use_label is not None and use_label not in adata.obs.keys() and use_label not in lr_use_labels ): raise Exception( - f"use_label must be in adata.obs or " f"one of lr stats: {lr_use_labels}." + f"use_label must be in adata.obs or one of lr stats: {lr_use_labels}." ) out_options = ["binary", "continuous", None] if outer_mode not in out_options: raise Exception(f"{outer_mode} should be one of {out_options}") - if l not in adata.var_names or r not in adata.var_names: + if ligand not in adata.var_names or receptor not in adata.var_names: raise Exception("L or R not found in adata.var_names.") # Whether to show just the significant spots or all spots @@ -741,21 +733,21 @@ def lr_plot( adata_full = adata # Dealing with the axis # - if type(fig) == type(None) or type(ax) == type(None): - fig, ax = plt.subplots(figsize=figsize) + if fig_or_none is None or ax_or_none is None: + fig, ax = plt.pyplot.subplots(figsize=figsize) expr = adata.to_df() - l_expr = expr.loc[:, l].values - r_expr = expr.loc[:, r].values + l_expr = expr.loc[:, ligand].values + r_expr = expr.loc[:, receptor].values # Adding binary points of the ligand/receptor pair # if outer_mode == "binary": l_bool, r_bool = l_expr > min_expr, r_expr > min_expr lr_binary_labels = [] for i in range(len(l_bool)): if l_bool[i] and not r_bool[i]: - lr_binary_labels.append(l) + lr_binary_labels.append(ligand) elif not l_bool[i] and r_bool[i]: - lr_binary_labels.append(r) + lr_binary_labels.append(receptor) elif l_bool[i] and r_bool[i]: lr_binary_labels.append(lr) elif not l_bool[i] and not r_bool[i]: @@ -765,12 +757,12 @@ def lr_plot( ).astype("category") adata.obs[f"{lr}_binary_labels"] = lr_binary_labels - if type(lr_cmap) == type(None): + if lr_cmap is None: lr_cmap = "default" # This gets ignored due to setting colours below - if type(lr_colors) == type(None): + if lr_colors is None: lr_colors = { - l: matplotlib.colors.to_hex("r"), - r: matplotlib.colors.to_hex("limegreen"), + ligand: matplotlib.colors.to_hex("r"), + receptor: matplotlib.colors.to_hex("limegreen"), lr: matplotlib.colors.to_hex("b"), "": "#836BC6", # Neutral color in H&E images. } @@ -797,13 +789,13 @@ def lr_plot( # Showing continuous gene expression of the LR pair # elif outer_mode == "continuous": - if type(l_cmap) == type(None): + if l_cmap is None: l_cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "lcmap", [(0, 0, 0), (0.5, 0, 0), (0.75, 0, 0), (1, 0, 0)] ) else: l_cmap = check_cmap(l_cmap) - if type(r_cmap) == type(None): + if r_cmap is None: r_cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "rcmap", [(0, 0, 0), (0, 0.5, 0), (0, 0.75, 0), (0, 1, 0)] ) @@ -812,10 +804,10 @@ def lr_plot( gene_plot( adata, - gene_symbols=l, + gene_symbols=ligand, size=outer_size_prop * pt_scale, cmap=l_cmap, - color_bar_label=l, + color_bar_label=ligand, ax=ax, fig=fig, crop=False, @@ -824,10 +816,10 @@ def lr_plot( ) gene_plot( adata, - gene_symbols=r, + gene_symbols=receptor, size=middle_size_prop * pt_scale, cmap=r_cmap, - color_bar_label=r, + color_bar_label=receptor, ax=ax, fig=fig, crop=False, @@ -836,11 +828,9 @@ def lr_plot( ) # Adding the cell type labels # - if type(use_label) != type(None): + if use_label is not None: if use_label in lr_use_labels: - inner_cmap = inner_cmap if type(inner_cmap) != type(None) else "copper" - # adata.obsm[f'{lr}_{use_label}'] = adata.uns['per_lr_results'][ - # lr].loc[adata.obs_names,use_label].values + inner_cmap = inner_cmap if inner_cmap is not None else "copper" lr_result_plot( adata, use_lr=lr, @@ -853,7 +843,7 @@ def lr_plot( **kwargs, ) else: - inner_cmap = inner_cmap if type(inner_cmap) != type(None) else "default" + inner_cmap = inner_cmap if inner_cmap is not None else "default" cluster_plot( adata, use_label=use_label, @@ -870,8 +860,8 @@ def lr_plot( # Adding in labels which show the interactions between signicant spots & # neighbours if show_arrows: - l_expr = adata_full[:, l].X.toarray()[:, 0] - r_expr = adata_full[:, r].X.toarray()[:, 0] + l_expr = adata_full[:, ligand].X.toarray()[:, 0] + r_expr = adata_full[:, receptor].X.toarray()[:, 0] if sig_cci: int_df = adata.uns[f"per_lr_cci_{use_label}"][lr] @@ -893,18 +883,6 @@ def lr_plot( arrow_cmap, arrow_vmax, ) - - # Cropping # - # if crop: - # x0, x1 = ax.get_xlim() - # y0, y1 = ax.get_ylim() - # x_margin, y_margin = (x1-x0)*margin_ratio, (y1-y0)*margin_ratio - # print(x_margin, y_margin) - # print(x0, x1, y0, y1) - # ax.set_xlim(x0 - x_margin, x1 + x_margin) - # ax.set_ylim(y0 - y_margin, y1 + y_margin) - # #ax.set_ylim(ax.get_ylim()[::-1]) - fig.suptitle(title) @@ -914,33 +892,33 @@ def lr_plot( def het_plot( adata: AnnData, # plotting param - title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - zoom_coord: Optional[float] = None, - crop: Optional[bool] = True, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + title: str | None = None, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: list | None = None, + ax: plt_axis.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + crop: bool = True, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, # cci_rank param - use_het: Optional[str] = "het", + use_het: str = "het", contour: bool = False, - step_size: Optional[int] = None, - vmin: float = None, - vmax: float = None, -) -> Optional[AnnData]: + step_size: int | None = None, + vmin: float | None = None, + vmax: float | None = None, +) -> None: """\ Allows the visualization of significant cell-cell interaction as the values of dot points or contour in the Spatial @@ -955,7 +933,7 @@ def het_plot( Examples ------------------------------------- >>> import stlearn as st - >>> adata = st.datasets.example_bcba() + >>> adata = st.datasets.visium_sge(sample_id="V1_Breast_Cancer_Block_A_Section_1") >>> pvalues = "lr_pvalues" >>> st.pl.gene_plot(adata, use_het = pvalues) @@ -998,8 +976,8 @@ def het_plot( def ccinet_plot( adata: AnnData, use_label: str, - lr: str = None, - pos: dict = None, + lr: str | None = None, + pos: dict | None = None, return_pos: bool = False, cmap: str = "default", font_size: int = 12, @@ -1007,10 +985,10 @@ def ccinet_plot( node_size_scaler: int = 1, min_counts: int = 0, sig_interactions: bool = True, - fig: matplotlib.figure.Figure = None, - ax: matplotlib.axes.Axes = None, + fig_or_none: plt_figure.Figure | None = None, + ax_or_none: plt_axis.Axes | None = None, pad=0.25, - title: str = None, + title_or_none: str | None = None, figsize: tuple = (10, 10), ): """Circular celltype-celltype interaction network based on LR-CCI analysis. @@ -1052,7 +1030,8 @@ def ccinet_plot( Returns ------- pos: dict - Dictionary of positions where the nodes are draw if return_pos is True, useful for consistent layouts. + Dictionary of positions where the nodes are draw if return_pos is True, + useful for consistent layouts. """ cmap, cmap_n = get_cmap(cmap) # Making sure adata in correct state that this function should run # @@ -1061,14 +1040,14 @@ def ccinet_plot( "Need to first call st.tl.run_cci with the equivalnt " "use_label to visualise cell-cell interactions." ) - elif type(lr) != type(None) and lr not in adata.uns[f"per_lr_cci_{use_label}"]: + elif lr is not None and lr not in adata.uns[f"per_lr_cci_{use_label}"]: raise Exception( f"{lr} not found in {f'per_lr_cci_{use_label}'}, " "suggesting no significant interactions." ) # Either plotting overall interactions, or just for a particular LR # - int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title) + int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title_or_none) # Creating the interaction graph # all_set = int_df.index.values int_matrix = int_df.values @@ -1084,7 +1063,7 @@ def ccinet_plot( graph.add_edge(cell_A, cell_B, weight=count) # Determining graph layout, node sizes, & edge colours # - if type(pos) == type(None): + if pos is None: pos = nx.circular_layout(graph) # position the nodes using the layout total = sum(sum(int_matrix)) node_names = list(graph.nodes.keys()) @@ -1122,8 +1101,10 @@ def ccinet_plot( node_colors = np.array(node_colors)[nodes_indices] #### Drawing the graph ##### - if type(fig) == type(None) or type(ax) == type(None): - fig, ax = plt.subplots(figsize=figsize, facecolor=[0.7, 0.7, 0.7, 0.4]) + ax: plt_axis.Axes + fig: plt_figure.Figure + if fig_or_none is None or ax_or_none is None: + fig, ax = plt.pyplot.subplots(figsize=figsize, facecolor=[0.7, 0.7, 0.7, 0.4]) # Adding in the self-loops # z = 55 @@ -1140,7 +1121,7 @@ def ccinet_plot( width=0.3, height=0.025, lw=5, - ec=plt.cm.get_cmap("Blues")(edge_weights[i]), + ec=plt.colormaps.get_cmap("Blues")(edge_weights[i]), angle=angle, theta1=z, theta2=360 - z, @@ -1148,7 +1129,7 @@ def ccinet_plot( ax.add_patch(arc) # Drawing the main components of the graph # - edges = nx.draw_networkx( + nx.draw_networkx( graph, pos, node_size=node_sizes, @@ -1159,11 +1140,11 @@ def ccinet_plot( font_size=font_size, font_weight="bold", edge_color=edge_weights, - edge_cmap=plt.cm.Blues, + edge_cmap=plt.colormaps.get_cmap("Blues"), ax=ax, ) fig.suptitle(title, fontsize=30) - plt.tight_layout() + plt.pyplot.tight_layout() # Adding padding # xlims = ax.get_xlim() @@ -1178,10 +1159,10 @@ def ccinet_plot( def cci_map( adata: AnnData, use_label: str, - lr: str = None, - ax: matplotlib.figure.Axes = None, + lr_or_none: str | None = None, + ax_or_none: plt_axis.Axes | None = None, show: bool = False, - figsize: tuple = None, + figsize_or_none: tuple | None = None, cmap: str = "Spectral_r", sig_interactions: bool = True, title=None, @@ -1195,14 +1176,14 @@ def cci_map( use_label: str Indicates the cell type labels or deconvolution results used for cell-cell interaction counting by LR pairs. - lr: str + lr_or_none: str The LR pair to visualise the sender->receiver interactions for. If None, will use all pairs via adata.uns[f'lr_cci_{use_label}']. - ax: Axes + ax_or_none: Axes Axes on which to plot the heatmap, if None then generates own. show: bool Whether to show the plot or not; if not, then returns ax. - figsize: tuple + figsize_or_none: tuple (width, height), specifies the dimensions of the figure. Only relevant if ax=None. cmap: str @@ -1220,11 +1201,14 @@ def cci_map( """ # Either plotting overall interactions, or just for a particular LR # - int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title) + int_df, title = get_int_df(adata, lr_or_none, use_label, sig_interactions, title) - if type(figsize) == type(None): # Adjust size depending on no. cell types + figsize: tuple + if figsize_or_none is None: # Adjust size depending on no. cell types add = np.array([int_df.shape[0] * 0.1, int_df.shape[0] * 0.05]) figsize = tuple(np.array([6.4, 4.8]) + add) + else: + figsize = figsize_or_none # Rank by total interactions # int_vals = int_df.values @@ -1235,21 +1219,21 @@ def cci_map( # Reformat the interaction df # flat_df = create_flat_df(int_df) - ax = _box_map( + ax: plt_axis.Axes = _box_map( flat_df["x"], flat_df["y"], flat_df["value"].astype(int), - ax=ax, + ax=ax_or_none, figsize=figsize, cmap=cmap, ) ax.set_ylabel("Sender") ax.set_xlabel("Receiver") - plt.suptitle(title) + plt.pyplot.suptitle(title) if show: - plt.show() + plt.pyplot.show() else: return ax @@ -1257,11 +1241,11 @@ def cci_map( def lr_cci_map( adata: AnnData, use_label: str, - lrs: list or np.array = None, + lrs: list | np.ndarray | None = None, n_top_lrs: int = 5, n_top_ccis: int = 15, min_total: int = 0, - ax: matplotlib.figure.Axes = None, + ax_or_none: plt_axis.Axes | None = None, figsize: tuple = (6.48, 4.8), show: bool = False, cmap: str = "Spectral_r", @@ -1279,15 +1263,15 @@ def lr_cci_map( Indicates the cell type labels or deconvolution results used for the cell-cell interaction counting by LR pairs. lrs: list-like - LR pairs to show in the heatmap, if None then top 5 lrs with highest no. - of interactions used from adata.uns['lr_summary']. + LR pairs to show in the heatmap, if None then top 5 lrs with the highest + no. of interactions used from adata.uns['lr_summary']. n_top_lrs: int Indicates how many top lrs to show; is ignored if lrs is not None. n_top_ccis: int Indicates maximum no. of CCIs to show. min_total: int Minimum no. of totals interaction celltypes must have to be shown. - ax: Axes + ax_or_none: Axes Axes on which to draw the heatmap, is generated internally if None. figsize: tuple (width, height), only relevant if ax=None. @@ -1310,7 +1294,7 @@ def lr_cci_map( else: lr_int_dfs = adata.uns[f"per_lr_cci_raw_{use_label}"] - if type(lrs) == type(None): + if lrs is None: lrs = np.array(list(lr_int_dfs.keys())) else: lrs = np.array(lrs) @@ -1353,11 +1337,11 @@ def lr_cci_map( if flat_df.shape[0] == 0 or flat_df.shape[1] == 0: raise Exception(f"No interactions greater than min: {min_total}") - ax = _box_map( + ax: plt_axis.Axes = _box_map( flat_df["x"], flat_df["y"], flat_df["value"].astype(int), - ax=ax, + ax=ax_or_none, cmap=cmap, figsize=figsize, square_scaler=square_scaler, @@ -1367,7 +1351,7 @@ def lr_cci_map( ax.set_xlabel("Cell-cell interaction") if show: - plt.show() + plt.pyplot.show() else: return ax @@ -1375,14 +1359,14 @@ def lr_cci_map( def lr_chord_plot( adata: AnnData, use_label: str, - lr: str = None, + lr: str | None = None, min_ints: int = 2, n_top_ccis: int = 10, cmap: str = "default", sig_interactions: bool = True, label_size: int = 10, label_rotation: float = 0, - title: str = None, + title: str = "", figsize: tuple = (8, 8), show: bool = True, ): @@ -1395,13 +1379,13 @@ def lr_chord_plot( Each cell type has a labelled edge taking up a proportion of the outter circle. Chords connecting cell type edges are coloured by the dominant sending cell. - Each chord linking cell types has an assymetric shape. - For two cell types, A and B, the side of the chord attached to edge A is sized by - the total interactions from B->A, where B is expressing the ligand & A + Each chord linking cell types has an asymmetric shape. + For two cell types, A and B, the side of the chord attached to edge A is + sized by the total interactions from B->A, where B is expressing the ligand & A is expressing the receptor. - Hence, the proportion of a cell type's edge in the chordplot circle + Hence, the proportion of a cell type's edge in the chord plot circle represents the total input signals to that cell type; while the - area of the chordplot circle taken up by the outputted chords from a given + area of the chord plot circle taken up by the outputted chords from a given cell type represents the total output signals from that cell type. Parameters @@ -1419,7 +1403,8 @@ def lr_chord_plot( n_top_ccis: int Maximum no. of CCIs to show, will take the top number of these to display. cmap: str - Cmap to use to get colors if colors not already in adata.uns[f'{use_label}_colors'] + Cmap to use to get colors if colors not already in + adata.uns[f'{use_label}_colors'] sig_interactions: bool Whether to show only significant CCIs or all interaction counts. label_size: str @@ -1443,7 +1428,7 @@ def lr_chord_plot( int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title) int_df = int_df.transpose() - fig = plt.figure(figsize=figsize) + fig = plt.pyplot.figure(figsize=figsize) flux = int_df.values total_ints = flux.sum(axis=1) + flux.sum(axis=0) - flux.diagonal() @@ -1455,7 +1440,7 @@ def lr_chord_plot( all_zero = np.array( [np.all(np.logical_and(flux[i, keep] == 0, flux[keep, i] == 0)) for i in keep] ) - keep = keep[all_zero == False] + keep = keep[~all_zero] if len(keep) == 0: # If we don't keep anything, warn the user print( f"Warning: for {lr} at the current min_ints ({min_ints}), there " @@ -1481,10 +1466,10 @@ def lr_chord_plot( # Retrieving colors of cell types # colors = get_colors(adata, use_label, cmap=cmap, label_set=cell_names) - ax = plt.axes([0, 0, 1, 1]) + ax = plt.pyplot.axes((0, 0, 1, 1)) nodePos = chordDiagram(flux, ax, lim=1.25, colors=colors) ax.axis("off") - prop = dict(fontsize=label_size, ha="center", va="center") + prop: dict[str, Any] = dict(fontsize=label_size, ha="center", va="center") label_rotation_ = label_rotation for i in range(len(cell_names)): x, y = nodePos[i][0:2] @@ -1501,21 +1486,22 @@ def lr_chord_plot( ) # size=10, fig.suptitle(title, fontsize=12, fontweight="bold") if show: - plt.show() + plt.pyplot.show() else: return fig, ax def grid_plot( adata, - use_label: str = None, + use_label: str | None = None, n_row: int = 10, n_col: int = 10, size: int = 1, figsize=(4.5, 4.5), show: bool = False, ): - """Plots grid over the top of spatial data to show how cells will be grouped if gridded. + """Plots grid over the top of spatial data to show how cells will be grouped if + gridded. Parameters ---------- @@ -1541,10 +1527,10 @@ def grid_plot( xmin, xmax = min(xedges), max(xedges) ymin, ymax = min(yedges), max(yedges) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.pyplot.subplots(figsize=figsize) # Plotting the points # - if type(use_label) != type(None): + if use_label is not None: if f"{use_label}_colors" in adata.uns: color_map = {} for i, ct in enumerate(adata.obs[use_label].cat.categories): @@ -1559,12 +1545,12 @@ def grid_plot( ax.hlines(-yedges, xmin, xmax, color="#36454F") if show: - plt.show() + plt.pyplot.show() else: return fig, ax -####################### Bokeh Interactive Plots ################################ +# Bokeh Interactive Plots def lr_plot_interactive(adata: AnnData): """Plots the LR scores for significant spots interatively using Bokeh. @@ -1589,78 +1575,3 @@ def spatialcci_plot_interactive(adata: AnnData): bokeh_object = BokehSpatialCciPlot(adata) output_notebook() show(bokeh_object.app, notebook_handle=True) - - -# def het_plot_interactive(adata: AnnData): -# bokeh_object = BokehCciPlot(adata) -# output_notebook() -# show(bokeh_object.app, notebook_handle=True) - - -# Bokeh & old grid plots; -# has not been tested since multi-LR testing implimentation. - -# def het_plot_interactive(adata: AnnData): -# bokeh_object = BokehCciPlot(adata) -# output_notebook() -# show(bokeh_object.app, notebook_handle=True) - - -# def grid_plot( -# adata: AnnData, -# use_het: str = None, -# num_row: int = 10, -# num_col: int = 10, -# vmin: float = None, -# vmax: float = None, -# cropped: bool = True, -# margin: int = 100, -# dpi: int = 100, -# name: str = None, -# output: str = None, -# copy: bool = False, -# ) -> Optional[AnnData]: -# -# """ -# Cell diversity plot for sptial transcriptomics data. -# -# Parameters -# ---------- -# adata: Annotated data matrix. -# use_het: Cluster heterogeneity count results from tl.cci_rank.het -# num_row: int Number of grids on height -# num_col: int Number of grids on width -# cropped crop image or not. -# margin margin used in cropping. -# dpi: Set dpi as the resolution for the plot. -# name: Name of the output figure file. -# output: Save the figure as file or not. -# copy: Return a copy instead of writing to adata. -# -# Returns -# ------- -# Nothing -# """ -# -# try: -# import seaborn as sns -# except: -# raise ImportError("Please run `pip install seaborn`") -# plt.subplots() -# -# sns.heatmap( -# pd.DataFrame(np.array(adata.obsm[use_het]).reshape(num_col, num_row)).T, -# vmin=vmin, -# vmax=vmax, -# ) -# plt.axis("equal") -# -# if output is not None: -# plt.savefig( -# output + "/" + name + "_heatmap.pdf", -# dpi=dpi, -# bbox_inches="tight", -# pad_inches=0, -# ) -# -# plt.show() diff --git a/stlearn/plotting/cci_plot_helpers.py b/stlearn/pl/cci_plot_helpers.py similarity index 88% rename from stlearn/plotting/cci_plot_helpers.py rename to stlearn/pl/cci_plot_helpers.py index 045612e0..28ac1b5a 100644 --- a/stlearn/plotting/cci_plot_helpers.py +++ b/stlearn/pl/cci_plot_helpers.py @@ -1,24 +1,19 @@ -""" Helper functions for cci_plot.py. -""" +"""Helper functions for cci_plot.py.""" -import sys -import math -import numpy as np -import pandas as pd import matplotlib +import matplotlib.cm as cm +import matplotlib.colors as plt_colors +import matplotlib.patches as patches import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from anndata import AnnData from matplotlib.axes import Axes from matplotlib.patches import Arc, Wedge -from mpl_toolkits.axes_grid1 import make_axes_locatable - from matplotlib.path import Path -import matplotlib.patches as patches -import matplotlib.colors as plt_colors -import matplotlib.cm as cm - -from ..tools.microenv.cci.het import get_edges +from mpl_toolkits.axes_grid1 import make_axes_locatable -from anndata import AnnData +from stlearn.tl.cci.het import get_edges # Helper functions for overview plots of the LRs. @@ -37,22 +32,30 @@ def lr_scatter( show=True, max_text=100, highlight_color="red", - figsize: tuple = None, + figsize: tuple | None = None, show_all: bool = False, ): - """General plotting of the LR features.""" - highlight = type(highlight_lrs) != type(None) - if not highlight: + lr_df = data.uns["lr_summary"] + + if max_text > len(lr_df): + print(f"Note: max_text ({max_text}) exceeds available LRs ({len(lr_df)})") + + if highlight_lrs is None: show_text = show_text if n_top <= max_text else False else: + missing_lrs = [lr for lr in highlight_lrs if lr not in lr_df.index] + if missing_lrs: + raise ValueError( + "The following highlight_lrs are not found in lr_summary index: " + + ",".join(missing_lrs) + ) highlight_lrs = highlight_lrs[0:max_text] - lr_df = data.uns["lr_summary"] lrs = lr_df.index.values.astype(str)[0:n_top] lr_features = data.uns["lrfeatures"] lr_df = pd.concat([lr_df, lr_features], axis=1).loc[lrs, :] if feature not in lr_df.columns: - raise Exception(f"Inputted {feature}; must be one of " f"{list(lr_df.columns)}") + raise ValueError(f"Inputted {feature}; must be one of {list(lr_df.columns)}") rot = 90 if feature != "n_spots_sig" else 70 @@ -77,39 +80,6 @@ def lr_scatter( pad=0, show_all=show_all, ) - # ranks = np.array(list(range(len(n_spots)))) - # - # if type(lr_text_fp)==type(None): - # lr_text_fp = {'weight': 'bold', 'size': 8} - # if type(axis_text_fp)==type(None): - # axis_text_fp = {'weight': 'bold', 'size': 12} - # - # if type(ax)==type(None): - # width = (7.5 / 50) * n_top if show_text and not highlight else 7.5 - # if width > 20: - # width = 20 - # fig, ax = plt.subplots(figsize=(width, 4)) - # - # # Plotting the points # - # ax.scatter(ranks, n_spots, alpha=alpha, c=color) - # - # if show_text: - # if highlight: - # ranks = ranks[[np.where(lrs==lr)[0][0] for lr in highlight_lrs]] - # ax.scatter(ranks, n_spots[ranks], alpha=alpha, c=highlight_color) - # - # for i in ranks: - # ax.text(i-.2, n_spots[i], lrs[i], rotation=rot, fontdict=lr_text_fp) - # - # ax.spines['top'].set_visible(False) - # ax.spines['right'].set_visible(False) - # ax.set_xlabel('LR Rank', axis_text_fp) - # ax.set_ylabel(feature, axis_text_fp) - # - # if show: - # plt.show() - # else: - # return ax def rank_scatter( @@ -139,14 +109,14 @@ def rank_scatter( """General plotting function for showing ranked list of items.""" ranks = np.array(list(range(len(items)))) - highlight = type(highlight_items) != type(None) - if type(lr_text_fp) == type(None): + highlight = highlight_items is not None + if lr_text_fp is None: lr_text_fp = {"weight": "bold", "size": 8} - if type(axis_text_fp) == type(None): + if axis_text_fp is None: axis_text_fp = {"weight": "bold", "size": 12} - if type(ax) == type(None): - if type(figsize) == type(None): + if ax is None: + if figsize is None: width = width_ratio * len(ranks) if show_text and not highlight else 7.5 if width > 20: width = 20 @@ -159,14 +129,14 @@ def rank_scatter( y, alpha=alpha, c=color, - s=None if type(point_sizes) == type(None) else point_sizes ** point_size_exp, + s=None if point_sizes is None else point_sizes**point_size_exp, edgecolors="none", ) y_min, y_max = ax.get_ylim() y_max = y_max + y_max * pad ax.set_ylim(y_min, y_max) - if type(point_sizes) != type(None): - # produce a legend with a cross section of sizes from the scatter + if point_sizes is not None: + # produce a legend with a cross-section of sizes from the scatter handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6, num=4) [handle.set_markeredgecolor("none") for handle in handles] starts = [label.find("{") for label in labels] @@ -180,7 +150,7 @@ def rank_scatter( label.replace(label[(starts[i]) : (ends[i])], "{" + str(counts[i]) + "}") for i, label in enumerate(labels) ] - legend2 = ax.legend( + ax.legend( handles, labels2, frameon=False, @@ -200,7 +170,7 @@ def rank_scatter( c=highlight_color, s=( None - if type(point_sizes) == type(None) + if point_sizes is None else (point_sizes[ranks_] ** point_size_exp) ), edgecolors=color, @@ -226,18 +196,18 @@ def rank_scatter( def add_arrows( adata: AnnData, - l_expr: np.array, - r_expr: np.array, + l_expr: np.ndarray, + r_expr: np.ndarray, min_expr: float, - sig_bool: np.array, + sig_bool: np.ndarray, fig, ax: Axes, - use_label: str, - int_df: pd.DataFrame, - head_width=4, - width=0.001, - arrow_cmap=None, - arrow_vmax=None, + use_label: str | None, + int_df: pd.DataFrame | None, + head_width: float = 4, + width: float = 0.001, + arrow_cmap: str | None = None, + arrow_vmax: float | None = None, ): """ Adds arrows to the current plot for significant spots to neighbours \ which is interacting with. @@ -262,7 +232,7 @@ def add_arrows( # in the base plotting function class. # Reason why is because scale_factor refers to scaling the # image to match the spot spatial coordinates, not the - # the spots to match the image coordinates!!! + # spots to match the image coordinates!!! L_bool = l_expr > min_expr R_bool = r_expr > min_expr @@ -271,14 +241,14 @@ def add_arrows( forward_edges, reverse_edges = get_edges(adata, L_bool, R_bool, sig_bool) # If int_df specified, means need to subset to CCIs which are significant # - if type(int_df) != type(None): + if int_df is not None: spot_bcs = adata.obs_names.values.astype(str) spot_labels = adata.obs[use_label].values.astype(str) label_set = int_df.index.values.astype(str) interact_bool = int_df.values > 0 # Subsetting to only significant CCI # - edges_sub = [[], []] # forward, reverse + edges_sub: list[list[tuple[str, str]]] = [[], []] # forward, reverse # ints_2 = np.zeros(int_df.shape) # Just for debugging make sure edge # list re-capitulates edge-counts. for i, edges in enumerate([forward_edges, reverse_edges]): @@ -303,8 +273,8 @@ def add_arrows( forward_edges, reverse_edges = edges_sub # If cmap specified, colour arrows by average LR expression on edge # - if type(arrow_cmap) != type(None): - edges_means = [[], []] + if arrow_cmap is not None: + edges_means: list[list[float]] = [[], []] all_means = [] for i, edges in enumerate([forward_edges, reverse_edges]): for j, edge in enumerate(edges): @@ -319,13 +289,13 @@ def add_arrows( all_means.append(mean_expr) # Determining the color maps # - arrow_vmax = np.max(all_means) if type(arrow_vmax) == type(None) else arrow_vmax + arrow_vmax = np.max(all_means) if arrow_vmax is None else arrow_vmax cmap = plt.get_cmap(arrow_cmap) c_norm = plt_colors.Normalize(vmin=0, vmax=arrow_vmax) scalar_map = cm.ScalarMappable(norm=c_norm, cmap=cmap) # Determining the edge colors # - edges_colors = [[], []] + edges_colors: list[list[tuple[float, float, float, float]]] = [[], []] for i, edges in enumerate([forward_edges, reverse_edges]): for j, edge in enumerate(edges): color_val = scalar_map.to_rgba(edges_means[i][j]) @@ -341,7 +311,7 @@ def add_arrows( axc = fig.add_axes(cax) else: - edges_colors = [None, None] + edges_colors = [[], []] # Now performing the plotting # # The arrows # @@ -366,8 +336,8 @@ def add_arrows( edge_colors=edges_colors[1], ) # Adding the color map # - if type(arrow_cmap) != type(None): - cb1 = matplotlib.colorbar.ColorbarBase( + if arrow_cmap is not None: + matplotlib.colorbar.ColorbarBase( axc, cmap=cmap, norm=c_norm, orientation="horizontal" ) @@ -383,6 +353,8 @@ def add_arrows_by_edges( edge_colors=None, axc=None, ): + if edge_colors is None: + edge_colors = [] """Adds the arrows using an edge list.""" for i, edge in enumerate(edges): # cols = ["imagecol", "imagerow"] @@ -399,7 +371,7 @@ def add_arrows_by_edges( x1, y1 = adata.obsm["spatial"][edge0_index, :] * scale_factor x2, y2 = adata.obsm["spatial"][edge1_index, :] * scale_factor dx, dy = (x2 - x1) * 0.75, (y2 - y1) * 0.75 - arrow_color = "k" if type(edge_colors) == type(None) else edge_colors[i] + arrow_color = "k" if len(edge_colors) == 0 else edge_colors[i] ax.arrow( x1, @@ -418,9 +390,9 @@ def add_arrows_by_edges( def get_int_df(adata, lr, use_label, sig_interactions, title): """Retrieves the relevant interaction count matrix.""" - no_title = type(title) == type(None) + no_title = title is None labels_ordered = adata.obs[use_label].cat.categories - if type(lr) == type(None): # No LR inputted, so just use all + if lr is None: # No LR inputted, so just use all int_df = ( adata.uns[f"lr_cci_{use_label}"] if sig_interactions @@ -428,7 +400,6 @@ def get_int_df(adata, lr, use_label, sig_interactions, title): )[labels_ordered].loc[labels_ordered] title = "Cell-Cell LR Interactions" if no_title else title else: - labels_ordered = adata.obs[use_label].cat.categories int_df = ( adata.uns[f"per_lr_cci_{use_label}"][lr] @@ -456,10 +427,9 @@ def create_flat_df(int_df): def _box_map(x, y, size, ax=None, figsize=(6.48, 4.8), cmap=None, square_scaler=700): """Main underlying helper function for generating the heatmaps.""" - if type(cmap) == type(None): + if cmap is None: cmap = "Spectral_r" - - if type(ax) == type(None): + if ax is None: fig, ax = plt.subplots(figsize=figsize) # Mapping from column names to integer coordinates @@ -567,7 +537,7 @@ def IdeogramArc( ] ) - if ax == None: + if ax is None: return verts, codes else: path = Path(verts, codes) @@ -631,7 +601,7 @@ def ChordArc( Path.CURVE4, ] - if ax == None: + if ax is None: return verts, codes else: path = Path(verts, codes) @@ -669,7 +639,7 @@ def selfChordArc(start=0, end=60, radius=1.0, chordwidth=0.7, ax=None, color=(1, Path.CURVE4, ] - if ax == None: + if ax is None: return verts, codes else: path = Path(verts, codes) @@ -688,13 +658,15 @@ def chordDiagram(X, ax, colors=None, width=0.1, pad=2, chordwidth=0.7, lim=1.1): ax : matplotlib `axes` to show the plot colors : optional - user defined colors in rgb format. Use function hex2rgb() to convert hex color to rgb color. Default: d3.js category10 + user defined colors in rgb format. Use function hex2rgb() to convert hex + color to rgb color. Default: d3.js category10 width : optional width/thickness of the ideogram arc pad : optional gap pad between two neighboring ideogram arcs, unit: degree, default: 2 degree chordwidth : optional - position of the control points for the chords, controlling the shape of the chords + position of the control points for the chords, controlling the shape + of the chords """ # X[i, j]: i -> j x = X.sum(axis=1) # sum over rows @@ -718,7 +690,7 @@ def chordDiagram(X, ax, colors=None, width=0.1, pad=2, chordwidth=0.7, lim=1.1): ] if len(x) > 10: print("x is too large! Use x smaller than 10") - if type(colors[0]) == str: + if isinstance(colors[0], str): colors = [hex2rgb(colors[i]) for i in range(len(x))] # find position for each start and end diff --git a/stlearn/plotting/classes.py b/stlearn/pl/classes.py similarity index 72% rename from stlearn/plotting/classes.py rename to stlearn/pl/classes.py index e60c7a0e..53eff406 100644 --- a/stlearn/plotting/classes.py +++ b/stlearn/pl/classes.py @@ -4,31 +4,23 @@ Date: 20 Feb 2021 """ -from lib2to3.pgen2.token import OP -from typing import Optional, Union, Mapping, List # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes - import numbers +import warnings +from typing import ( # Special + Optional, # Classes +) + +import matplotlib +import matplotlib.pyplot as plt +import networkx as nx import numpy as np import pandas as pd from anndata import AnnData - -from matplotlib import rcParams, ticker, gridspec, axes -import matplotlib.pyplot as plt -import matplotlib from scipy.interpolate import griddata -import networkx as nx from ..classes import Spatial -from ..utils import _AxesSubplot, Axes, _read_graph -from .utils import centroidpython, get_cluster, get_node, check_sublist, get_cmap - -################################################################ -# # -# Spatial base plot class # -# # -################################################################ +from ..utils import Axes, _AxesSubplot, _read_graph +from .utils import centroidpython, check_sublist, get_cluster, get_cmap, get_node class SpatialBasePlot(Spatial): @@ -37,31 +29,32 @@ def __init__( # plotting param adata: AnnData, title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - color_bar_label: Optional[str] = "", - zoom_coord: Optional[float] = None, - crop: Optional[bool] = True, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 0.7, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + color_bar_label: str = "", + zoom_coord: tuple[float, float, float, float] | None = None, + crop: bool = True, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 0.7, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, **kwds, ): super().__init__( adata, ) + self.title = title self.figsize = figsize self.image_alpha = image_alpha @@ -75,28 +68,25 @@ def __init__( if use_raw: self.query_adata = self.adata[0].raw.to_adata().copy() - if self.list_clusters != None: - assert use_label != None, "Please specify `use_label` parameter!" - - if use_label != None: + if self.list_clusters is not None: + assert use_label is not None, "Please specify `use_label` parameter!" + if use_label is not None: assert ( use_label in self.adata[0].obs.columns ), "Please choose the right label in `adata.obs.columns`!" self.use_label = use_label - if self.list_clusters is None: + unique_categories = np.array(self.adata[0].obs[use_label].cat.categories) - self.list_clusters = np.array( - self.adata[0].obs[use_label].cat.categories - ) + if self.list_clusters is None: + self.list_clusters = unique_categories else: - if type(self.list_clusters) != list: + if not isinstance(self.list_clusters, list): self.list_clusters = [self.list_clusters] clusters_indexes = [ - np.where(adata.obs[use_label].cat.categories == i)[0][0] - for i in self.list_clusters + np.where(unique_categories == i)[0][0] for i in self.list_clusters ] self.list_clusters = np.array(self.list_clusters)[ np.argsort(clusters_indexes) @@ -111,21 +101,21 @@ def __init__( stlearn_cmap = ["jana_40", "default"] cmap_available = plt.colormaps() + scanpy_cmap + stlearn_cmap error_msg = ( - "cmap must be a matplotlib.colors.LinearSegmentedColormap OR" + "cmap must be a matplotlib.colors.LinearSegmentedColormap OR " "one of these: " + str(cmap_available) ) - if type(cmap) == str: + if isinstance(cmap, str): assert cmap in cmap_available, error_msg - elif type(cmap) != matplotlib.colors.LinearSegmentedColormap: + elif not isinstance(cmap, matplotlib.colors.LinearSegmentedColormap): raise Exception(error_msg) self.cmap = cmap - if type(fig) == type(None) and type(ax) == type(None): + if fig is None and ax is None: self.fig, self.ax = self._generate_frame() else: self.fig, self.ax = fig, ax - if show_axis == False: + if not show_axis: self._remove_axis(self.ax) if show_image: @@ -174,7 +164,6 @@ def _add_image(self, main_ax: Axes): ) def _plot_colorbar(self, plot_ax: Axes, color_bar_label: str = ""): - cb = plt.colorbar( plot_ax, aspect=10, shrink=0.5, cmap=self.cmap, label=color_bar_label ) @@ -184,15 +173,13 @@ def _remove_axis(self, main_ax: Axes): main_ax.axis("off") def _crop_image(self, main_ax: _AxesSubplot, margin: float): - main_ax.set_xlim(self.imagecol.min() - margin, self.imagecol.max() + margin) - main_ax.set_ylim(self.imagerow.min() - margin, self.imagerow.max() + margin) - main_ax.set_ylim(main_ax.get_ylim()[::-1]) - def _zoom_image(self, main_ax: _AxesSubplot, zoom_coord: Optional[float]): - + def _zoom_image( + self, main_ax: _AxesSubplot, zoom_coord: tuple[float, float, float, float] + ): main_ax.set_xlim(zoom_coord[0], zoom_coord[1]) main_ax.set_ylim(zoom_coord[3], zoom_coord[2]) @@ -219,7 +206,6 @@ def _get_query_clusters_index(self): return index_query def _save_output(self): - self.fig.savefig( fname=self.fname, bbox_inches="tight", pad_inches=0, dpi=self.dpi ) @@ -231,43 +217,43 @@ def _save_output(self): # # ################################################################ -import warnings - class GenePlot(SpatialBasePlot): + gene_symbols: list[str] + def __init__( self, adata: AnnData, # plotting param - title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - color_bar_label: Optional[str] = "", - crop: Optional[bool] = True, - zoom_coord: Optional[float] = None, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + title: str | None = None, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + color_bar_label: str = "", + crop: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, # gene plot param - gene_symbols: Union[str, list] = None, - threshold: Optional[float] = None, + gene_symbols: str | list[str] | None = None, + threshold: float | None = None, method: str = "CumSum", contour: bool = False, - step_size: Optional[int] = None, - vmin: float = None, - vmax: float = None, + step_size: int | None = None, + vmin: float | None = None, + vmax: float | None = None, **kwargs, ): super().__init__( @@ -302,18 +288,18 @@ def __init__( self.step_size = step_size - if self.title == None: - if type(gene_symbols) == str: + if isinstance(gene_symbols, str): + self.gene_symbols = [gene_symbols] + elif gene_symbols is None: + self.gene_symbols = [] + else: + self.gene_symbols = gene_symbols - self.title = str(gene_symbols) - gene_symbols = [gene_symbols] - else: - self.title = ", ".join(gene_symbols) + if self.title is None: + self.title = ", ".join(self.gene_symbols) self._add_title() - self.gene_symbols = gene_symbols - gene_values = self._get_gene_expression() self.available_ids = self._add_threshold(gene_values, threshold) @@ -328,17 +314,15 @@ def __init__( if show_color_bar: self._add_color_bar(plot, color_bar_label=color_bar_label) - if fname != None: + if fname is not None: self._save_output() def _get_gene_expression(self): - # Gene plot option if len(self.gene_symbols) == 0: raise ValueError("Genes should be provided, please input genes") elif len(self.gene_symbols) == 1: - if self.gene_symbols[0] not in self.query_adata.var_names: raise ValueError( self.gene_symbols[0] @@ -349,7 +333,6 @@ def _get_gene_expression(self): return colors else: - for gene in self.gene_symbols: if gene not in self.query_adata.var.index: self.gene_symbols.remove(gene) @@ -379,8 +362,7 @@ def _get_gene_expression(self): return colors def _plot_genes(self, gene_values: pd.Series): - - if type(self.vmin) == type(None) and type(self.vmax) == type(None): + if self.vmin is None and self.vmax is None: vmin = min(gene_values) vmax = max(gene_values) else: @@ -398,13 +380,12 @@ def _plot_genes(self, gene_values: pd.Series): marker="o", vmin=vmin, vmax=vmax, - cmap=plt.get_cmap(self.cmap) if type(self.cmap) == str else self.cmap, + cmap=plt.get_cmap(self.cmap) if isinstance(self.cmap, str) else self.cmap, c=gene_values, ) return plot def _plot_contour(self, gene_values: pd.Series): - imgcol_new = self.query_adata.obsm["spatial"][:, 0] * self.scale_factor imgrow_new = self.query_adata.obsm["spatial"][:, 1] * self.scale_factor # Extracting x,y and values (z) @@ -417,7 +398,7 @@ def _plot_contour(self, gene_values: pd.Series): yi = np.linspace(y.min(), y.max(), 100) zi = griddata((x, y), z, (xi[None, :], yi[:, None]), method="linear") - if self.step_size == None: + if self.step_size is None: self.step_size = int(np.max(z) / 50) if self.step_size < 1: self.step_size = 1 @@ -428,13 +409,13 @@ def _plot_contour(self, gene_values: pd.Series): yi, zi, range(0, int(np.nanmax(zi)) + self.step_size, self.step_size), - cmap=plt.get_cmap(self.cmap) if type(self.cmap) == str else self.cmap, + cmap=plt.get_cmap(self.cmap) if isinstance(self.cmap, str) else self.cmap, alpha=self.cell_alpha, ) return cs def _add_threshold(self, gene_values, threshold): - if threshold == None: + if threshold is None: return np.repeat(True, len(gene_values)) else: return gene_values > threshold @@ -453,33 +434,33 @@ def __init__( adata: AnnData, # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - color_bar_label: Optional[str] = "", - crop: Optional[bool] = True, - zoom_coord: Optional[float] = None, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + color_bar_label: str = "", + crop: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, # gene plot param - feature: str = None, - threshold: Optional[float] = None, + feature: str | None = None, + threshold: float | None = None, contour: bool = False, - step_size: Optional[int] = None, - vmin: float = None, - vmax: float = None, + step_size: int | None = None, + vmin: float | None = None, + vmax: float | None = None, **kwargs, ): super().__init__( @@ -527,11 +508,10 @@ def __init__( if show_color_bar: self._add_color_bar(plot, color_bar_label=color_bar_label) - if fname != None: + if fname is not None: self._save_output() def _get_feature_values(self): - if self.feature not in self.query_adata.obs: raise ValueError( self.feature + " is not in data.obs, please try another feature" @@ -549,8 +529,7 @@ def _get_feature_values(self): return colors def _plot_feature(self, feature_values: pd.Series): - - if type(self.vmin) == type(None) and type(self.vmax) == type(None): + if self.vmin is None and self.vmax is None: vmin = min(feature_values) vmax = max(feature_values) else: @@ -568,13 +547,12 @@ def _plot_feature(self, feature_values: pd.Series): marker="o", vmin=vmin, vmax=vmax, - cmap=plt.get_cmap(self.cmap) if type(self.cmap) == str else self.cmap, + cmap=plt.get_cmap(self.cmap) if isinstance(self.cmap, str) else self.cmap, c=feature_values, ) return plot def _plot_contour(self, feature_values: pd.Series): - imgcol_new = self.query_adata.obsm["spatial"][:, 0] * self.scale_factor imgrow_new = self.query_adata.obsm["spatial"][:, 1] * self.scale_factor # Extracting x,y and values (z) @@ -587,7 +565,7 @@ def _plot_contour(self, feature_values: pd.Series): yi = np.linspace(y.min(), y.max(), 100) zi = griddata((x, y), z, (xi[None, :], yi[:, None]), method="linear") - if self.step_size == None: + if self.step_size is None: self.step_size = int(np.max(z) / 50) if self.step_size < 1: self.step_size = 1 @@ -598,65 +576,59 @@ def _plot_contour(self, feature_values: pd.Series): yi, zi, range(0, int(np.nanmax(zi)) + self.step_size, self.step_size), - cmap=plt.get_cmap(self.cmap) if type(self.cmap) == str else self.cmap, + cmap=plt.get_cmap(self.cmap) if isinstance(self.cmap, str) else self.cmap, alpha=self.cell_alpha, ) return cs def _add_threshold(self, feature_values, threshold): - if threshold == None: + if threshold is None: return np.repeat(True, len(feature_values)) else: return feature_values > threshold -################################################################ -# # -# Cluster plot class # -# # -################################################################ - - +# Cluster plot class class ClusterPlot(SpatialBasePlot): def __init__( self, adata: AnnData, # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "default", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - crop: Optional[bool] = True, - zoom_coord: Optional[float] = None, - margin: Optional[bool] = 100, - size: Optional[float] = 5, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "default", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + crop: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + margin: float = 100, + size: float = 5, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + fname: str | None = None, + dpi: int = 120, # cluster plot param - show_subcluster: Optional[bool] = False, - show_cluster_labels: Optional[bool] = False, - show_trajectories: Optional[bool] = False, - reverse: Optional[bool] = False, - show_node: Optional[bool] = False, - threshold_spots: Optional[int] = 5, - text_box_size: Optional[float] = 5, - color_bar_size: Optional[float] = 10, - bbox_to_anchor: Optional[Tuple[float, float]] = (1, 1), + show_subcluster: bool = False, + show_cluster_labels: bool = False, + show_trajectories: bool = False, + reverse: bool = False, + show_node: bool = False, + threshold_spots: int = 5, + text_box_size: float = 5, + color_bar_size: float = 10, + bbox_to_anchor: tuple[float, float] | None = (1, 1), # trajectory - trajectory_node_size: Optional[int] = 10, - trajectory_alpha: Optional[float] = 1.0, - trajectory_width: Optional[float] = 2.5, - trajectory_edge_color: Optional[str] = "#f4efd3", - trajectory_arrowsize: Optional[int] = 17, + trajectory_node_size: int = 10, + trajectory_alpha: float = 1.0, + trajectory_width: float = 2.5, + trajectory_edge_color: str = "#f4efd3", + trajectory_arrowsize: int = 17, ): super().__init__( adata=adata, @@ -703,7 +675,6 @@ def __init__( self._add_sub_clusters() if show_trajectories: - self.trajectory_node_size = trajectory_node_size self.trajectory_alpha = trajectory_alpha self.trajectory_width = trajectory_width @@ -712,36 +683,32 @@ def __init__( self._add_trajectories() - if fname != None: + if fname is not None: self._save_output() def _add_cluster_colors(self): - if self.use_label + "_colors" not in self.adata[0].uns: - # self.adata[0].uns[self.use_label + "_set"] = [] - self.adata[0].uns[self.use_label + "_colors"] = [] + self.adata[0].uns[self.use_label + "_colors"] = [] - for i, cluster in enumerate(self.adata[0].obs.groupby(self.use_label)): - self.adata[0].uns[self.use_label + "_colors"].append( - matplotlib.colors.to_hex(self.cmap_(i / (self.cmap_n - 1))) - ) - # self.adata[0].uns[self.use_label + "_set"].append( cluster[0] ) + for i, cluster in enumerate( + self.adata[0].obs.groupby(self.use_label, observed=True) + ): + self.adata[0].uns[self.use_label + "_colors"].append( + matplotlib.colors.to_hex(self.cmap_(i / (self.cmap_n - 1))) + ) def _plot_clusters(self): # Plot scatter plot based on pixel of spots - # for i, cluster in enumerate(self.query_adata.obs[self.use_label].cat.categories): - for i, cluster in enumerate(self.query_adata.obs.groupby(self.use_label)): - + for i, cluster in enumerate( + self.query_adata.obs.groupby(self.use_label, observed=True) + ): # Plot scatter plot based on pixel of spots subset_spatial = self.query_adata.obsm["spatial"][ check_sublist(list(self.query_adata.obs.index), list(cluster[1].index)) ] if self.use_label + "_colors" in self.adata[0].uns: - # label_set = self.adata[0].uns[self.use_label+'_set'] - label_set = ( - self.adata[0].obs[self.use_label].cat.categories.values.astype(str) - ) + label_set = self.adata[0].obs[self.use_label].cat.categories.values col_index = np.where(label_set == cluster[0])[0][0] color = self.adata[0].uns[self.use_label + "_colors"][col_index] else: @@ -773,13 +740,11 @@ def _add_cluster_bar(self, bbox_to_anchor): handleheight=1.0, edgecolor="white", ) - for handle in lgnd.legendHandles: + for handle in lgnd.legend_handles: handle.set_sizes([20.0]) def _add_cluster_labels(self): - for i, label in enumerate(self.list_clusters): - label_index = list( self.query_adata.obs[ self.query_adata.obs[self.use_label] == str(label) @@ -818,9 +783,8 @@ def _add_cluster_labels(self): ) def _add_sub_clusters(self): - if "sub_cluster_labels" not in self.query_adata.obs.columns: - raise ValueError("Please run stlearn.spatial.cluster.localization") + raise ValueError("Please run stlearn.spatial.clustering.localization") for i, label in enumerate(self.list_clusters): label_index = list( @@ -941,20 +905,23 @@ def _add_trajectories(self): ) if self.show_node: - for x, y in centroid_dict.items(): - - if x in get_node(self.list_clusters, self.adata[0].uns["split_node"]): + for node, pos in centroid_dict.items(): + if str(node) in get_node( + self.list_clusters, self.adata[0].uns["split_node"] + ): self.ax.text( - y[0], - y[1], - get_cluster(str(x), self.adata[0].uns["split_node"]), + pos[0], + pos[1], + get_cluster(str(node), self.adata[0].uns["split_node"]), color="black", fontsize=8, zorder=100, bbox=dict( facecolor=cmap( int( - get_cluster(str(x), self.adata[0].uns["split_node"]) + get_cluster( + str(node), self.adata[0].uns["split_node"] + ) ) / (len(used_colors) - 1) ), @@ -977,29 +944,29 @@ def __init__( adata: AnnData, # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "jet", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - crop: Optional[bool] = True, - zoom_coord: Optional[float] = None, - margin: Optional[bool] = 100, - size: Optional[float] = 5, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "jet", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + crop: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + margin: float = 100, + size: float = 5, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + fname: str | None = None, + dpi: int = 120, # subcluster plot param - cluster: Optional[int] = 0, - threshold_spots: Optional[int] = 5, - text_box_size: Optional[float] = 5, - bbox_to_anchor: Optional[Tuple[float, float]] = (1, 1), + cluster: int = 0, + threshold_spots: int = 5, + text_box_size: float = 5, + bbox_to_anchor: tuple[float, float] | None = (1, 1), **kwargs, ): super().__init__( @@ -1032,7 +999,7 @@ def __init__( self._add_subclusters_label(subset) - if fname != None: + if fname is not None: self._save_output() def _plot_subclusters(self, threshold_spots): @@ -1060,13 +1027,13 @@ def _plot_subclusters(self, threshold_spots): colors = colors.replace(self.mapping) - plot = self.ax.scatter( + self.ax.scatter( self.imgcol_new, self.imgrow_new, edgecolor="none", s=self.size, marker="o", - cmap=plt.get_cmap(self.cmap) if type(self.cmap) == str else self.cmap, + cmap=plt.get_cmap(self.cmap) if isinstance(self.cmap, str) else self.cmap, c=colors, alpha=self.cell_alpha, ) @@ -1074,9 +1041,14 @@ def _plot_subclusters(self, threshold_spots): return subset def _add_subclusters_label(self, subset): - if len(subset["sub_cluster_labels"].unique()) < 2: - print("lower than 2") - centroids = [centroidpython(subset[["imagecol", "imagerow"]].values)] + unique_subcluster_labels = len(subset["sub_cluster_labels"].unique()) + if unique_subcluster_labels == 1: + print("No unique labels found") + return + elif unique_subcluster_labels == 1: + imgcol = subset["imagecol"].values + imgrow = subset["imagerow"].values + centroids = [centroidpython(imgcol, imgrow)] classes = np.array([subset["sub_cluster_labels"][0]]) else: @@ -1131,31 +1103,31 @@ def __init__( adata: AnnData, # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - crop: Optional[bool] = True, - zoom_coord: Optional[float] = None, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + crop: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, # cci_rank param - use_het: Optional[str] = "het", + use_het: str = "het", contour: bool = False, - step_size: Optional[int] = None, - vmin: float = None, - vmax: float = None, + step_size: int | None = None, + vmin: float | None = None, + vmax: float | None = None, **kwargs, ): super().__init__( @@ -1202,44 +1174,42 @@ def __init__( use_result: Optional["str"] = "lr_sig_scores", # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - crop: Optional[bool] = True, - zoom_coord: Optional[float] = None, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + crop: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, # cci_rank param contour: bool = False, - step_size: Optional[int] = None, - vmin: float = None, - vmax: float = None, + step_size: int | None = None, + vmin: float | None = None, + vmax: float | None = None, **kwargs, ): # Making sure cci_rank has been run first # if "lr_summary" not in adata.uns: raise Exception( - f"To visualise LR interaction results, must run" f"st.pl.cci.run first." + "To visualise LR interaction results, must run st.pl.cci.run first." ) # By default, using the LR with most significant spots # - if type(use_lr) == type(None): + if use_lr is None: use_lr = adata.uns["lr_summary"].index.values[0] elif use_lr not in adata.uns["lr_summary"].index: - raise Exception( - f"use_lr must be one of:\n" f'{adata.uns["lr_summary"].index}' - ) + raise Exception(f"use_lr must be one of:\n{adata.uns['lr_summary'].index}") else: use_lr = str(use_lr) diff --git a/stlearn/plotting/classes_bokeh.py b/stlearn/pl/classes_bokeh.py similarity index 96% rename from stlearn/plotting/classes_bokeh.py rename to stlearn/pl/classes_bokeh.py index 7a75b2df..cd52d16a 100644 --- a/stlearn/plotting/classes_bokeh.py +++ b/stlearn/pl/classes_bokeh.py @@ -1,56 +1,50 @@ -from __future__ import division +from collections import OrderedDict + import numpy as np import pandas as pd -from PIL import Image -from stlearn.tools.microenv.cci.het import get_edges - -from bokeh.plotting import ( - figure, - show, - ColumnDataSource, - curdoc, -) +import scanpy as sc +from anndata import AnnData +from bokeh.application import Application +from bokeh.application.handlers import FunctionHandler +from bokeh.layouts import column, row from bokeh.models import ( + Arrow, + AutocompleteInput, + BasicTicker, BoxSelectTool, - LassoSelectTool, + Button, + CheckboxGroup, + ColorBar, CustomJS, Div, - Paragraph, + HoverTool, + LassoSelectTool, LinearColorMapper, - Slider, + Paragraph, Select, - AutocompleteInput, - ColorBar, + Slider, TextInput, - BasicTicker, - HoverTool, - ZoomOutTool, - CheckboxGroup, - Arrow, VeeHead, - Button, - Dropdown, - Div, + ZoomOutTool, ) - -from bokeh.models.widgets import DataTable, DateFormatter, TableColumn -from anndata import AnnData +from bokeh.models.widgets import DataTable, TableColumn from bokeh.palettes import ( - Spectral11, - Viridis256, - Reds256, Blues256, - Magma256, Category20, + Magma256, + Reds256, + Spectral11, + Viridis256, ) -from bokeh.layouts import column, row, grid -from collections import OrderedDict -from bokeh.application import Application -from bokeh.application.handlers import FunctionHandler +from bokeh.plotting import ( + ColumnDataSource, + figure, +) +from PIL import Image + from stlearn.classes import Spatial -from typing import Optional +from stlearn.tl.cci import get_edges from stlearn.utils import _read_graph -import scanpy as sc class BokehGenePlot(Spatial): @@ -63,7 +57,10 @@ def __init__( adata, ) # Open image, and make sure it's RGB*A* - image = (self.img * 255).astype(np.uint8) + if self.img is None: + raise ValueError("self.img must be a numpy array") + else: + image = (self.img * 255).astype(np.uint8) img_pillow = Image.fromarray(image).convert("RGBA") @@ -139,7 +136,6 @@ def __init__( # self.tab = Tabs(tabs = [Panel(child=self.layout, title="Gene plot")]) def modify_fig(doc): - doc.add_root(row(self.layout, width=800)) self.data_alpha.on_change("value", self.update_data) @@ -156,7 +152,6 @@ def modify_fig(doc): self.app = Application(handler) def make_fig(self): - fig = figure( title=self.gene_select.value, x_range=(0, self.dim), @@ -272,7 +267,6 @@ def add_violin(self): return p def update_data(self, attrname, old, new): - if len(self.menu) != 0: self.layout.children[0].children[1] = self.make_fig() self.layout.children[1] = self.add_violin() @@ -280,7 +274,6 @@ def update_data(self, attrname, old, new): self.layout.children[1] = self.make_fig() def _get_gene_expression(self, gene_symbols): - if gene_symbols[0] not in self.adata[0].var_names: raise ValueError( gene_symbols[0] + " is not exist in the data, please try another gene" @@ -334,7 +327,10 @@ def __init__( super().__init__(adata) # Open image, and make sure it's RGB*A* - image = (self.img * 255).astype(np.uint8) + if self.img is None: + raise ValueError("self.img must be a numpy array") + else: + image = (self.img * 255).astype(np.uint8) img_pillow = Image.fromarray(image).convert("RGBA") @@ -361,7 +357,7 @@ def __init__( self.use_label = Select(title="Select use_label:", value=menu[0], options=menu) # Initialize the color - from stlearn.plotting.cluster_plot import cluster_plot + from stlearn.pl.cluster_plot import cluster_plot if len(adata.obs[self.use_label.value].cat.categories) <= 20: cluster_plot(adata, use_label=self.use_label.value, show_plot=False) @@ -508,18 +504,10 @@ def modify_fig(doc): self.app = Application(handler) def update_list(self, attrname, old, name): - # Initialize the color - from stlearn.plotting.cluster_plot import cluster_plot + from stlearn.pl.cluster_plot import cluster_plot cluster_plot(self.adata[0], use_label=self.use_label.value, show_plot=False) - - # self.list_cluster = CheckboxGroup( - # labels=list(self.adata[0].obs[self.use_label.value].cat.categories), - # active=list( - # np.array(range(0, len(self.adata[0].obs[self.use_label.value].unique()))) - # ), - # ) self.list_cluster.labels = list( self.adata[0].obs[self.use_label.value].cat.categories ) @@ -528,7 +516,6 @@ def update_list(self, attrname, old, name): ) def update_data(self, attrname, old, new): - if "rank_genes_groups" in self.adata[0].uns: if ( self.use_label.value @@ -548,7 +535,7 @@ def make_fig(self): title="Cluster plot", x_range=(0, self.dim - 150), y_range=(self.dim, 0), - output_backend=self.output_backend.value + output_backend=self.output_backend.value, # Specifying xdim/ydim isn't quire right :-( # width=xdim, height=ydim, ) @@ -784,7 +771,10 @@ def __init__( adata, ) # Open image, and make sure it's RGB*A* - image = (self.img * 255).astype(np.uint8) + if self.img is None: + raise ValueError("self.img must be a numpy array") + else: + image = (self.img * 255).astype(np.uint8) img_pillow = Image.fromarray(image).convert("RGBA") @@ -853,7 +843,6 @@ def __init__( # self.tab = Tabs(tabs = [Panel(child=self.layout, title="Gene plot")]) def modify_fig(doc): - doc.add_root(row(self.layout, width=800)) self.data_alpha.on_change("value", self.update_data) @@ -867,7 +856,6 @@ def modify_fig(doc): self.app = Application(handler) def make_fig(self): - fig = figure( title=self.lr_select.value, # self.het_select.value, x_range=(0, self.dim - 150), @@ -934,7 +922,6 @@ def update_data(self, attrname, old, new): self.layout.children[1] = self.make_fig() def _get_het(self, het): - if het not in self.adata[0].obsm: raise ValueError(het + " is not exist in the data, please try another het") @@ -963,7 +950,10 @@ def __init__( adata, ) # Open image, and make sure it's RGB*A* - image = (self.img * 255).astype(np.uint8) + if self.img is None: + raise ValueError("self.img must be a numpy array") + else: + image = (self.img * 255).astype(np.uint8) img_pillow = Image.fromarray(image).convert("RGBA") @@ -1045,7 +1035,6 @@ def __init__( # self.tab = Tabs(tabs = [Panel(child=self.layout, title="Gene plot")]) def modify_fig(doc): - doc.add_root(row(self.layout, width=800)) self.data_alpha.on_change("value", self.update_data) @@ -1061,7 +1050,6 @@ def modify_fig(doc): self.app = Application(handler) def make_fig(self): - fig = figure( title="Spatial CCI plot", x_range=(0, self.dim - 150), @@ -1165,10 +1153,10 @@ def _get_cci_lr_edges(self): selected = self.annot_select.value # Extracting the data # - l, r = lr.split("_") + ligand, receptor = lr.split("_") lr_index = np.where(adata.uns["lr_summary"].index.values == lr)[0][0] - L_bool = adata[:, l].X.toarray()[:, 0] > 0 - R_bool = adata[:, r].X.toarray()[:, 0] > 0 + L_bool = adata[:, ligand].X.toarray()[:, 0] > 0 + R_bool = adata[:, receptor].X.toarray()[:, 0] > 0 sig_bool = adata.obsm["lr_sig_scores"][:, lr_index] > 0 int_df = adata.uns[f"per_lr_cci_{selected}"][lr] @@ -1225,19 +1213,11 @@ def _add_edges(fig, adata, edges, arrow_size, forward=True, scale_factor=1): ) def update_list(self, attrname, old, name): - # Initialize the color - from stlearn.plotting.cluster_plot import cluster_plot + from stlearn.pl.cluster_plot import cluster_plot selected = self.annot_select.value.strip("raw_") cluster_plot(self.adata[0], use_label=selected, show_plot=False) - - # self.list_cluster = CheckboxGroup( - # labels=list(self.adata[0].obs[self.use_label.value].cat.categories), - # active=list( - # np.array(range(0, len(self.adata[0].obs[self.use_label.value].unique()))) - # ), - # ) self.list_cluster.labels = list(self.adata[0].obs[selected].cat.categories) self.list_cluster.active = list( np.array(range(0, len(self.adata[0].obs[selected].unique()))) @@ -1252,7 +1232,10 @@ def __init__( ): super().__init__(adata) # Open image, and make sure it's RGB*A* - image = (self.img * 255).astype(np.uint8) + if self.img is None: + raise ValueError("self.img must be a numpy array") + else: + image = (self.img * 255).astype(np.uint8) img_pillow = Image.fromarray(image).convert("RGBA") @@ -1392,7 +1375,9 @@ def make_fig(self): var new_data = source_data_2.data; - new_data = addRowToAccumulator(new_data,inds,color_index.data.index[0].toString(),color_index.data.index[0]) + ci = color_index.data.index[0]; + cs = ci.toString(); + new_data = addRowToAccumulator(new_data,inds,cs,ci) source_data_2.data = new_data @@ -1410,9 +1395,9 @@ def change_click(): empty_array[:] = np.NaN empty_array = empty_array.astype(object) for i in range(0, len(self.adata[0].uns["annotation"])): - empty_array[ - [np.array(self.adata[0].uns["annotation"]["spot"][i])] - ] = str(self.adata[0].uns["annotation"]["label"][i]) + empty_array[[np.array(self.adata[0].uns["annotation"]["spot"][i])]] = ( + str(self.adata[0].uns["annotation"]["label"][i]) + ) empty_array = pd.Series(empty_array).fillna("other") self.adata[0].obs["annotation"] = pd.Categorical(empty_array) diff --git a/stlearn/plotting/cluster_plot.py b/stlearn/pl/cluster_plot.py similarity index 50% rename from stlearn/plotting/cluster_plot.py rename to stlearn/pl/cluster_plot.py index 0cd6645a..288f0b7f 100644 --- a/stlearn/plotting/cluster_plot.py +++ b/stlearn/pl/cluster_plot.py @@ -1,67 +1,58 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np -import networkx as nx - -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes +from typing import ( + Optional, # Special +) +import matplotlib from anndata import AnnData -import warnings - -from stlearn.plotting.classes import ClusterPlot -from stlearn.plotting.classes_bokeh import BokehClusterPlot -from stlearn.plotting._docs import doc_spatial_base_plot, doc_cluster_plot -from stlearn.utils import _AxesSubplot, Axes, _docs_params - -from bokeh.io import push_notebook, output_notebook +from bokeh.io import output_notebook from bokeh.plotting import show +from stlearn.pl._docs import doc_cluster_plot, doc_spatial_base_plot +from stlearn.pl.classes import ClusterPlot +from stlearn.pl.classes_bokeh import BokehClusterPlot +from stlearn.utils import _docs_params + @_docs_params(spatial_base_plot=doc_spatial_base_plot, cluster_plot=doc_cluster_plot) def cluster_plot( adata: AnnData, # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "default", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - zoom_coord: Optional[float] = None, - crop: Optional[bool] = True, - margin: Optional[bool] = 100, - size: Optional[float] = 5, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "default", + use_label: str | None = None, + list_clusters: str | list[str] | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + zoom_coord: tuple[float, float, float, float] | None = None, + crop: bool = True, + margin: float = 100, + size: float = 5, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + fname: str | None = None, + dpi: int = 120, # cluster plot param - show_subcluster: Optional[bool] = False, - show_cluster_labels: Optional[bool] = False, - show_trajectories: Optional[bool] = False, - reverse: Optional[bool] = False, - show_node: Optional[bool] = False, - threshold_spots: Optional[int] = 5, - text_box_size: Optional[float] = 5, - color_bar_size: Optional[float] = 10, - bbox_to_anchor: Optional[Tuple[float, float]] = (1, 1), + show_subcluster: bool = False, + show_cluster_labels: bool = False, + show_trajectories: bool = False, + reverse: bool = False, + show_node: bool = False, + threshold_spots: int = 5, + text_box_size: float = 5, + color_bar_size: float = 10, + bbox_to_anchor: tuple[float, float] | None = (1, 1), # trajectory - trajectory_node_size: Optional[int] = 10, - trajectory_alpha: Optional[float] = 1.0, - trajectory_width: Optional[float] = 2.5, - trajectory_edge_color: Optional[str] = "#f4efd3", - trajectory_arrowsize: Optional[int] = 17, -) -> Optional[AnnData]: - + trajectory_node_size: int = 10, + trajectory_alpha: float = 1.0, + trajectory_width: float = 2.5, + trajectory_edge_color: str = "#f4efd3", + trajectory_arrowsize: int = 17, +) -> AnnData | None: """\ Allows the visualization of a cluster results as the discretes values of dot points in the Spatial transcriptomics array. We also support to @@ -76,13 +67,13 @@ def cluster_plot( Examples ------------------------------------- >>> import stlearn as st - >>> adata = st.datasets.example_bcba() + >>> adata = st.datasets.visium_sge(sample_id="V1_Breast_Cancer_Block_A_Section_1") >>> label = "louvain" >>> st.pl.cluster_plot(adata, use_label = label, show_trajectories = True) """ - assert use_label != None, "Please select `use_label` parameter" + assert use_label is not None, "Please select `use_label` parameter" ClusterPlot( adata, @@ -121,11 +112,12 @@ def cluster_plot( trajectory_arrowsize=trajectory_arrowsize, ) + return adata + def cluster_plot_interactive( adata: AnnData, ): - bokeh_object = BokehClusterPlot(adata) output_notebook() show(bokeh_object.app, notebook_handle=True) diff --git a/stlearn/plotting/deconvolution_plot.py b/stlearn/pl/deconvolution_plot.py similarity index 75% rename from stlearn/plotting/deconvolution_plot.py rename to stlearn/pl/deconvolution_plot.py index e69c2b13..f0c92a5c 100644 --- a/stlearn/plotting/deconvolution_plot.py +++ b/stlearn/pl/deconvolution_plot.py @@ -1,41 +1,38 @@ -from typing import Optional, Union -from anndata import AnnData -import matplotlib.pyplot as plt -from matplotlib import cm import matplotlib as mpl +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import numpy as np -import stlearn.plotting.utils as utils +from anndata import AnnData def deconvolution_plot( adata: AnnData, - library_id: str = None, + library_id: str | None = None, use_label: str = "louvain", - cluster: [int, str] = None, - celltype: str = None, + cluster: int | str | None = None, + celltype: str | None = None, celltype_threshold: float = 0, data_alpha: float = 1.0, threshold: float = 0.0, cmap: str = "tab20", - colors: list = None, # The colors to use for each label... - tissue_alpha: float = 1.0, - title: str = None, - spot_size: Union[float, int] = 10, + colors: ( + list[tuple[float, float, float, float]] | None + ) = None, # The colors to use for each label... + spot_size: float | int = 10, show_axis: bool = False, show_legend: bool = True, show_donut: bool = True, cropped: bool = True, margin: int = 100, - name: str = None, + name: str | None = None, dpi: int = 150, - output: str = None, - copy: bool = False, + output: str | None = None, figsize: tuple = (6.4, 4.8), show=True, -) -> Optional[AnnData]: - +) -> None: """\ - Clustering plot for sptial transcriptomics data. Also it has a function to display trajectory inference. + Clustering plot for sptial transcriptomics data. Also, it has a function to + display trajectory inference. Parameters ---------- @@ -45,8 +42,8 @@ def deconvolution_plot( Library id stored in AnnData. use_label Use label result of cluster method. - list_cluster - Choose set of clusters that will display in the plot. + cluster + Choose a cluster (in adata.obs[use_label]) that will display in the plot. data_alpha Opacity of the spot. tissue_alpha @@ -62,7 +59,8 @@ def deconvolution_plot( show_donut Whether to show the donut plot or not. show_trajectory - Show the spatial trajectory or not. It requires stlearn.spatial.trajectory.pseudotimespace. + Show the spatial trajectory or not. It requires + stlearn.spatial.trajectory.pseudotimespace. show_subcluster Show subcluster or not. It requires stlearn.spatial.trajectory.global_level. name @@ -102,12 +100,11 @@ def deconvolution_plot( ] label_filter_ = label_filter[base.index] - - if type(colors) == type(None): - color_vals = list(range(0, len(label_filter_), 1)) - my_norm = mpl.colors.Normalize(0, len(label_filter_)) - my_cmap = mpl.cm.get_cmap(cmap, len(color_vals)) - colors = my_cmap.colors + if colors is None: + color_vals: list[int] = list(range(0, len(label_filter_), 1)) + my_norm: mcolors.Normalize = mpl.colors.Normalize(0, len(label_filter_)) + my_cmap: mcolors.Colormap = mpl.cm.get_cmap(cmap, len(color_vals)) + colors = [my_cmap(my_norm(i)) for i in color_vals] for i, xy in enumerate(base.values): _ = ax.pie( @@ -127,14 +124,14 @@ def deconvolution_plot( ] if show_donut: - ax_pie = fig.add_axes([0.5, -0.4, 0.03, 0.5]) + ax_pie = fig.add_axes((0.5, -0.4, 0.03, 0.5)) def my_autopct(pct): return ("%1.0f%%" % pct) if pct >= 4 else "" ax_pie.pie( label_filter_.sum(axis=1), - colors=my_cmap.colors, + colors=colors, radius=10, # frame=True, autopct=my_autopct, @@ -144,9 +141,9 @@ def my_autopct(pct): textprops={"fontsize": 5}, ) - if show_legend == True: - ax_cb = fig.add_axes([0.9, 0.25, 0.03, 0.5], axisbelow=False) - cb = mpl.colorbar.ColorbarBase( + if show_legend: + ax_cb = fig.add_axes((0.9, 0.25, 0.03, 0.5), axisbelow=False) + cb = mpl.pyplot.colorbar.ColorbarBase( ax_cb, cmap=my_cmap, norm=my_norm, ticks=color_vals ) @@ -167,11 +164,8 @@ def my_autopct(pct): if cropped: ax.set_xlim(imagecol.min() - margin, imagecol.max() + margin) - ax.set_ylim(imagerow.min() - margin, imagerow.max() + margin) - ax.set_ylim(ax.get_ylim()[::-1]) - # plt.gca().invert_yaxis() if name is None: diff --git a/stlearn/pl/feat_plot.py b/stlearn/pl/feat_plot.py new file mode 100644 index 00000000..26430cbb --- /dev/null +++ b/stlearn/pl/feat_plot.py @@ -0,0 +1,94 @@ +""" +Plotting of continuous features stored in adata.obs. +""" + +from typing import ( + Optional, # Special +) + +import matplotlib +from anndata import AnnData + +from stlearn.pl.classes import FeaturePlot + + +# @_docs_params(spatial_base_plot=doc_spatial_base_plot, gene_plot=doc_gene_plot) +def feat_plot( + adata: AnnData, + feature: str | None = None, + threshold: float | None = None, + contour: bool = False, + step_size: int | None = None, + title: Optional["str"] = None, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: list | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + color_bar_label: str = "", + zoom_coord: tuple[float, float, float, float] | None = None, + crop: bool = True, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 0.7, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, + vmin: float | None = None, + vmax: float | None = None, +) -> AnnData | None: + """\ + Allows the visualization of a continuous features stored in adata.obs + for Spatial transcriptomics array. + + + Parameters + ------------------------------------- + {spatial_base_plot} + {feature_plot} + + Examples + ------------------------------------- + >>> import stlearn as st + >>> adata = st.datasets.visium_sge(sample_id="V1_Breast_Cancer_Block_A_Section_1") + >>> st.pl.gene_plot(adata, 'dpt_pseudotime') + + """ + FeaturePlot( + adata, + feature=feature, + threshold=threshold, + contour=contour, + step_size=step_size, + title=title, + figsize=figsize, + cmap=cmap, + use_label=use_label, + list_clusters=list_clusters, + ax=ax, + fig=fig, + show_plot=show_plot, + show_axis=show_axis, + show_image=show_image, + show_color_bar=show_color_bar, + color_bar_label=color_bar_label, + zoom_coord=zoom_coord, + crop=crop, + margin=margin, + size=size, + image_alpha=image_alpha, + cell_alpha=cell_alpha, + use_raw=use_raw, + fname=fname, + dpi=dpi, + vmin=vmin, + vmax=vmax, + ) + + return adata diff --git a/stlearn/pl/gene_plot.py b/stlearn/pl/gene_plot.py new file mode 100644 index 00000000..5b545170 --- /dev/null +++ b/stlearn/pl/gene_plot.py @@ -0,0 +1,99 @@ +import matplotlib +from anndata import AnnData +from bokeh.io import output_notebook +from bokeh.plotting import show + +from stlearn.pl._docs import doc_gene_plot, doc_spatial_base_plot +from stlearn.pl.classes import GenePlot +from stlearn.pl.classes_bokeh import BokehGenePlot +from stlearn.utils import _docs_params + + +@_docs_params(spatial_base_plot=doc_spatial_base_plot, gene_plot=doc_gene_plot) +def gene_plot( + adata: AnnData, + gene_symbols: str | list | None = None, + threshold: float | None = None, + method: str = "CumSum", + contour: bool = False, + step_size: int | None = None, + title: str | None = None, + figsize: tuple[float, float] | None = None, + cmap: str = "Spectral_r", + use_label: str | None = None, + list_clusters: list | None = None, + ax: matplotlib.axes.Axes | None = None, + fig: matplotlib.figure.Figure | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + color_bar_label: str = "", + zoom_coord: tuple[float, float, float, float] | None = None, + crop: bool = True, + margin: float = 100, + size: float = 7, + image_alpha: float = 1.0, + cell_alpha: float = 0.7, + use_raw: bool = False, + fname: str | None = None, + dpi: int = 120, + vmin: float | None = None, + vmax: float | None = None, +) -> AnnData | None: + """\ + Allows the visualization of a single gene or multiple genes as the values + of dot points or contour in the Spatial transcriptomics array. + + + Parameters + ------------------------------------- + {spatial_base_plot} + {gene_plot} + + Examples + ------------------------------------- + >>> import stlearn as st + >>> adata = st.datasets.visium_sge(sample_id="V1_Breast_Cancer_Block_A_Section_1") + >>> genes = ["BRCA1","BRCA2"] + >>> st.pl.gene_plot(adata, gene_symbols = genes) + + """ + GenePlot( + adata, + gene_symbols=gene_symbols, + threshold=threshold, + method=method, + contour=contour, + step_size=step_size, + title=title, + figsize=figsize, + cmap=cmap, + use_label=use_label, + list_clusters=list_clusters, + ax=ax, + fig=fig, + show_plot=show_plot, + show_axis=show_axis, + show_image=show_image, + show_color_bar=show_color_bar, + color_bar_label=color_bar_label, + zoom_coord=zoom_coord, + crop=crop, + margin=margin, + size=size, + image_alpha=image_alpha, + cell_alpha=cell_alpha, + use_raw=use_raw, + fname=fname, + dpi=dpi, + vmin=vmin, + vmax=vmax, + ) + return adata + + +def gene_plot_interactive(adata: AnnData): + bokeh_object = BokehGenePlot(adata) + output_notebook() + show(bokeh_object.app, notebook_handle=True) diff --git a/stlearn/plotting/mask_plot.py b/stlearn/pl/mask_plot.py similarity index 90% rename from stlearn/plotting/mask_plot.py rename to stlearn/pl/mask_plot.py index 483163a9..60762ccc 100644 --- a/stlearn/plotting/mask_plot.py +++ b/stlearn/pl/mask_plot.py @@ -1,26 +1,24 @@ import matplotlib -from matplotlib import pyplot as plt - -from typing import Optional, Union from anndata import AnnData +from matplotlib import pyplot as plt def plot_mask( adata: AnnData, - library_id: str = None, + library_id: str | None = None, show_spot: bool = True, spot_alpha: float = 1.0, - cmap: str = "vega_20_scanpy", + cmap_name: str = "vega_20_scanpy", tissue_alpha: float = 1.0, mask_alpha: float = 0.5, - spot_size: Union[float, int] = 6.5, + spot_size: float | int = 6.5, show_legend: bool = True, name: str = "mask_plot", dpi: int = 150, - output: str = None, + output: str | None = None, show_axis: bool = False, show_plot: bool = True, -) -> Optional[AnnData]: +) -> AnnData | None: """\ mask plot for sptial transcriptomics data. @@ -59,19 +57,20 @@ def plot_mask( Nothing """ from scanpy.plotting import palettes - from stlearn.plotting import palettes_st - if cmap == "vega_10_scanpy": + from stlearn.pl import palettes_st + + if cmap_name == "vega_10_scanpy": cmap = palettes.vega_10_scanpy - elif cmap == "vega_20_scanpy": + elif cmap_name == "vega_20_scanpy": cmap = palettes.vega_20_scanpy - elif cmap == "default_102": + elif cmap_name == "default_102": cmap = palettes.default_102 - elif cmap == "default_28": + elif cmap_name == "default_28": cmap = palettes.default_28 - elif cmap == "jana_40": + elif cmap_name == "jana_40": cmap = palettes_st.jana_40 - elif cmap == "default": + elif cmap_name == "default": cmap = palettes_st.default else: raise ValueError( @@ -171,5 +170,7 @@ def plot_mask( if output is not None: fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) - if show_plot == True: + if show_plot: plt.show() + + return adata diff --git a/stlearn/plotting/non_spatial_plot.py b/stlearn/pl/non_spatial_plot.py similarity index 85% rename from stlearn/plotting/non_spatial_plot.py rename to stlearn/pl/non_spatial_plot.py index 4e6447d0..dcdf2307 100644 --- a/stlearn/plotting/non_spatial_plot.py +++ b/stlearn/pl/non_spatial_plot.py @@ -1,23 +1,12 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np - -from stlearn._compat import Literal -from typing import Optional, Union -from anndata import AnnData -import warnings - # from .utils import get_img_from_fig, checkType import scanpy +from anndata import AnnData def non_spatial_plot( adata: AnnData, use_label: str = "louvain", -) -> Optional[AnnData]: - +) -> None: """\ A wrap function to plot all the non-spatial plot from scanpy. @@ -56,7 +45,6 @@ def non_spatial_plot( scanpy.pl.draw_graph(adata, color="dpt_pseudotime") else: - scanpy.pl.draw_graph(adata) # adata.uns[use_label+"_colors"] = adata.uns["tmp_color"] diff --git a/stlearn/plotting/palettes_st.py b/stlearn/pl/palettes_st.py similarity index 100% rename from stlearn/plotting/palettes_st.py rename to stlearn/pl/palettes_st.py diff --git a/stlearn/plotting/stack_3d_plot.py b/stlearn/pl/stack_3d_plot.py similarity index 81% rename from stlearn/plotting/stack_3d_plot.py rename to stlearn/pl/stack_3d_plot.py index 4128c97f..c958575a 100644 --- a/stlearn/plotting/stack_3d_plot.py +++ b/stlearn/pl/stack_3d_plot.py @@ -1,36 +1,39 @@ -from typing import Optional, Union -from anndata import AnnData import pandas as pd +from anndata import AnnData def stack_3d_plot( adata: AnnData, slides, + height, + width, cmap="viridis", slide_col="sample_id", use_label=None, gene_symbol=None, -) -> Optional[AnnData]: - +) -> None: """\ - Clustering plot for sptial transcriptomics data. Also it has a function to display trajectory inference. + Clustering plot for spatial transcriptomics data. Also, it has a function to + display trajectory inference. Parameters ---------- - adata + adata: Annotated data matrix. - slides + slides: A list of slide id - cmap + width: + Width of the plot + height: + Height of the plot + cmap: Color map - use_label + slide_col: + Obs column to use for coloring. + use_label: Choose label to plot (priotize) gene_symbol Choose gene symbol to plot - width - Witdh of the plot - height - Height of the plot Returns ------- Nothing @@ -46,14 +49,14 @@ def stack_3d_plot( list_df = [] for i, slide in enumerate(slides): - tmp = data.obs[data.obs[slide_col] == slide][["imagecol", "imagerow"]] + tmp = adata.obs[adata.obs[slide_col] == slide][["imagecol", "imagerow"]] tmp["sample_id"] = slide tmp["z-dimension"] = i list_df.append(tmp) df = pd.concat(list_df) - if use_label != None: + if use_label is not None: assert use_label in adata.obs.columns, "Please use the right `use_label`" df[use_label] = adata[df.index].obs[use_label].values diff --git a/stlearn/plotting/subcluster_plot.py b/stlearn/pl/subcluster_plot.py similarity index 51% rename from stlearn/plotting/subcluster_plot.py rename to stlearn/pl/subcluster_plot.py index 3714f6b9..2b7b0ec2 100644 --- a/stlearn/plotting/subcluster_plot.py +++ b/stlearn/pl/subcluster_plot.py @@ -1,19 +1,12 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np - -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes +from typing import ( + Optional, # Special +) from anndata import AnnData -import warnings -from stlearn.plotting.classes import SubClusterPlot -from stlearn.plotting._docs import doc_spatial_base_plot, doc_subcluster_plot -from stlearn.utils import _AxesSubplot, Axes, _docs_params +from stlearn.pl._docs import doc_spatial_base_plot, doc_subcluster_plot +from stlearn.pl.classes import SubClusterPlot +from stlearn.utils import _AxesSubplot, _docs_params @_docs_params( @@ -23,28 +16,28 @@ def subcluster_plot( adata: AnnData, # plotting param title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "jet", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[_AxesSubplot] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - crop: Optional[bool] = True, - margin: Optional[bool] = 100, - size: Optional[float] = 5, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 1.0, - fname: Optional[str] = None, - dpi: Optional[int] = 120, + figsize: tuple[float, float] | None = None, + cmap: str = "jet", + use_label: str | None = None, + list_clusters: list | None = None, + ax: _AxesSubplot | None = None, + show_plot: bool = True, + show_axis: bool = False, + show_image: bool = True, + show_color_bar: bool = True, + crop: bool = True, + margin: float = 100, + size: float = 5, + image_alpha: float = 1.0, + cell_alpha: float = 1.0, + fname: str | None = None, + dpi: int = 120, # subcluster plot param - cluster: Optional[int] = 0, - threshold_spots: Optional[int] = 5, - text_box_size: Optional[float] = 5, - bbox_to_anchor: Optional[Tuple[float, float]] = (1, 1), -) -> Optional[AnnData]: + cluster: int = 0, + threshold_spots: int = 5, + text_box_size: float = 5, + bbox_to_anchor: tuple[float, float] | None = (1, 1), +) -> AnnData | None: """\ Allows the visualization of a subclustering results as the discretes values of dot points in the Spatial transcriptomics array. @@ -57,14 +50,14 @@ def subcluster_plot( Examples ------------------------------------- >>> import stlearn as st - >>> adata = st.datasets.example_bcba() + >>> adata = st.datasets.visium_sge(sample_id="V1_Breast_Cancer_Block_A_Section_1") >>> label = "louvain" >>> cluster = 6 >>> st.pl.cluster_plot(adata, use_label = label, cluster = cluster) """ - assert use_label != None, "Please select `use_label` parameter" + assert use_label is not None, "Please select `use_label` parameter" assert ( use_label in adata.obs.columns ), "Please run `stlearn.spatial.cluster.localization` function!" @@ -93,3 +86,5 @@ def subcluster_plot( cluster=cluster, threshold_spots=threshold_spots, ) + + return adata diff --git a/stlearn/plotting/trajectory/DE_transition_plot.py b/stlearn/pl/trajectory/DE_transition_plot.py similarity index 95% rename from stlearn/plotting/trajectory/DE_transition_plot.py rename to stlearn/pl/trajectory/DE_transition_plot.py index 2fc82ebd..1ea91831 100644 --- a/stlearn/plotting/trajectory/DE_transition_plot.py +++ b/stlearn/pl/trajectory/DE_transition_plot.py @@ -1,6 +1,6 @@ -import matplotlib.pyplot as plt from decimal import Decimal -from typing import Optional, Union + +import matplotlib.pyplot as plt from anndata import AnnData @@ -8,11 +8,10 @@ def DE_transition_plot( adata: AnnData, top_genes: int = 10, font_size: int = 6, - name: str = None, + name: str | None = None, dpi: int = 150, - output: str = None, -) -> Optional[AnnData]: - + output: str | None = None, +) -> AnnData | None: """\ Differential expression between transition markers. @@ -136,7 +135,7 @@ def DE_transition_plot( rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, - size=font_size + size=font_size, ) axes[0][1].text( rect.get_x() + 0.01, @@ -144,7 +143,7 @@ def DE_transition_plot( p_value, color="w", **alignment, - size=font_size + size=font_size, ) rects = axes[0][0].patches @@ -161,7 +160,7 @@ def DE_transition_plot( rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, - size=font_size + size=font_size, ) axes[0][0].text( rect.get_x() - 0.01, @@ -169,7 +168,7 @@ def DE_transition_plot( p_value, color="w", **alignment, - size=font_size + size=font_size, ) rects = axes[1][1].patches @@ -186,7 +185,7 @@ def DE_transition_plot( rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, - size=font_size + size=font_size, ) axes[1][1].text( rect.get_x() + 0.01, @@ -194,7 +193,7 @@ def DE_transition_plot( p_value, color="w", **alignment, - size=font_size + size=font_size, ) rects = axes[1][0].patches @@ -211,7 +210,7 @@ def DE_transition_plot( rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, - size=font_size + size=font_size, ) axes[1][0].text( rect.get_x() - 0.01, @@ -219,7 +218,7 @@ def DE_transition_plot( p_value, color="w", **alignment, - size=font_size + size=font_size, ) plt.figtext( @@ -240,3 +239,5 @@ def DE_transition_plot( if output is not None: if name is not None: plt.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) + + return adata diff --git a/stlearn/plotting/trajectory/__init__.py b/stlearn/pl/trajectory/__init__.py similarity index 59% rename from stlearn/plotting/trajectory/__init__.py rename to stlearn/pl/trajectory/__init__.py index 16681a51..2b5417d5 100644 --- a/stlearn/plotting/trajectory/__init__.py +++ b/stlearn/pl/trajectory/__init__.py @@ -1,7 +1,19 @@ -from .pseudotime_plot import pseudotime_plot +# stlearn/pl/trajectory/__init__.py + +from .check_trajectory import check_trajectory +from .DE_transition_plot import DE_transition_plot from .local_plot import local_plot -from .tree_plot_simple import tree_plot_simple -from .tree_plot import tree_plot +from .pseudotime_plot import pseudotime_plot from .transition_markers_plot import transition_markers_plot -from .DE_transition_plot import DE_transition_plot -from .check_trajectory import check_trajectory +from .tree_plot import tree_plot +from .tree_plot_simple import tree_plot_simple + +__all__ = [ + "pseudotime_plot", + "local_plot", + "tree_plot", + "transition_markers_plot", + "DE_transition_plot", + "tree_plot_simple", + "check_trajectory", +] diff --git a/stlearn/plotting/trajectory/check_trajectory.py b/stlearn/pl/trajectory/check_trajectory.py similarity index 80% rename from stlearn/plotting/trajectory/check_trajectory.py rename to stlearn/pl/trajectory/check_trajectory.py index 29037969..587c20e9 100644 --- a/stlearn/plotting/trajectory/check_trajectory.py +++ b/stlearn/pl/trajectory/check_trajectory.py @@ -1,27 +1,26 @@ -from anndata import AnnData -from typing import Optional, Union import matplotlib.pyplot as plt -import scanpy as sc import numpy as np +import scanpy as sc +from anndata import AnnData def check_trajectory( adata: AnnData, - library_id: str = None, + trajectory: list[int], + library_id: str | None = None, use_label: str = "louvain", basis: str = "umap", pseudotime_key: str = "dpt_pseudotime", - trajectory: list = None, figsize=(10, 4), size_umap: int = 50, - size_spatial: int = 1.5, + size_spatial: float = 1.5, img_key: str = "hires", -) -> Optional[AnnData]: +) -> None: trajectory = np.array(trajectory).astype(int) assert ( trajectory in adata.uns["available_paths"].values() ), "Please choose the right path!" - trajectory = trajectory.astype(str) + trajectory_str = [str(node) for node in trajectory] assert ( pseudotime_key in adata.obs.columns ), "Please run the pseudotime or choose the right one!" @@ -40,7 +39,7 @@ def check_trajectory( ax1 = sc.pl.umap(adata, size=size_umap, show=False, ax=ax1) sc.pl.umap( - adata[adata.obs[use_label].isin(trajectory)], + adata[adata.obs[use_label].isin(trajectory_str)], size=size_umap, color=pseudotime_key, ax=ax1, @@ -56,7 +55,7 @@ def check_trajectory( ax=ax2, ) sc.pl.spatial( - adata[adata.obs[use_label].isin(trajectory)], + adata[adata.obs[use_label].isin(trajectory_str)], size=size_spatial, ax=ax2, color=pseudotime_key, @@ -66,9 +65,7 @@ def check_trajectory( show=False, ) - im = ax2.imshow( - adata.uns["spatial"][library_id]["images"][img_key], alpha=0, zorder=-1 - ) + ax2.imshow(adata.uns["spatial"][library_id]["images"][img_key], alpha=0, zorder=-1) plt.show() diff --git a/stlearn/plotting/trajectory/local_plot.py b/stlearn/pl/trajectory/local_plot.py similarity index 92% rename from stlearn/plotting/trajectory/local_plot.py rename to stlearn/pl/trajectory/local_plot.py index 878d1666..4fa1644b 100644 --- a/stlearn/plotting/trajectory/local_plot.py +++ b/stlearn/pl/trajectory/local_plot.py @@ -1,38 +1,28 @@ -from mpl_toolkits.mplot3d import Axes3D -from matplotlib.patches import FancyArrowPatch -from mpl_toolkits.mplot3d import proj3d import matplotlib.pyplot as plt import numpy as np -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np -import networkx as nx -from stlearn._compat import Literal -from typing import Optional, Union from anndata import AnnData -import warnings +from matplotlib.patches import FancyArrowPatch +from mpl_toolkits.mplot3d import proj3d def local_plot( adata: AnnData, + use_cluster: int, use_label: str = "louvain", - use_cluster: int = None, reverse: bool = False, cluster: int = 0, data_alpha: float = 1.0, arrow_alpha: float = 1.0, branch_alpha: float = 0.2, - spot_size: Union[float, int] = 1, + spot_size: float | int = 1, show_color_bar: bool = True, show_axis: bool = False, show_plot: bool = True, - name: str = None, + name: str | None = None, dpi: int = 150, - output: str = None, + output: str | None = None, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Local spatial trajectory inference plot. @@ -114,9 +104,6 @@ def local_plot( x = np.linspace(centroids_[i][0], centroids_[i + j][0], 1000) z = np.linspace(centroids_[i][1], centroids_[i + j][1], 1000) - branch = ax.plot( - x, y, z, zorder=10, c="#333333", linewidth=1, alpha=branch_alpha - ) if reverse: dpt_distance = -dpt_distance if dpt_distance <= 0: @@ -198,6 +185,8 @@ def local_plot( name = use_label fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) + return adata + def calculate_y(m): import math diff --git a/stlearn/plotting/trajectory/pseudotime_plot.py b/stlearn/pl/trajectory/pseudotime_plot.py similarity index 72% rename from stlearn/plotting/trajectory/pseudotime_plot.py rename to stlearn/pl/trajectory/pseudotime_plot.py index 54359c5a..63868009 100644 --- a/stlearn/plotting/trajectory/pseudotime_plot.py +++ b/stlearn/pl/trajectory/pseudotime_plot.py @@ -1,28 +1,25 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd import matplotlib -import numpy as np import networkx as nx -from stlearn._compat import Literal -from typing import Optional, Union +import numpy as np from anndata import AnnData -import warnings +from matplotlib import pyplot as plt +from numpy._typing import NDArray +from stlearn.pl.utils import get_cluster, get_node from stlearn.utils import _read_graph def pseudotime_plot( adata: AnnData, - library_id: str = None, + library_id: str | None = None, use_label: str = "louvain", - pseudotime_key: str = "pseudotime_key", - list_clusters: Union[str, list] = None, + pseudotime_key: str = "dpt_pseudotime", + list_clusters: str | list[str] | None = None, cell_alpha: float = 1.0, image_alpha: float = 1.0, edge_alpha: float = 0.8, node_alpha: float = 1.0, - spot_size: Union[float, int] = 6.5, + spot_size: float | int = 6.5, node_size: float = 5, show_color_bar: bool = True, show_axis: bool = False, @@ -34,12 +31,10 @@ def pseudotime_plot( cropped: bool = True, margin: int = 100, dpi: int = 150, - output: str = None, - name: str = None, - copy: bool = False, + output: str | None = None, + name: str | None = None, ax=None, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Global trajectory inference plot (Only DPT). @@ -80,35 +75,33 @@ def pseudotime_plot( dpi DPI of the output figure. output - Save the figure as file or not. - copy - Return a copy instead of writing to adata. + The output folder of the plot. + name + The filename of the plot. Returns ------- Nothing """ - # plt.rcParams['figure.dpi'] = dpi + checked_list_clusters: list[str] + if list_clusters is None: + checked_list_clusters = adata.obs[use_label].cat.categories + elif isinstance(list_clusters, str): + checked_list_clusters = [list_clusters] + else: + checked_list_clusters = list_clusters imagecol = adata.obs["imagecol"] imagerow = adata.obs["imagerow"] - - if list_clusters == None: - list_clusters = np.array(range(0, len(adata.obs[use_label].unique()))).astype( - int - ) - # Get query clusters - command = [] - # for i in list_clusters: - # command.append(use_label + ' == "' + str(i) + '"') - # tmp = adata.obs.query(" or ".join(command)) tmp = adata.obs G = _read_graph(adata, "global_graph") labels = nx.get_edge_attributes(G, "weight") result = [] - query_node = get_node(list_clusters, adata.uns["split_node"]) + query_node = list( + map(int, get_node(checked_list_clusters, adata.uns["split_node"])) + ) for edge in G.edges(query_node): if (edge[0] in query_node) and (edge[1] in query_node): result.append(edge) @@ -121,13 +114,11 @@ def pseudotime_plot( result2.append(labels[edge[::-1]] + 0.5) fig, a = plt.subplots() - if ax != None: + if ax is not None: a = ax - centroid_dict = adata.uns["centroid_dict"] + centroid_dict: dict[int, NDArray[np.float64]] = adata.uns["centroid_dict"] centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} dpt = adata.obs[pseudotime_key] - - colors = adata.obs[use_label].astype(int) vmin = min(dpt) vmax = max(dpt) # Plot scatter plot based on pixel of spots @@ -150,19 +141,9 @@ def pseudotime_plot( cmap=plt.get_cmap("viridis"), c=scale.reshape(1, -1)[0], ) - - n_clus = len(colors.unique()) - used_colors = adata.uns[use_label + "_colors"] cmaps = matplotlib.colors.LinearSegmentedColormap.from_list("", used_colors) - cmap = plt.get_cmap(cmaps) - bounds = np.linspace(0, n_clus, n_clus + 1) - norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N) - - norm = matplotlib.colors.Normalize(vmin=min(colors), vmax=max(colors)) - m = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) - if show_graph: nx.draw_networkx( G, @@ -176,19 +157,18 @@ def pseudotime_plot( edge_color="#333333", ) - for x, y in centroid_dict.items(): - - if x in get_node(list_clusters, adata.uns["split_node"]): + for node, pos in centroid_dict.items(): + if str(node) in get_node(checked_list_clusters, adata.uns["split_node"]): a.text( - y[0], - y[1], - get_cluster(str(x), adata.uns["split_node"]), + pos[0], + pos[1], + get_cluster(str(node), adata.uns["split_node"]), color="white", fontsize=node_size, zorder=100, bbox=dict( facecolor=cmap( - int(get_cluster(str(x), adata.uns["split_node"])) + int(get_cluster(str(node), adata.uns["split_node"])) / (len(used_colors) - 1) ), boxstyle="circle", @@ -197,7 +177,6 @@ def pseudotime_plot( ) if show_trajectories: - used_colors = adata.uns[use_label + "_colors"] cmaps = matplotlib.colors.LinearSegmentedColormap.from_list("", used_colors) @@ -241,19 +220,20 @@ def pseudotime_plot( ) if show_node: - for x, y in centroid_dict.items(): - - if x in get_node(list_clusters, adata.uns["split_node"]): + for node, pos in centroid_dict.items(): + if str(node) in get_node( + checked_list_clusters, adata.uns["split_node"] + ): a.text( - y[0], - y[1], - get_cluster(str(x), adata.uns["split_node"]), + pos[0], + pos[1], + str(get_cluster(str(node), adata.uns["split_node"])), color="black", fontsize=8, zorder=100, bbox=dict( facecolor=cmap( - int(get_cluster(str(x), adata.uns["split_node"])) + get_cluster(str(node), adata.uns["split_node"]) / (len(used_colors) - 1) ), boxstyle="circle", @@ -289,21 +269,9 @@ def pseudotime_plot( a.set_ylim(a.get_ylim()[::-1]) # plt.gca().invert_yaxis() - if output is not None: + if output is not None and name is not None: fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) - if show_plot == True: + if show_plot: plt.show() - -# get name of cluster by subcluster -def get_cluster(search, dictionary): - for cl, sub in dictionary.items(): - if search in sub: - return cl - - -def get_node(node_list, split_node): - result = np.array([]) - for node in node_list: - result = np.append(result, np.array(split_node[int(node)]).astype(int)) - return result.astype(int) + return adata diff --git a/stlearn/plotting/trajectory/transition_markers_plot.py b/stlearn/pl/trajectory/transition_markers_plot.py similarity index 91% rename from stlearn/plotting/trajectory/transition_markers_plot.py rename to stlearn/pl/trajectory/transition_markers_plot.py index 9f81d100..c816b193 100644 --- a/stlearn/plotting/trajectory/transition_markers_plot.py +++ b/stlearn/pl/trajectory/transition_markers_plot.py @@ -1,17 +1,17 @@ -import matplotlib.pyplot as plt from decimal import Decimal + +import matplotlib.pyplot as plt from anndata import AnnData -from typing import Optional, Union def transition_markers_plot( adata: AnnData, + trajectory: str, top_genes: int = 10, - trajectory: str = None, dpi: int = 150, - output: str = None, - name: str = None, -) -> Optional[AnnData]: + output: str | None = None, + name: str | None = None, +) -> AnnData | None: """\ Plot transition marker. @@ -19,10 +19,10 @@ def transition_markers_plot( ---------- adata Annotated data matrix. - top_genes - Top genes users want to display in the plot. trajectory Name of a clade/branch user wants to plot transition markers. + top_genes + Top genes users want to display in the plot. dpi The resolution of the plot. output @@ -34,10 +34,10 @@ def transition_markers_plot( Anndata """ - if trajectory == None: - raise ValueError("Please input the trajectory name!") if trajectory not in adata.uns: - raise ValueError("Please input the right trajectory name!") + raise ValueError( + "Please input the right trajectory name - not found in adata.uns!" + ) pos = ( adata.uns[trajectory][adata.uns[trajectory]["score"] >= 0] @@ -100,7 +100,7 @@ def transition_markers_plot( rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, - size=6 + size=6, ) axes[1].text( rect.get_x() + 0.01, @@ -108,7 +108,7 @@ def transition_markers_plot( p_value, color="w", **alignment, - size=6 + size=6, ) rects = axes[0].patches @@ -125,7 +125,7 @@ def transition_markers_plot( rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, - size=6 + size=6, ) axes[0].text( rect.get_x() - 0.01, @@ -133,7 +133,7 @@ def transition_markers_plot( p_value, color="w", **alignment, - size=6 + size=6, ) plt.figtext(0.5, 0.9, trajectory, ha="center", va="center") @@ -146,7 +146,9 @@ def transition_markers_plot( if name is None: name = trajectory - if output is not None: + if output is not None and name is not None: fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) plt.show() + + return adata diff --git a/stlearn/plotting/trajectory/tree_plot.py b/stlearn/pl/trajectory/tree_plot.py similarity index 83% rename from stlearn/plotting/trajectory/tree_plot.py rename to stlearn/pl/trajectory/tree_plot.py index 90ade45f..ce753128 100644 --- a/stlearn/plotting/trajectory/tree_plot.py +++ b/stlearn/pl/trajectory/tree_plot.py @@ -1,38 +1,31 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np -import networkx as nx import math import random -from stlearn._compat import Literal -from typing import Optional, Union + +import networkx as nx from anndata import AnnData -import warnings -import io -from copy import deepcopy +from matplotlib import pyplot as plt + from stlearn.utils import _read_graph def tree_plot( adata: AnnData, - library_id: str = None, - figsize: Union[float, int] = (10, 4), + library_id: str | None = None, + figsize: tuple[float, float] = (10, 4), data_alpha: float = 1.0, use_label: str = "louvain", - spot_size: Union[float, int] = 50, + spot_size: float | int = 50, fontsize: int = 6, piesize: float = 0.15, zoom: float = 0.1, - name: str = None, - output: str = None, + name: str | None = None, + output: str | None = None, dpi: int = 180, show_all: bool = False, show_plot: bool = True, ncols: int = 4, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Hierarchical tree plot represent for the global spatial trajectory inference. @@ -108,9 +101,11 @@ def tree_plot( output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0 ) - if show_plot == True: + if show_plot: plt.show() + return adata + def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5): """ @@ -120,23 +115,24 @@ def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5 If the graph is a tree this will return the positions to plot this in a hierarchical layout. - G: the graph (must be a tree) - - root: the root node of current branch - - if the tree is directed and this is not given, - the root will be found and used - - if the tree is directed and this is given, then - the positions will be just for the descendants of this node. - - if the tree is undirected and not given, - then a random choice will be used. - - width: horizontal space allocated for this branch - avoids overlap with other branches - - vert_gap: gap between levels of hierarchy - - vert_loc: vertical location of root - - xcenter: horizontal location of root + G: + the graph (must be a tree) + root: + the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + width: + horizontal space allocated for this branch - avoids overlap with other branches + vert_gap: + gap between levels of hierarchy + vert_loc: + vertical location of root + xcenter: + horizontal location of root """ if not nx.is_tree(G): raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") diff --git a/stlearn/plotting/trajectory/tree_plot_simple.py b/stlearn/pl/trajectory/tree_plot_simple.py similarity index 83% rename from stlearn/plotting/trajectory/tree_plot_simple.py rename to stlearn/pl/trajectory/tree_plot_simple.py index 3b2395fd..0dd9902c 100644 --- a/stlearn/plotting/trajectory/tree_plot_simple.py +++ b/stlearn/pl/trajectory/tree_plot_simple.py @@ -1,38 +1,31 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np -import networkx as nx import math import random -from stlearn._compat import Literal -from typing import Optional, Union + +import networkx as nx from anndata import AnnData -import warnings -import io -from copy import deepcopy +from matplotlib import pyplot as plt + from stlearn.utils import _read_graph def tree_plot_simple( adata: AnnData, - library_id: str = None, - figsize: Union[float, int] = (10, 4), + library_id: str | None = None, + figsize: tuple[float, float] = (10, 4), data_alpha: float = 1.0, use_label: str = "louvain", - spot_size: Union[float, int] = 50, + spot_size: float | int = 50, fontsize: int = 6, piesize: float = 0.15, zoom: float = 0.1, - name: str = None, - output: str = None, + name: str | None = None, + output: str | None = None, dpi: int = 180, show_all: bool = False, show_plot: bool = True, ncols: int = 4, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Hierarchical tree plot represent for the global spatial trajectory inference. @@ -108,9 +101,11 @@ def tree_plot_simple( output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0 ) - if show_plot == True: + if show_plot: plt.show() + return adata + def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5): """ @@ -120,23 +115,24 @@ def hierarchy_pos(G, root=None, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5 If the graph is a tree this will return the positions to plot this in a hierarchical layout. - G: the graph (must be a tree) - - root: the root node of current branch - - if the tree is directed and this is not given, - the root will be found and used - - if the tree is directed and this is given, then - the positions will be just for the descendants of this node. - - if the tree is undirected and not given, - then a random choice will be used. - - width: horizontal space allocated for this branch - avoids overlap with other branches - - vert_gap: gap between levels of hierarchy - - vert_loc: vertical location of root - - xcenter: horizontal location of root + G: + the graph (must be a tree) + root: + the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + width: + horizontal space allocated for this branch - avoids overlap with other branches + vert_gap: + gap between levels of hierarchy + vert_loc: + vertical location of root + xcenter: + horizontal location of root """ if not nx.is_tree(G): raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") diff --git a/stlearn/plotting/trajectory/utils.py b/stlearn/pl/trajectory/utils.py similarity index 99% rename from stlearn/plotting/trajectory/utils.py rename to stlearn/pl/trajectory/utils.py index f7b46284..a8095616 100644 --- a/stlearn/plotting/trajectory/utils.py +++ b/stlearn/pl/trajectory/utils.py @@ -1,5 +1,4 @@ def checkType(arr, n=2): - # If the first two and the last two elements # of the array are in increasing order if arr[0] <= arr[1] and arr[n - 2] <= arr[n - 1]: diff --git a/stlearn/plotting/utils.py b/stlearn/pl/utils.py similarity index 75% rename from stlearn/plotting/utils.py rename to stlearn/pl/utils.py index fb22686a..30363049 100644 --- a/stlearn/plotting/utils.py +++ b/stlearn/pl/utils.py @@ -1,24 +1,12 @@ -import numpy as np -import pandas as pd - import io -from PIL import Image import matplotlib import matplotlib.pyplot as plt -from anndata import AnnData +import numpy as np +from PIL import Image from scanpy.plotting import palettes -from stlearn.plotting import palettes_st -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes - -from enum import Enum - -from matplotlib import rcParams, ticker, gridspec, axes -from matplotlib.axes import Axes -from abc import ABC +from stlearn.pl import palettes_st def get_img_from_fig(fig, dpi=180): @@ -39,26 +27,22 @@ def get_img_from_fig(fig, dpi=180): def centroidpython(x, y): - l = len(x) - return sum(x) / l, sum(y) / l - - -def get_cluster(search, dictionary): - for ( - cl, - sub, - ) in ( - dictionary.items() - ): # for name, age in dictionary.iteritems(): (for Python 2.x) - if search in sub: + length_of_x = len(x) + return sum(x) / length_of_x, sum(y) / length_of_x + + +# get name of cluster by subcluster +def get_cluster(search: str, split_node: dict[str, list[str]]): + for cl, sub in split_node.items(): + if str(search) in sub: return cl -def get_node(node_list, split_node): - result = np.array([]) +def get_node(node_list: list[str], split_node: dict[str, list[str]]) -> list[str]: + all_values = [] for node in node_list: - result = np.append(result, np.array(split_node[node]).astype(int)) - return result.astype(int) + all_values.extend(split_node[node]) + return all_values def check_sublist(full, sub): @@ -82,10 +66,10 @@ def get_cmap(cmap): cmap = palettes_st.jana_40 elif cmap == "default": cmap = palettes_st.default - elif type(cmap) == str: # If refers to matplotlib cmap + elif isinstance(cmap, str): # If refers to matplotlib cmap cmap_n = plt.get_cmap(cmap).N return plt.get_cmap(cmap), cmap_n - elif type(cmap) == matplotlib.colors.LinearSegmentedColormap: # already cmap + elif isinstance(cmap, matplotlib.colors.LinearSegmentedColormap): # already cmap cmap_n = cmap.N return cmap, cmap_n @@ -106,9 +90,9 @@ def check_cmap(cmap): "cmap must be a matplotlib.colors.LinearSegmentedColormap OR" "one of these: " + str(cmap_available) ) - if type(cmap) == str: + if isinstance(cmap, str): assert cmap in cmap_available, error_msg - elif type(cmap) != matplotlib.colors.LinearSegmentedColormap: + elif not isinstance(cmap, matplotlib.colors.LinearSegmentedColormap): raise Exception(error_msg) return cmap @@ -137,7 +121,7 @@ def get_colors(adata, obs_key, cmap="default", label_set=None): adata.uns[col_key] = colors_ordered # Returning the colors of the desired labels in indicated order # - if type(label_set) != type(None): + if label_set is not None: colors_ordered = [ colors_ordered[np.where(labels_ordered == label)[0][0]] for label in label_set diff --git a/stlearn/plotting/feat_plot.py b/stlearn/plotting/feat_plot.py deleted file mode 100644 index 7df51516..00000000 --- a/stlearn/plotting/feat_plot.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Plotting of continuous features stored in adata.obs. -""" - -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np - -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes - -from anndata import AnnData -import warnings - -from stlearn.plotting.classes import FeaturePlot -from stlearn.plotting.classes_bokeh import BokehGenePlot -from stlearn.plotting._docs import doc_spatial_base_plot, doc_gene_plot -from stlearn.utils import Empty, _empty, _AxesSubplot, _docs_params - -from bokeh.io import push_notebook, output_notebook -from bokeh.plotting import show - -# @_docs_params(spatial_base_plot=doc_spatial_base_plot, gene_plot=doc_gene_plot) -def feat_plot( - adata: AnnData, - feature: str = None, - threshold: Optional[float] = None, - contour: bool = False, - step_size: Optional[int] = None, - title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - color_bar_label: Optional[str] = "", - zoom_coord: Optional[float] = None, - crop: Optional[bool] = True, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 0.7, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, - vmin: Optional[float] = None, - vmax: Optional[float] = None, -) -> Optional[AnnData]: - """\ - Allows the visualization of a continuous features stored in adata.obs - for Spatial transcriptomics array. - - - Parameters - ------------------------------------- - {spatial_base_plot} - {feature_plot} - - Examples - ------------------------------------- - >>> import stlearn as st - >>> adata = st.datasets.example_bcba() - >>> st.pl.gene_plot(adata, 'dpt_pseudotime') - - """ - FeaturePlot( - adata, - feature=feature, - threshold=threshold, - contour=contour, - step_size=step_size, - title=title, - figsize=figsize, - cmap=cmap, - use_label=use_label, - list_clusters=list_clusters, - ax=ax, - fig=fig, - show_plot=show_plot, - show_axis=show_axis, - show_image=show_image, - show_color_bar=show_color_bar, - color_bar_label=color_bar_label, - zoom_coord=zoom_coord, - crop=crop, - margin=margin, - size=size, - image_alpha=image_alpha, - cell_alpha=cell_alpha, - use_raw=use_raw, - fname=fname, - dpi=dpi, - vmin=vmin, - vmax=vmax, - ) diff --git a/stlearn/plotting/gene_plot.py b/stlearn/plotting/gene_plot.py deleted file mode 100644 index c755d12b..00000000 --- a/stlearn/plotting/gene_plot.py +++ /dev/null @@ -1,109 +0,0 @@ -from matplotlib import pyplot as plt -from PIL import Image -import pandas as pd -import matplotlib -import numpy as np - -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes - -from anndata import AnnData -import warnings - -from stlearn.plotting.classes import GenePlot -from stlearn.plotting.classes_bokeh import BokehGenePlot -from stlearn.plotting._docs import doc_spatial_base_plot, doc_gene_plot -from stlearn.utils import Empty, _empty, _AxesSubplot, _docs_params - -from bokeh.io import push_notebook, output_notebook -from bokeh.plotting import show - - -@_docs_params(spatial_base_plot=doc_spatial_base_plot, gene_plot=doc_gene_plot) -def gene_plot( - adata: AnnData, - gene_symbols: Union[str, list] = None, - threshold: Optional[float] = None, - method: str = "CumSum", - contour: bool = False, - step_size: Optional[int] = None, - title: Optional["str"] = None, - figsize: Optional[Tuple[float, float]] = None, - cmap: Optional[str] = "Spectral_r", - use_label: Optional[str] = None, - list_clusters: Optional[list] = None, - ax: Optional[matplotlib.axes.Axes] = None, - fig: Optional[matplotlib.figure.Figure] = None, - show_plot: Optional[bool] = True, - show_axis: Optional[bool] = False, - show_image: Optional[bool] = True, - show_color_bar: Optional[bool] = True, - color_bar_label: Optional[str] = "", - zoom_coord: Optional[float] = None, - crop: Optional[bool] = True, - margin: Optional[bool] = 100, - size: Optional[float] = 7, - image_alpha: Optional[float] = 1.0, - cell_alpha: Optional[float] = 0.7, - use_raw: Optional[bool] = False, - fname: Optional[str] = None, - dpi: Optional[int] = 120, - vmin: Optional[float] = None, - vmax: Optional[float] = None, -) -> Optional[AnnData]: - """\ - Allows the visualization of a single gene or multiple genes as the values - of dot points or contour in the Spatial transcriptomics array. - - - Parameters - ------------------------------------- - {spatial_base_plot} - {gene_plot} - - Examples - ------------------------------------- - >>> import stlearn as st - >>> adata = st.datasets.example_bcba() - >>> genes = ["BRCA1","BRCA2"] - >>> st.pl.gene_plot(adata, gene_symbols = genes) - - """ - GenePlot( - adata, - gene_symbols=gene_symbols, - threshold=threshold, - method=method, - contour=contour, - step_size=step_size, - title=title, - figsize=figsize, - cmap=cmap, - use_label=use_label, - list_clusters=list_clusters, - ax=ax, - fig=fig, - show_plot=show_plot, - show_axis=show_axis, - show_image=show_image, - show_color_bar=show_color_bar, - color_bar_label=color_bar_label, - zoom_coord=zoom_coord, - crop=crop, - margin=margin, - size=size, - image_alpha=image_alpha, - cell_alpha=cell_alpha, - use_raw=use_raw, - fname=fname, - dpi=dpi, - vmin=vmin, - vmax=vmax, - ) - - -def gene_plot_interactive(adata: AnnData): - bokeh_object = BokehGenePlot(adata) - output_notebook() - show(bokeh_object.app, notebook_handle=True) diff --git a/stlearn/pp.py b/stlearn/pp.py index 87407591..9a191237 100644 --- a/stlearn/pp.py +++ b/stlearn/pp.py @@ -1,7 +1,18 @@ +from .image_preprocessing.feature_extractor import extract_feature +from .image_preprocessing.image_tiling import tiling +from .preprocessing.filter_cells import filter_cells from .preprocessing.filter_genes import filter_genes -from .preprocessing.normalize import normalize_total -from .preprocessing.log_scale import log1p -from .preprocessing.log_scale import scale from .preprocessing.graph import neighbors -from .image_preprocessing.image_tiling import tiling -from .image_preprocessing.feature_extractor import extract_feature +from .preprocessing.log_scale import log1p, scale +from .preprocessing.normalize import normalize_total + +__all__ = [ + "filter_cells", + "filter_genes", + "normalize_total", + "log1p", + "scale", + "neighbors", + "tiling", + "extract_feature", +] diff --git a/stlearn/preprocessing/filter_cells.py b/stlearn/preprocessing/filter_cells.py new file mode 100644 index 00000000..6736c5a0 --- /dev/null +++ b/stlearn/preprocessing/filter_cells.py @@ -0,0 +1,62 @@ +import numpy as np +import scanpy +from anndata import AnnData + + +def filter_cells( + adata: AnnData, + min_counts: int | None = None, + min_genes: int | None = None, + max_counts: int | None = None, + max_genes: int | None = None, + inplace: bool = True, +) -> AnnData | None | tuple[np.ndarray, np.ndarray]: + """\ + Wrap function scanpy.pp.filter_cells + + Filter cell outliers based on counts and numbers of genes expressed. + + For instance, only keep cells with at least `min_counts` counts or + `min_genes` genes expressed. This is to filter measurement outliers, + i.e. “unreliable” observations. + + Only provide one of the optional parameters `min_counts`, `min_genes`, + `max_counts`, `max_genes` per call. + + Parameters + ---------- + adata + The (annotated) data matrix of shape `n_obs` × `n_vars`. + Rows correspond to cells and columns to genes. + min_counts + Minimum number of counts required for a cell to pass filtering. + min_genes + Minimum number of genes expressed required for a cell to pass filtering. + max_counts + Maximum number of counts required for a cell to pass filtering. + max_genes + Maximum number of genes expressed required for a cell to pass filtering. + inplace + Perform computation inplace or return result. + + Returns + ------- + Depending on `inplace`, returns the following arrays or directly subsets + and annotates the data matrix: + + cells_subset + Boolean index mask that does filtering. `True` means that the + cell is kept. `False` means the cell is removed. + number_per_cell + Depending on what was thresholded (`counts` or `genes`), + the array stores `n_counts` or `n_cells` per gene. + """ + + return scanpy.pp.filter_cells( + adata, + min_counts=min_counts, + min_genes=min_genes, + max_counts=max_counts, + max_genes=max_genes, + inplace=inplace, + ) diff --git a/stlearn/preprocessing/filter_genes.py b/stlearn/preprocessing/filter_genes.py index 5f102ea6..71bd4b58 100644 --- a/stlearn/preprocessing/filter_genes.py +++ b/stlearn/preprocessing/filter_genes.py @@ -1,18 +1,16 @@ -from typing import Union, Optional, Tuple, Collection, Sequence, Iterable -from anndata import AnnData import numpy as np -from scipy.sparse import issparse, isspmatrix_csr, csr_matrix, spmatrix import scanpy +from anndata import AnnData def filter_genes( adata: AnnData, - min_counts: Optional[int] = None, - min_cells: Optional[int] = None, - max_counts: Optional[int] = None, - max_cells: Optional[int] = None, + min_counts: int | None = None, + min_cells: int | None = None, + max_counts: int | None = None, + max_cells: int | None = None, inplace: bool = True, -) -> Union[AnnData, None, Tuple[np.ndarray, np.ndarray]]: +) -> AnnData | None | tuple[np.ndarray, np.ndarray]: """\ Wrap function scanpy.pp.filter_genes @@ -24,7 +22,7 @@ def filter_genes( `max_counts`, `max_cells` per call. Parameters ---------- - data + adata An annotated data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. min_counts @@ -49,7 +47,7 @@ def filter_genes( `n_counts` or `n_cells` per gene. """ - scanpy.pp.filter_genes( + return scanpy.pp.filter_genes( adata, min_counts=min_counts, min_cells=min_cells, diff --git a/stlearn/preprocessing/graph.py b/stlearn/preprocessing/graph.py index 8d7255c1..3f69bfee 100644 --- a/stlearn/preprocessing/graph.py +++ b/stlearn/preprocessing/graph.py @@ -1,14 +1,13 @@ +from collections.abc import Callable, Mapping from types import MappingProxyType -from typing import Union, Optional, Any, Mapping, Callable +from typing import Any, Literal import numpy as np -import scipy +import scanpy from anndata import AnnData from numpy.random import RandomState -from .._compat import Literal -import scanpy -_Method = Literal["umap", "gauss", "rapids"] +_Method = Literal["umap", "gauss"] _MetricFn = Callable[[np.ndarray, np.ndarray], float] # from sklearn.metrics.pairwise_distances.__doc__: _MetricSparseCapable = Literal[ @@ -33,21 +32,21 @@ "sqeuclidean", "yule", ] -_Metric = Union[_MetricSparseCapable, _MetricScipySpatial] +_Metric = _MetricSparseCapable | _MetricScipySpatial def neighbors( adata: AnnData, n_neighbors: int = 15, - n_pcs: Optional[int] = None, - use_rep: Optional[str] = None, + n_pcs: int | None = None, + use_rep: str | None = None, knn: bool = True, - random_state: Optional[Union[int, RandomState]] = 0, - method: Optional[_Method] = "umap", - metric: Union[_Metric, _MetricFn] = "euclidean", + random_state: int | RandomState | None = 0, + method: _Method = "umap", + metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Compute a neighborhood graph of observations [McInnes18]_. The neighbor search efficiency of this heavily relies on UMAP [McInnes18]_, @@ -79,8 +78,6 @@ def neighbors( method Use 'umap' [McInnes18]_ or 'gauss' (Gauss kernel following [Coifman05]_ with adaptive width [Haghverdi16]_) for computing connectivities. - Use 'rapids' for the RAPIDS implementation of UMAP (experimental, GPU - only). metric A known metric’s name or a callable that returns a distance. metric_kwds @@ -98,7 +95,7 @@ def neighbors( neighbors. """ - scanpy.pp.neighbors( + adata = scanpy.pp.neighbors( adata, n_neighbors=n_neighbors, n_pcs=n_pcs, @@ -112,3 +109,5 @@ def neighbors( ) print("Created k-Nearest-Neighbor graph in adata.uns['neighbors'] ") + + return adata if copy else None diff --git a/stlearn/preprocessing/log_scale.py b/stlearn/preprocessing/log_scale.py index 8ba18ec2..4a434507 100644 --- a/stlearn/preprocessing/log_scale.py +++ b/stlearn/preprocessing/log_scale.py @@ -1,19 +1,16 @@ -from typing import Union, Optional, Tuple, Collection, Sequence, Iterable -from anndata import AnnData import numpy as np -from scipy.sparse import issparse, isspmatrix_csr, csr_matrix, spmatrix -from scipy import sparse -from stlearn import logging as logg import scanpy +from anndata import AnnData +from scipy.sparse import spmatrix def log1p( - adata: Union[AnnData, np.ndarray, spmatrix], + adata: AnnData | np.ndarray | spmatrix, copy: bool = False, chunked: bool = False, - chunk_size: Optional[int] = None, - base: Optional[float] = None, -) -> Optional[AnnData]: + chunk_size: int | None = None, + base: float | None = None, +) -> AnnData | None: """\ Wrap function of scanpy.pp.log1p Copyright (c) 2017 F. Alexander Wolf, P. Angerer, Theis Lab @@ -41,17 +38,19 @@ def log1p( Returns or updates `data`, depending on `copy`. """ - scanpy.pp.log1p(adata, copy=copy, chunked=chunked, chunk_size=chunk_size, base=base) - + result = scanpy.pp.log1p( + adata, copy=copy, chunked=chunked, chunk_size=chunk_size, base=base + ) print("Log transformation step is finished in adata.X") + return result def scale( - adata: Union[AnnData, np.ndarray, spmatrix], + data: AnnData | spmatrix | np.ndarray, zero_center: bool = True, - max_value: Optional[float] = None, + max_value: float | None = None, copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | spmatrix | np.ndarray | None: """\ Wrap function of scanpy.pp.scale @@ -62,7 +61,7 @@ def scale( the future, they might be set to NaNs. Parameters ---------- - data + data: The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. zero_center @@ -75,9 +74,11 @@ def scale( determines whether a copy is returned. Returns ------- - Depending on `copy` returns or updates `adata` with a scaled `adata.X`. + Depending on `copy` returns or updates `data` with a scaled `data.X`. """ - scanpy.pp.scale(adata, zero_center=zero_center, max_value=max_value, copy=copy) - + result = scanpy.pp.scale( + data, zero_center=zero_center, max_value=max_value, copy=copy + ) print("Scale step is finished in adata.X") + return result diff --git a/stlearn/preprocessing/normalize.py b/stlearn/preprocessing/normalize.py index 0bfe006a..376a2f04 100644 --- a/stlearn/preprocessing/normalize.py +++ b/stlearn/preprocessing/normalize.py @@ -1,23 +1,21 @@ -from typing import Optional, Union, Iterable, Dict +from collections.abc import Iterable +from typing import Literal import numpy as np -from anndata import AnnData -from scipy.sparse import issparse -from sklearn.utils import sparsefuncs -from stlearn._compat import Literal import scanpy +from anndata import AnnData def normalize_total( adata: AnnData, - target_sum: Optional[float] = None, + target_sum: float | None = None, exclude_highly_expressed: bool = False, max_fraction: float = 0.05, - key_added: Optional[str] = None, - layers: Union[Literal["all"], Iterable[str]] = None, - layer_norm: Optional[str] = None, + key_added: str | None = None, + layers: Literal["all"] | Iterable[str] | None = None, + layer_norm: str | None = None, inplace: bool = True, -) -> Optional[Dict[str, np.ndarray]]: +) -> dict[str, np.ndarray] | None: """\ Wrap function from scanpy.pp.log1p Normalize counts per cell. @@ -72,7 +70,7 @@ def normalize_total( `adata.X` and `adata.layers`, depending on `inplace`. """ - scanpy.pp.normalize_total( + t = scanpy.pp.normalize_total( adata, target_sum=target_sum, exclude_highly_expressed=exclude_highly_expressed, @@ -84,3 +82,5 @@ def normalize_total( ) print("Normalization step is finished in adata.X") + + return t diff --git a/stlearn/spatial.py b/stlearn/spatial.py deleted file mode 100644 index 8f62d633..00000000 --- a/stlearn/spatial.py +++ /dev/null @@ -1,5 +0,0 @@ -from .spatials import clustering -from .spatials import smooth -from .spatials import trajectory -from .spatials import morphology -from .spatials import SME diff --git a/stlearn/spatials/SME/__init__.py b/stlearn/spatial/SME/__init__.py similarity index 52% rename from stlearn/spatials/SME/__init__.py rename to stlearn/spatial/SME/__init__.py index 88321427..8fffb497 100644 --- a/stlearn/spatials/SME/__init__.py +++ b/stlearn/spatial/SME/__init__.py @@ -1,2 +1,8 @@ -from .normalize import SME_normalize from .impute import SME_impute0, pseudo_spot +from .normalize import SME_normalize + +__all__ = [ + "SME_normalize", + "SME_impute0", + "pseudo_spot", +] diff --git a/stlearn/spatials/SME/_weighting_matrix.py b/stlearn/spatial/SME/_weighting_matrix.py similarity index 96% rename from stlearn/spatials/SME/_weighting_matrix.py rename to stlearn/spatial/SME/_weighting_matrix.py index 49553763..12848161 100644 --- a/stlearn/spatials/SME/_weighting_matrix.py +++ b/stlearn/spatial/SME/_weighting_matrix.py @@ -1,8 +1,8 @@ -from sklearn.metrics import pairwise_distances -from typing import Optional, Union -from anndata import AnnData +from typing import Literal + import numpy as np -from ..._compat import Literal +from anndata import AnnData +from sklearn.metrics import pairwise_distances from tqdm import tqdm _PLATFORM = Literal["Visium", "Old_ST"] @@ -19,13 +19,15 @@ def calculate_weight_matrix( adata: AnnData, - adata_imputed: Union[AnnData, None] = None, + adata_imputed: AnnData | None = None, pseudo_spots: bool = False, platform: _PLATFORM = "Visium", -) -> Optional[AnnData]: - from sklearn.linear_model import LinearRegression +) -> AnnData | None: import math + from sklearn.linear_model import LinearRegression + + rate: float if platform == "Visium": img_row = adata.obs["imagerow"] img_col = adata.obs["imagecol"] @@ -101,14 +103,15 @@ def calculate_weight_matrix( adata.uns["gene_expression_correlation"] * adata.uns["morphological_distance"] ) + return adata def impute_neighbour( adata: AnnData, - count_embed: Union[np.ndarray, None] = None, + count_embed: np.ndarray, weights: _WEIGHTING_MATRIX = "weights_matrix_all", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: coor = adata.obs[["imagecol", "imagerow"]] weights_matrix = adata.uns[weights] @@ -123,7 +126,6 @@ def impute_neighbour( bar_format="{l_bar}{bar} [ time left: {remaining} ]", ) as pbar: for i in range(len(coor)): - main_weights = weights_matrix[i] if weights == "physical_distance": diff --git a/stlearn/spatials/SME/impute.py b/stlearn/spatial/SME/impute.py similarity index 89% rename from stlearn/spatials/SME/impute.py rename to stlearn/spatial/SME/impute.py index 12dcebac..68a20dc3 100644 --- a/stlearn/spatials/SME/impute.py +++ b/stlearn/spatial/SME/impute.py @@ -1,18 +1,20 @@ -from typing import Optional, Union -from anndata import AnnData from pathlib import Path +from typing import Literal + import numpy as np -from scipy.sparse import csr_matrix import pandas as pd +import scipy +from anndata import AnnData +from scipy.sparse import csr_matrix + +import stlearn + from ._weighting_matrix import ( + _PLATFORM, + _WEIGHTING_MATRIX, calculate_weight_matrix, impute_neighbour, - _WEIGHTING_MATRIX, - _PLATFORM, ) -import stlearn -import scipy -from ..._compat import Literal def SME_impute0( @@ -21,10 +23,10 @@ def SME_impute0( weights: _WEIGHTING_MATRIX = "weights_matrix_all", platform: _PLATFORM = "Visium", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ - using spatial location (S), tissue morphological feature (M) and gene expression (E) information to impute missing - values + using spatial location (S), tissue morphological feature (M) and gene + expression (E) information to impute missing values Parameters ---------- @@ -34,10 +36,10 @@ def SME_impute0( input data, can be `raw` counts or log transformed data weights weighting matrix for imputation. - if `weights_matrix_all`, matrix combined all information from spatial location (S), - tissue morphological feature (M) and gene expression (E) - if `weights_matrix_pd_md`, matrix combined information from spatial location (S), - tissue morphological feature (M) + if `weights_matrix_all`, matrix combined all information from spatial + location (S), tissue morphological feature (M) and gene expression (E) + if `weights_matrix_pd_md`, matrix combined information from spatial + location (S), tissue morphological feature (M) platform `Visium` or `Old_ST` copy @@ -46,6 +48,8 @@ def SME_impute0( ------- Anndata """ + adata = adata.copy() if copy else adata + if use_data == "raw": if isinstance(adata.X, csr_matrix): count_embed = adata.X.toarray() @@ -86,23 +90,25 @@ def SME_impute0( def pseudo_spot( adata: AnnData, - tile_path: Union[Path, str] = Path("/tmp/tiles"), + tile_path: Path | str = Path("/tmp/tiles"), use_data: str = "raw", - crop_size: int = "auto", + crop_size: str | int = "auto", platform: _PLATFORM = "Visium", weights: _WEIGHTING_MATRIX = "weights_matrix_all", copy: _COPY = "pseudo_spot_adata", -) -> Optional[AnnData]: +) -> AnnData | None: """\ - using spatial location (S), tissue morphological feature (M) and gene expression (E) information to impute - gap between spots and increase resolution for gene detection + using spatial location (S), tissue morphological feature (M) and gene + expression (E) information to impute gap between spots and increase resolution + for gene detection Parameters ---------- adata Annotated data matrix. use_data - Input data, can be `raw` counts, log transformed data or dimension reduced space(`X_pca` and `X_umap`) + Input data, can be `raw` counts, log transformed data or dimension + reduced space(`X_pca` and `X_umap`) tile_path Path to save spot image tiles crop_size @@ -110,10 +116,10 @@ def pseudo_spot( if `auto`, automatically detect crop size weights Weighting matrix for imputation. - if `weights_matrix_all`, matrix combined all information from spatial location (S), - tissue morphological feature (M) and gene expression (E) - if `weights_matrix_pd_md`, matrix combined information from spatial location (S), - tissue morphological feature (M) + if `weights_matrix_all`, matrix combined all information from spatial + location (S), tissue morphological feature (M) and gene expression (E) + if `weights_matrix_pd_md`, matrix combined information from spatial + location (S), tissue morphological feature (M) platform `Visium` or `Old_ST` copy @@ -124,15 +130,17 @@ def pseudo_spot( ------- Anndata """ - from sklearn.linear_model import LinearRegression import math + from sklearn.linear_model import LinearRegression + + adata = adata.copy() if copy else adata + if platform == "Visium": img_row = adata.obs["imagerow"] img_col = adata.obs["imagecol"] array_row = adata.obs["array_row"] array_col = adata.obs["array_col"] - rate = 3 obs_df_ = adata.obs[["array_row", "array_col"]].copy() obs_df_.loc[:, "array_row"] = obs_df_["array_row"].apply(lambda x: x - 2 / 3) obs_df = adata.obs[["array_row", "array_col"]].copy() @@ -145,7 +153,6 @@ def pseudo_spot( img_col = adata.obs["imagecol"] array_row = adata.obs_names.map(lambda x: x.split("x")[1]) array_col = adata.obs_names.map(lambda x: x.split("x")[0]) - rate = 1.5 obs_df_left = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, @@ -276,10 +283,14 @@ def pseudo_spot( pseudo_spot_adata = AnnData(impute_df, obs=obs_df) pseudo_spot_adata.uns["spatial"] = adata.uns["spatial"] + actual_crop_size: int if crop_size == "auto": - crop_size = round(unit / 2) - - stlearn.pp.tiling(pseudo_spot_adata, tile_path, crop_size=crop_size) + actual_crop_size = round(unit / 2) + elif isinstance(crop_size, int): + actual_crop_size = crop_size + else: + raise ValueError(f"crop_size must be 'auto' or an integer, got {crop_size}") + stlearn.pp.tiling(pseudo_spot_adata, tile_path, crop_size=actual_crop_size) stlearn.pp.extract_feature(pseudo_spot_adata) @@ -319,7 +330,7 @@ def _merge( adata1: AnnData, adata2: AnnData, copy: bool = True, -) -> Optional[AnnData]: +) -> AnnData | None: merged_df = adata1.to_df().append(adata2.to_df()) merged_df_obs = adata1.obs.append(adata2.obs) merged_adata = AnnData(merged_df, obs=merged_df_obs) diff --git a/stlearn/spatials/SME/normalize.py b/stlearn/spatial/SME/normalize.py similarity index 84% rename from stlearn/spatials/SME/normalize.py rename to stlearn/spatial/SME/normalize.py index 83b132f9..39f65207 100644 --- a/stlearn/spatials/SME/normalize.py +++ b/stlearn/spatial/SME/normalize.py @@ -1,13 +1,13 @@ -from typing import Optional -from anndata import AnnData import numpy as np -from scipy.sparse import csr_matrix import pandas as pd +from anndata import AnnData +from scipy.sparse import csr_matrix + from ._weighting_matrix import ( + _PLATFORM, + _WEIGHTING_MATRIX, calculate_weight_matrix, impute_neighbour, - _WEIGHTING_MATRIX, - _PLATFORM, ) @@ -17,30 +17,33 @@ def SME_normalize( weights: _WEIGHTING_MATRIX = "weights_matrix_all", platform: _PLATFORM = "Visium", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ - using spatial location (S), tissue morphological feature (M) and gene expression (E) information to normalize data. + using spatial location (S), tissue morphological feature (M) and gene + expression (E) information to normalize data. Parameters ---------- - adata + adata: Annotated data matrix. - use_data + use_data: Input data, can be `raw` counts or log transformed data - weights + weights: Weighting matrix for imputation. - if `weights_matrix_all`, matrix combined all information from spatial location (S), - tissue morphological feature (M) and gene expression (E) - if `weights_matrix_pd_md`, matrix combined information from spatial location (S), - tissue morphological feature (M) - platform + if `weights_matrix_all`, matrix combined all information from spatial + location (S), tissue morphological feature (M) and gene expression (E) + if `weights_matrix_pd_md`, matrix combined information from spatial + location (S), tissue morphological feature (M) + platform: `Visium` or `Old_ST` - copy + copy: Return a copy instead of writing to adata. Returns ------- Anndata """ + adata = adata.copy() if copy else adata + if use_data == "raw": if isinstance(adata.X, csr_matrix): count_embed = adata.X.toarray() diff --git a/stlearn/spatial/__init__.py b/stlearn/spatial/__init__.py new file mode 100644 index 00000000..d7034329 --- /dev/null +++ b/stlearn/spatial/__init__.py @@ -0,0 +1,11 @@ +# stlearn/spatial/__init__.py + +from . import SME, clustering, morphology, smooth, trajectory + +__all__ = [ + "clustering", + "smooth", + "trajectory", + "morphology", + "SME", +] diff --git a/stlearn/spatials/clustering/__init__.py b/stlearn/spatial/clustering/__init__.py similarity index 52% rename from stlearn/spatials/clustering/__init__.py rename to stlearn/spatial/clustering/__init__.py index be151100..7f1e86e7 100644 --- a/stlearn/spatials/clustering/__init__.py +++ b/stlearn/spatial/clustering/__init__.py @@ -1 +1,5 @@ from .localization import localization + +__all__ = [ + "localization", +] diff --git a/stlearn/spatials/clustering/localization.py b/stlearn/spatial/clustering/localization.py similarity index 83% rename from stlearn/spatials/clustering/localization.py rename to stlearn/spatial/clustering/localization.py index c91dd9c7..c45757e9 100644 --- a/stlearn/spatials/clustering/localization.py +++ b/stlearn/spatial/clustering/localization.py @@ -1,53 +1,53 @@ -from anndata import AnnData -from typing import Optional, Union import numpy as np import pandas as pd -from sklearn.cluster import DBSCAN +from anndata import AnnData from natsort import natsorted +from sklearn.cluster import DBSCAN def localization( adata: AnnData, use_label: str = "louvain", - eps: int = 20, - min_samples: int = 0, + eps: float = 20.0, + min_samples: int = 1, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Perform local cluster by using DBSCAN. Parameters ---------- - adata + adata: AnnData Annotated data matrix. - use_label + use_label: str, default = "louvain" Use label result of cluster method. - eps + eps: float, default 20.0 The maximum distance between two samples for one to be considered as in the neighborhood of the other. This is not a maximum bound on the distances of points within a cluster. This is the most important DBSCAN parameter to choose appropriately for your data set and distance function. - min_samples + min_samples: int, default = 1 The number of samples (or total weight) in a neighborhood for a point to be - considered as a core point. This includes the point itself. - copy + considered as a core point. This includes the point itself. Passed into DBSCAN's + min_samples parameter. + copy: bool, default = False Return a copy instead of writing to adata. Returns ------- Anndata """ + adata = adata.copy() if copy else adata + if "sub_cluster_labels" in adata.obs.columns: adata.obs = adata.obs.drop("sub_cluster_labels", axis=1) pd.set_option("mode.chained_assignment", None) subclusters_list = [] for i in adata.obs[use_label].unique(): - tmp = adata.obs[adata.obs[use_label] == i] - clustering = DBSCAN(eps=eps, min_samples=1, algorithm="kd_tree").fit( + clustering = DBSCAN(eps=eps, min_samples=min_samples, algorithm="kd_tree").fit( tmp[["imagerow", "imagecol"]] ) @@ -81,7 +81,7 @@ def localization( ), ) - labels_cat = adata.obs[use_label].cat.categories + labels_cat = list(map(int, adata.obs[use_label].cat.categories)) cat_ind = {labels_cat[i]: i for i in range(len(labels_cat))} adata.uns[use_label + "_index_dict"] = cat_ind diff --git a/stlearn/spatial/morphology/__init__.py b/stlearn/spatial/morphology/__init__.py new file mode 100644 index 00000000..3e5b88f5 --- /dev/null +++ b/stlearn/spatial/morphology/__init__.py @@ -0,0 +1,5 @@ +from .adjust import adjust + +__all__ = [ + "adjust", +] diff --git a/stlearn/spatials/morphology/adjust.py b/stlearn/spatial/morphology/adjust.py similarity index 87% rename from stlearn/spatials/morphology/adjust.py rename to stlearn/spatial/morphology/adjust.py index 1128ae1e..a97ec258 100644 --- a/stlearn/spatials/morphology/adjust.py +++ b/stlearn/spatial/morphology/adjust.py @@ -1,11 +1,9 @@ -from typing import Optional import numpy as np -from anndata import AnnData -from ..._compat import Literal import scipy.spatial as spatial +from anndata import AnnData from tqdm import tqdm -_SIMILARITY_MATRIX = Literal["cosine", "euclidean", "pearson", "spearman"] +from stlearn.types import _METHOD, _SIMILARITY_MATRIX def adjust( @@ -13,39 +11,40 @@ def adjust( use_data: str = "X_pca", radius: float = 50.0, rates: int = 1, - method="mean", - copy: bool = False, + method: _METHOD = "mean", similarity_matrix: _SIMILARITY_MATRIX = "cosine", -) -> Optional[AnnData]: + copy: bool = False, +) -> AnnData | None: """\ SME normalisation: Using spot location information and tissue morphological features to correct spot gene expression Parameters ---------- - adata + adata : AnnData Annotated data matrix. - use_data + use_data : str, default "X_pca" Input date to be adjusted by morphological features. choose one from ["raw", "X_pca", "X_umap"] - radius + radius: float, default 50.0 Radius to select neighbour spots. - rates - Strength for adjustment. - method - Method for disk smoothing. - choose one from ["means", "median"] - copy + rates: int, default 1 + Number of times to add the aggregated neighbor contribution. + Higher values increase the strength of morphological adjustment. + method: {'mean', 'median', 'sum'}, default 'mean' + Method for aggregating neighbor contributions. + similarity_matrix : {'cosine', 'euclidean', 'pearson', 'spearman'}, default 'cosine' + Method to calculate morphological similarity between spots. + copy : bool, default False Return a copy instead of writing to adata. - similarity_matrix - Matrix to calculate morphological similarity of two spots - choose one from ["cosine", "euclidean", "pearson", "spearman"] Returns ------- Depending on `copy`, returns or updates `adata` with the following fields. **[use_data]_morphology** : `adata.obsm` field Add SME normalised gene expression matrix """ + adata = adata.copy() if copy else adata + if "X_morphology" not in adata.obsm: raise ValueError("Please run the function stlearn.pp.extract_feature") coor = adata.obs[["imagecol", "imagerow"]] diff --git a/stlearn/spatial/smooth/__init__.py b/stlearn/spatial/smooth/__init__.py new file mode 100644 index 00000000..3e663461 --- /dev/null +++ b/stlearn/spatial/smooth/__init__.py @@ -0,0 +1,5 @@ +from .disk import disk + +__all__ = [ + "disk", +] diff --git a/stlearn/spatials/smooth/disk.py b/stlearn/spatial/smooth/disk.py similarity index 88% rename from stlearn/spatials/smooth/disk.py rename to stlearn/spatial/smooth/disk.py index 0517b267..a259aee4 100644 --- a/stlearn/spatials/smooth/disk.py +++ b/stlearn/spatial/smooth/disk.py @@ -1,8 +1,6 @@ -from typing import Optional, Union import numpy as np -from anndata import AnnData -import logging as logg import scipy.spatial as spatial +from anndata import AnnData def disk( @@ -12,7 +10,8 @@ def disk( rates: int = 1, method: str = "mean", copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: + adata = adata.copy() if copy else adata coor = adata.obs[["imagecol", "imagerow"]] count_embed = adata.obsm[use_data] @@ -48,7 +47,8 @@ def disk( adata.obsm[new_embed] = np.array(lag_coor) print( - 'Disk smoothing function is applied! The new data are stored in adata.obsm["X_diffmap_disk"]' + "Disk smoothing function is applied! The new data are stored in " + + 'adata.obsm["X_diffmap_disk"]' ) return adata if copy else None diff --git a/stlearn/spatials/trajectory/__init__.py b/stlearn/spatial/trajectory/__init__.py similarity index 59% rename from stlearn/spatials/trajectory/__init__.py rename to stlearn/spatial/trajectory/__init__.py index bd6c4820..f2922e65 100644 --- a/stlearn/spatials/trajectory/__init__.py +++ b/stlearn/spatial/trajectory/__init__.py @@ -1,14 +1,30 @@ +from .compare_transitions import compare_transitions +from .detect_transition_markers import ( + detect_transition_markers_branches, + detect_transition_markers_clades, +) from .global_level import global_level from .local_level import local_level from .pseudotime import pseudotime -from .weight_optimization import weight_optimizing_global, weight_optimizing_local -from .utils import lambda_dist, resistance_distance from .pseudotimespace import pseudotimespace_global, pseudotimespace_local -from .detect_transition_markers import ( - detect_transition_markers_clades, - detect_transition_markers_branches, -) -from .compare_transitions import compare_transitions - from .set_root import set_root from .shortest_path_spatial_PAGA import shortest_path_spatial_PAGA +from .utils import lambda_dist, resistance_distance +from .weight_optimization import weight_optimizing_global, weight_optimizing_local + +__all__ = [ + "global_level", + "local_level", + "pseudotime", + "weight_optimizing_global", + "weight_optimizing_local", + "lambda_dist", + "resistance_distance", + "pseudotimespace_global", + "pseudotimespace_local", + "detect_transition_markers_clades", + "detect_transition_markers_branches", + "compare_transitions", + "set_root", + "shortest_path_spatial_PAGA", +] diff --git a/stlearn/spatials/trajectory/compare_transitions.py b/stlearn/spatial/trajectory/compare_transitions.py similarity index 100% rename from stlearn/spatials/trajectory/compare_transitions.py rename to stlearn/spatial/trajectory/compare_transitions.py diff --git a/stlearn/spatials/trajectory/detect_transition_markers.py b/stlearn/spatial/trajectory/detect_transition_markers.py similarity index 72% rename from stlearn/spatials/trajectory/detect_transition_markers.py rename to stlearn/spatial/trajectory/detect_transition_markers.py index 56497ada..d41d493e 100644 --- a/stlearn/spatials/trajectory/detect_transition_markers.py +++ b/stlearn/spatial/trajectory/detect_transition_markers.py @@ -1,41 +1,56 @@ -from scipy.stats import spearmanr +import warnings + import numpy as np import pandas as pd -import warnings -import networkx as nx +from anndata import AnnData +from scipy.stats import spearmanr + from ...utils import _read_graph warnings.filterwarnings("ignore", category=RuntimeWarning) def detect_transition_markers_clades( - adata, - clade, - cutoff_spearman=0.4, - cutoff_pvalue=0.05, - screening_genes=None, - use_raw_count=False, + adata: AnnData, + clade: int, + cutoff_spearman: float = 0.4, + cutoff_pvalue: float = 0.05, + screening_genes: None | list[str] = None, + use_raw_count: bool = False, ): """\ Transition markers detection of a clade. Parameters ---------- - adata - Annotated data matrix. - clade - Name of a clade user wants to detect transition markers. - cutoff_spearman - The threshold of correlation coefficient. - cutoff_pvalue - The threshold of p-value. - screening_genes - List of customised genes. - use_raw_count + adata : AnnData + Annotated data matrix containing spatial transcriptomics data with + computed pseudotime and clade information. + clade : int + Numeric identifier of the clade for which to detect transition markers. + Should correspond to a clade ID present in the trajectory analysis. + cutoff_spearman : float, default 0.4 + The minimum Spearman correlation coefficient threshold for identifying + significant gene-pseudotime correlations. Must be between 0 and 1. + cutoff_pvalue : float, default 0.05 + The maximum p-value threshold for statistical significance testing. + Must be between 0 and 1. Lower values result in more stringent + statistical filtering. + screening_genes : list of str, optional + Custom list of gene names to restrict the analysis to. If None, + all genes in the dataset will be considered. Useful for focusing + on specific gene sets or reducing computational time. + use_raw_count : bool, default False True if user wants to use raw layer data. Returns ------- - Anndata + AnnData + The input AnnData object with additional information stored in + adata.uns about the detected transition markers, including: + - Correlation coefficients + - P-values + - Gene rankings + - Clade-specific marker information """ print("Detecting the transition markers of clade_" + str(clade) + "...") @@ -152,7 +167,7 @@ def get_rank_cor(adata, screening_genes=None, use_raw_count=True): tmp = tmp.to_df() else: tmp = adata.to_df() - if screening_genes != None: + if screening_genes is not None: tmp = tmp[screening_genes] dpt = adata.obs["dpt_pseudotime"].values genes = [] diff --git a/stlearn/spatials/trajectory/global_level.py b/stlearn/spatial/trajectory/global_level.py similarity index 79% rename from stlearn/spatials/trajectory/global_level.py rename to stlearn/spatial/trajectory/global_level.py index 6a898238..1a77985b 100644 --- a/stlearn/spatials/trajectory/global_level.py +++ b/stlearn/spatial/trajectory/global_level.py @@ -1,24 +1,22 @@ -from anndata import AnnData -from typing import Optional, Union -import numpy as np -import pandas as pd +import networkx import networkx as nx +import numpy as np +from anndata import AnnData from scipy.spatial.distance import cdist + from stlearn.utils import _read_graph -from sklearn.metrics import pairwise_distances def global_level( adata: AnnData, + list_clusters: list[str], + w: float, use_label: str = "louvain", use_rep: str = "X_pca", n_dims: int = 40, - list_clusters: list = [], return_graph: bool = False, - w: float = None, verbose: bool = True, - copy: bool = False, -) -> Optional[AnnData]: +) -> networkx.Graph | None: """\ Perform global sptial trajectory inference. @@ -28,17 +26,18 @@ def global_level( Annotated data matrix. list_clusters Setup a list of cluster to perform pseudo-space-time + w + Pseudo-spatio-temporal distance weight (balance between spatial effect and DPT) use_label Use label result of cluster method. return_graph Return PTS graph - w - Pseudo-spatio-temporal distance weight (balance between spatial effect and DPT) - copy - Return a copy instead of writing to adata. Returns ------- - Anndata + networkx.Graph: + + adata.uns["PTS_graph"]["graph"]: + adata.uns["PTS_graph"]["node_dict"]: """ assert w <= 1, "w should be in range 0 to 1" @@ -51,9 +50,13 @@ def global_level( inds_cat = {v: k for (k, v) in cat_inds.items()} # Query cluster - if type(list_clusters[0]) == str: - list_clusters = [cat_inds[label] for label in list_clusters] - query_nodes = list_clusters + if len(list_clusters) == 0: + print("No clusters specified, using all available clusters") + query_nodes = list(cat_inds.values()) + else: + if isinstance(list_clusters[0], str): + list_clusters = [cat_inds[int(label)] for label in list_clusters] + query_nodes = list_clusters query_nodes = ordering_nodes(query_nodes, use_label, adata) if verbose: @@ -72,19 +75,19 @@ def global_level( ].unique(): query_dict[int(j)] = int(i) order_dict[int(j)] = int(order) - order += 1 dm_list = [] sdm_list = [] order_big_dict = {} edge_list = [] + split_node = adata.uns["split_node"] for i, j in enumerate(query_nodes): order_big_dict[j] = int(i) if i == len(query_nodes) - 1: break - for j in adata.uns["split_node"][query_nodes[i]]: - for k in adata.uns["split_node"][query_nodes[i + 1]]: + for j in split_node[str(query_nodes[i])]: + for k in split_node[str(query_nodes[i + 1])]: edge_list.append((int(j), int(k))) # Calculate DPT distance matrix @@ -112,14 +115,15 @@ def global_level( centroid_dict = adata.uns["centroid_dict"] centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} - H_sub = H.edge_subgraph(edge_list) + H_sub: networkx.Graph = H.edge_subgraph(edge_list) if not nx.is_connected(H_sub.to_undirected()): raise ValueError( - "The chosen clusters are not available to construct the spatial trajectory! Please choose other path." + "The chosen clusters are not available to construct the spatial " + + "trajectory! Please choose other path." ) H_sub = nx.DiGraph(H_sub) prepare_root = [] - for node in adata.uns["split_node"][query_nodes[0]]: + for node in split_node[str(query_nodes[0])]: H_sub.add_edge(9999, int(node)) prepare_root.append(centroid_dict[int(node)]) @@ -137,7 +141,7 @@ def global_level( H_sub = nx.DiGraph(H_sub) prepare_root = [] - for node in adata.uns["split_node"][query_nodes[0]]: + for node in split_node[str(query_nodes[0])]: H_sub.add_edge(9999, int(node)) prepare_root.append(centroid_dict[int(node)]) @@ -177,13 +181,11 @@ def global_level( if return_graph: return H_sub + else: + return None -######################## -## Global level PTS ## -######################## - - +# Global level PTS def get_node(node_list, split_node): result = np.array([]) for node in node_list: @@ -201,42 +203,6 @@ def ordering_nodes(node_list, use_label, adata): return list(np.array(node_list)[np.argsort(mean_dpt)]) -# def dpt_distance_matrix(adata, cluster1, cluster2, use_label): -# tmp = adata.obs[adata.obs[use_label] == str(cluster1)] -# chosen_adata1 = adata[list(tmp.index)] -# tmp = adata.obs[adata.obs[use_label] == str(cluster2)] -# chosen_aadata = adata[list(tmp.index)] - -# sub_dpt1 = [] -# chosen_sub1 = chosen_adata1.obs["sub_cluster_labels"].unique() -# for i in range(0, len(chosen_sub1)): -# sub_dpt1.append( -# chosen_adata1.obs[ -# chosen_adata1.obs["sub_cluster_labels"] == chosen_sub1[i] -# ]["dpt_pseudotime"].median() -# ) - -# sub_dpt2 = [] -# chosen_sub2 = chosen_aadata.obs["sub_cluster_labels"].unique() -# for i in range(0, len(chosen_sub2)): -# sub_dpt2.append( -# chosen_aadata.obs[ -# chosen_aadata.obs["sub_cluster_labels"] == chosen_sub2[i] -# ]["dpt_pseudotime"].median() -# ) - -# dm = cdist( -# np.array(sub_dpt1).reshape(-1, 1), -# np.array(sub_dpt2).reshape(-1, 1), -# lambda u, v: v - u, -# ) -# from sklearn.preprocessing import MinMaxScaler -# scaler = MinMaxScaler() -# scale_dm = scaler.fit_transform(dm) -# # scale_dm = (dm + (-np.min(dm))) / np.max(dm) -# return scale_dm - - def spatial_distance_matrix(adata, cluster1, cluster2, use_label): tmp = adata.obs[adata.obs[use_label] == str(cluster1)] chosen_adata1 = adata[list(tmp.index)] @@ -258,8 +224,6 @@ def spatial_distance_matrix(adata, cluster1, cluster2, use_label): sdm = cdist(np.array(sub_coord1), np.array(sub_coord2), "euclidean") - from sklearn.preprocessing import MinMaxScaler - # scaler = MinMaxScaler() # scale_sdm = scaler.fit_transform(sdm) scale_sdm = sdm / np.max(sdm) @@ -304,8 +268,6 @@ def ge_distance_matrix(adata, cluster1, cluster2, use_label, use_rep, n_dims): results.append(cdist(sub_coord1[i], sub_coord2[j], "cosine").mean()) results = np.array(results).reshape(len(sub_coord1), len(sub_coord2)) - from sklearn.preprocessing import MinMaxScaler - # scaler = MinMaxScaler() # scale_sdm = scaler.fit_transform(results) scale_sdm = results / np.max(results) diff --git a/stlearn/spatials/trajectory/local_level.py b/stlearn/spatial/trajectory/local_level.py similarity index 79% rename from stlearn/spatials/trajectory/local_level.py rename to stlearn/spatial/trajectory/local_level.py index c68888f9..56f1b3ec 100644 --- a/stlearn/spatials/trajectory/local_level.py +++ b/stlearn/spatial/trajectory/local_level.py @@ -1,8 +1,5 @@ -from anndata import AnnData -from typing import Optional, Union import numpy as np -from stlearn.em import run_pca, run_diffmap -from stlearn.pp import neighbors +from anndata import AnnData from scipy.spatial.distance import cdist @@ -13,31 +10,34 @@ def local_level( w: float = 0.5, return_matrix: bool = False, verbose: bool = True, - copy: bool = False, -) -> Optional[AnnData]: - +) -> np.ndarray | None: """\ Perform local sptial trajectory inference (required run pseudotime first). Parameters ---------- - adata + adata: Annotated data matrix. - use_label + use_label: Use label result of cluster method. - cluster + cluster: Choose cluster to perform local spatial trajectory inference. - threshold - Threshold to find the significant connection for PAGA graph. - w + w: float, default=0.5 Pseudo-spatio-temporal distance weight (balance between spatial effect and DPT) - return_matrix + return_matrix: Return PTS matrix for local level - copy - Return a copy instead of writing to adata. + verbose : bool, default=True + Whether to print progress information. Returns ------- - Anndata + np.ndarray: the STDM (spatio-temporal distance matrix) - weighted combination of + spatial and temporal distances. + + adata["nonabs_dpt_distance_matrix"]: np.ndarray + Pseudotime distance (difference between values) matrix + + adata["nonabs_dpt_distance_matrix"]: np.ndarray + STDM """ if verbose: print("Start construct trajectory for subcluster " + str(cluster)) @@ -81,5 +81,5 @@ def local_level( if return_matrix: return stdm - - return adata if copy else None + else: + return None diff --git a/stlearn/spatials/trajectory/pseudotime.py b/stlearn/spatial/trajectory/pseudotime.py similarity index 85% rename from stlearn/spatials/trajectory/pseudotime.py rename to stlearn/spatial/trajectory/pseudotime.py index 0c9df496..4615617e 100644 --- a/stlearn/spatials/trajectory/pseudotime.py +++ b/stlearn/spatial/trajectory/pseudotime.py @@ -1,21 +1,25 @@ -from anndata import AnnData -from typing import Optional, Union +import networkx as nx import numpy as np import pandas as pd -import networkx as nx -from scipy.spatial.distance import cdist -import scanpy +import scanpy as sc +from anndata import AnnData +from sklearn.neighbors import NearestCentroid + +from stlearn.pp import neighbors +from stlearn.spatial.clustering import localization +from stlearn.spatial.morphology import adjust +from stlearn.types import _METHOD def pseudotime( adata: AnnData, - use_label: str = None, + use_label: str = "louvain", eps: float = 20, n_neighbors: int = 25, use_rep: str = "X_pca", threshold: float = 0.01, radius: int = 50, - method: str = "mean", + method: _METHOD = "mean", threshold_spots: int = 5, use_sme: bool = False, reverse: bool = False, @@ -23,37 +27,36 @@ def pseudotime( max_nodes: int = 4, run_knn: bool = False, copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Perform pseudotime analysis. Parameters ---------- - adata + adata: Annotated data matrix. - use_label + use_label: Use label result of cluster method. - eps + eps: The maximum distance between two samples for one to be considered as in the neighborhood of the other. This is not a maximum bound on the distances of points within a cluster. This is the most important DBSCAN parameter to choose appropriately for your data set and distance function. - threshold + threshold: Threshold to find the significant connection for PAGA graph. - radius + radius: radius to adjust data for diffusion map - method + method: method to adjust the data. - use_sme + use_sme: Use adjusted feature by SME normalization or not - reverse + reverse: Reverse the pseudotime score - pseudotime_key + pseudotime_key: Key to store pseudotime - max_nodes + max_nodes: Maximum number of node in available paths - copy + copy: Return a copy instead of writing to adata. Returns ------- @@ -69,30 +72,21 @@ def pseudotime( except: pass - assert use_label != None, "Please choose the right `use_label`!" - - # Localize - from stlearn.spatials.clustering import localization - - if "sub_clusters_laber" not in adata.obs.columns: + if "sub_cluster_labels" not in adata.obs.columns: localization(adata, use_label=use_label, eps=eps) # Running knn if run_knn: - from stlearn.pp import neighbors - neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep, random_state=0) # Running paga - scanpy.tl.paga(adata, groups=use_label) + sc.tl.paga(adata, groups=use_label) # Denoising the graph - scanpy.tl.diffmap(adata) + sc.tl.diffmap(adata) if use_sme: - from stlearn.spatials.morphology import adjust - adjust(adata, use_data="X_diffmap", radius=radius, method=method) adata.obsm["X_diffmap"] = adata.obsm["X_diffmap_morphology"] @@ -100,16 +94,13 @@ def pseudotime( cnt_matrix = adata.uns["paga"]["connectivities"].toarray() # Filter by threshold - cnt_matrix[cnt_matrix < threshold] = 0.0 cnt_matrix = pd.DataFrame(cnt_matrix) # Mapping louvain label to subcluster - - cat_ind = adata.uns[use_label + "_index_dict"] - + cat_inds = adata.uns[use_label + "_index_dict"] split_node = {} - for label in adata.obs[use_label].unique(): + for label in adata.obs[use_label].cat.categories: meaningful_sub = [] for i in adata.obs[adata.obs[use_label] == label][ "sub_cluster_labels" @@ -120,10 +111,12 @@ def pseudotime( ): meaningful_sub.append(i) - split_node[cat_ind[label]] = meaningful_sub + label = cat_inds[int(label)] + split_node[label] = meaningful_sub adata.uns["threshold_spots"] = threshold_spots - adata.uns["split_node"] = split_node + # split_node has string keys for rest of code/plotting (names a strings) + adata.uns["split_node"] = {str(k): v for k, v in split_node.items()} # Replicate louvain label row to prepare for subcluster connection # matrix construction @@ -159,8 +152,6 @@ def pseudotime( adata.uns["global_graph"]["node_dict"] = node_convert # Create centroid dict for subclusters - from sklearn.neighbors import NearestCentroid - clf = NearestCentroid() clf.fit(adata.obs[["imagecol", "imagerow"]].values, adata.obs["sub_cluster_labels"]) centroid_dict = dict(zip(clf.classes_.astype(int), clf.centroids_)) @@ -178,10 +169,9 @@ def closest_node(node, nodes): centroid_dict[int(cl)] = new_centroid adata.uns["centroid_dict"] = centroid_dict - centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} # Running diffusion pseudo-time - scanpy.tl.dpt(adata) + sc.tl.dpt(adata) if reverse: adata.obs[pseudotime_key] = 1 - adata.obs[pseudotime_key] @@ -191,9 +181,7 @@ def closest_node(node, nodes): return adata if copy else None -######## utils ######## - - +# Utils def replace_with_dict(ar, dic): # Extract out keys and values k = np.array(list(dic.keys()), dtype=object) @@ -213,7 +201,6 @@ def selection_sort(x): def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key): - # Read original PAGA graph G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray()) edge_weights = nx.get_edge_attributes(G, "weight") @@ -248,7 +235,8 @@ def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key adata.uns["available_paths"] = all_paths print( - "All available trajectory paths are stored in adata.uns['available_paths'] with length < " + "All available trajectory paths are stored in adata.uns['available_paths'] " + + "with length < " + str(max_nodes) + " nodes" ) diff --git a/stlearn/spatials/trajectory/pseudotimespace.py b/stlearn/spatial/trajectory/pseudotimespace.py similarity index 55% rename from stlearn/spatials/trajectory/pseudotimespace.py rename to stlearn/spatial/trajectory/pseudotimespace.py index 230d0ff0..60573d10 100644 --- a/stlearn/spatials/trajectory/pseudotimespace.py +++ b/stlearn/spatial/trajectory/pseudotimespace.py @@ -1,8 +1,10 @@ +from typing import Literal + from anndata import AnnData -from typing import Optional, Union -from .weight_optimization import weight_optimizing_global, weight_optimizing_local + from .global_level import global_level from .local_level import local_level +from .weight_optimization import weight_optimizing_global, weight_optimizing_local def pseudotimespace_global( @@ -10,36 +12,43 @@ def pseudotimespace_global( use_label: str = "louvain", use_rep: str = "X_pca", n_dims: int = 40, - list_clusters: list = [], - model: str = "spatial", + list_clusters=None, + model: Literal["spatial", "gene_expression", "mixed"] = "spatial", step=0.01, k=10, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Perform pseudo-time-space analysis with global level. Parameters ---------- - adata + adata: AnnData Annotated data matrix. - use_label + use_label: str, default = "louvain" Use label result of cluster method. - list_clusters - List of cluster used to reconstruct spatial trajectory. - w - Weighting factor to balance between spatial data and gene expression - step - Step for screeing weighting factor - k - The number of eigenvalues to be compared + use_rep: str, default = "X_pca" + Which obsm location to use. + n_dims: int, default = 40 + Number of dimensions to use in PCA + list_clusters: list, optional + List of cluster used to reconstruct spatial trajectory. If None, uses all + clusters. + model: Literal["spatial", "gene_expression", "mixed"] = "mixed", + Can be mixed, spatial or gene expression. spatial sets weight to 0, + gene expression sets weight to 1 and mixed uses the list_clusters, step and k. + step: float, default = 0.01 + Step for screening weighting factor. + k: int, default = 10 + The number of eigenvalues to be compared. Returns ------- Anndata """ - if model == "mixed": + if list_clusters is None: + list_clusters = [] + if model == "mixed": w = weight_optimizing_global( adata, use_label=use_label, list_clusters=list_clusters, step=step, k=k ) @@ -48,8 +57,9 @@ def pseudotimespace_global( elif model == "gene_expression": w = 1 else: - raise ValidationError( - "Please choose the right model! Available models: 'mixed', 'spatial' and 'gene_expression' " + raise ValueError( + "Please choose the right model! Available models: 'mixed', 'spatial' " + + "and 'gene_expression' " ) global_level( @@ -61,33 +71,38 @@ def pseudotimespace_global( n_dims=n_dims, ) + return adata + def pseudotimespace_local( adata: AnnData, use_label: str = "louvain", - cluster: list = [], - w: float = None, -) -> Optional[AnnData]: - + cluster=None, + w: float | None = None, +) -> AnnData | None: """\ Perform pseudo-time-space analysis with local level. Parameters ---------- - adata + adata: AnnData Annotated data matrix. - use_label + use_label: str, default = "louvain" Use label result of cluster method. - cluster - Cluster used to reconstruct intraregional spatial trajectory. - w + cluster: + Cluster used to reconstruct intra regional spatial trajectory. + w: Weighting factor to balance between spatial data and gene expression Returns ------- Anndata """ + if cluster is None: + cluster = [] if w is None: w = weight_optimizing_local(adata, use_label=use_label, cluster=cluster) local_level(adata, use_label=use_label, cluster=cluster, w=w) + + return adata diff --git a/stlearn/spatials/trajectory/set_root.py b/stlearn/spatial/trajectory/set_root.py similarity index 60% rename from stlearn/spatials/trajectory/set_root.py rename to stlearn/spatial/trajectory/set_root.py index b26c7909..65287ec1 100644 --- a/stlearn/spatials/trajectory/set_root.py +++ b/stlearn/spatial/trajectory/set_root.py @@ -1,37 +1,50 @@ -from anndata import AnnData -from typing import Optional, Union import numpy as np -from stlearn.spatials.trajectory.utils import _correlation_test_helper +from anndata import AnnData +from stlearn.spatial.trajectory.utils import _correlation_test_helper -def set_root(adata: AnnData, use_label: str, cluster: str, use_raw: bool = False): +def set_root(adata: AnnData, use_label: str, cluster: str, use_raw: bool = False): """\ - Automatically set the root index. + Automatically set the root index for trajectory analysis. Parameters ---------- - adata + adata: AnnData Annotated data matrix. - use_label + use_label: str Use label result of cluster method. - cluster - Choose cluster to use as root - use_raw - Use the raw layer + cluster: str + Cluster identifier to use as the root cluster. Must exist in + `adata.obs[use_label]`. + use_raw: bool, default False + If True, use `adata.raw.X` for calculations; otherwise use `adata.X`. Returns ------- - Root index + int + Index of the selected root cell in the AnnData object + Raises + ------ + ValueError + If the specified cluster is not found in the clustering results. + ZeroDivisionError + If the specified cluster contains no cells. """ tmp_adata = adata.copy() # Subset the data based on the chosen cluster + available_clusters = tmp_adata.obs[use_label].unique() + if str(cluster) not in available_clusters.astype(str): + raise ValueError( + f"Cluster '{cluster}' not found in available clusters: " + + "{sorted(available_clusters)}" + ) tmp_adata = tmp_adata[ tmp_adata.obs[tmp_adata.obs[use_label] == str(cluster)].index, : ] - if use_raw == True: + if use_raw: tmp_adata = tmp_adata.raw.to_adata() # Borrow from Cellrank to calculate CytoTrace score diff --git a/stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py b/stlearn/spatial/trajectory/shortest_path_spatial_PAGA.py similarity index 86% rename from stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py rename to stlearn/spatial/trajectory/shortest_path_spatial_PAGA.py index bfd6b359..958907bc 100644 --- a/stlearn/spatials/trajectory/shortest_path_spatial_PAGA.py +++ b/stlearn/spatial/trajectory/shortest_path_spatial_PAGA.py @@ -1,5 +1,6 @@ import networkx as nx -import numpy as np + +from stlearn.pl.utils import get_node from stlearn.utils import _read_graph @@ -25,7 +26,7 @@ def shortest_path_spatial_PAGA( key ].max() - # Force original PAGA to directed PAGA based on pseudotime + # Force original PAGA to a directed PAGA based on pseudotime edge_to_remove = [] for edge in H.edges: if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0: @@ -71,20 +72,6 @@ def shortest_path_spatial_PAGA( return shortest_path.split(",") -# get name of cluster by subcluster -def get_cluster(search, dictionary): - for cl, sub in dictionary.items(): - if search in sub: - return cl - - -def get_node(node_list, split_node): - result = np.array([]) - for node in node_list: - result = np.append(result, np.array(split_node[int(node)]).astype(int)) - return result.astype(int) - - def find_min_max_node(adata, key="dpt_pseudotime", use_label="leiden"): min_cluster = int(adata.obs[adata.obs[key] == 0][use_label].values[0]) max_cluster = int(adata.obs[adata.obs[key] == 1][use_label].values[0]) diff --git a/stlearn/spatials/trajectory/utils.py b/stlearn/spatial/trajectory/utils.py similarity index 86% rename from stlearn/spatials/trajectory/utils.py rename to stlearn/spatial/trajectory/utils.py index 54ea41be..d8cc4277 100644 --- a/stlearn/spatials/trajectory/utils.py +++ b/stlearn/spatial/trajectory/utils.py @@ -1,9 +1,18 @@ +import warnings + +import networkx as nx +import numpy as np from numpy import linalg as la +from scipy import linalg as spla +from scipy import sparse as sps +from scipy.sparse import csr_matrix, issparse, isspmatrix_csr, spmatrix +from scipy.sparse import linalg as sparse_spla +from scipy.stats import norm def lambda_dist(A1, A2, k=None, p=2, kind="laplacian"): - """The function is migrated from NetComp package. The lambda distance between graphs, which is defined as - d(G1,G2) = norm(L_1 - L_2) + """The function is migrated from NetComp package. The lambda distance between + graphs, which is defined as d(G1,G2) = norm(L_1 - L_2) where L_1 is a vector of the top k eigenvalues of the appropriate matrix associated with G1, and L2 is defined similarly. Parameters @@ -99,11 +108,11 @@ def resistance_distance( ] try: distance_vector = np.sum((R1 - R2) ** p, axis=1) - except ValueError: - raise InputError( + except ValueError as e: + raise ValueError( "Input matrices are different sizes. Please use " "renormalized resistance distance." - ) + ) from e if attributed: return distance_vector ** (1 / p) else: @@ -114,20 +123,7 @@ def resistance_distance( # Eigenstuff # ********** # Functions for calculating eigenstuff of graphs. - - -from scipy import sparse as sps -import numpy as np -from scipy.sparse import linalg as spla -from numpy import linalg as la - -from scipy.sparse import issparse - -###################### -## Helper Functions ## -###################### - - +# Helper Functions def _eigs(M, which="SR", k=None): """Helper function for getting eigenstuff. Parameters @@ -155,7 +151,7 @@ def _eigs(M, which="SR", k=None): raise ValueError("which must be either 'LR' or 'SR'.") M = M.astype(float) if issparse(M) and k < n - 1: - evals, evecs = spla.eigs(M, k=k, which=which) + evals, evecs = sparse_spla.eigs(M, k=k, which=which) else: try: M = M.todense() @@ -174,11 +170,7 @@ def _eigs(M, which="SR", k=None): return np.real(evals), np.real(evecs) -##################### -## Get Eigenstuff ## -##################### - - +# Get Eigenstuff def normalized_laplacian_eig(A, k=None): """Return the eigenstuff of the normalized Laplacian matrix of graph associated with adjacency matrix A. @@ -213,9 +205,7 @@ def normalized_laplacian_eig(A, k=None): nx.normalized_laplacian_matrix """ n, m = A.shape - ## - ## TODO: implement checks on the adjacency matrix - ## + # TODO: implement checks on the adjacency matrix degs = _flat(A.sum(axis=1)) # the below will break if inv_root_degs = [d ** (-1 / 2) if d > _eps else 0 for d in degs] @@ -234,18 +224,10 @@ def normalized_laplacian_eig(A, k=None): # Matrices associated with graphs. Also contains linear algebraic helper functions. # """ - -from scipy import sparse as sps -from scipy.sparse import issparse -import numpy as np - _eps = 10 ** (-10) # a small parameter -###################### -## Helper Functions ## -###################### - +# Helper Functions def _flat(D): """Flatten column or row matrices, as well as arrays.""" if issparse(D): @@ -274,11 +256,7 @@ def _pad(A, N): return A_pad -######################## -## Matrices of Graphs ## -######################## - - +# Matrices of Graphs def degree_matrix(A): """Diagonal degree matrix of graph with adjacency matrix A Parameters @@ -338,15 +316,6 @@ class UndefinedException(Exception): # Resistance matrix. Renormalized version, as well as conductance and commute matrices. # """ -import networkx as nx -from numpy import linalg as la -from scipy import linalg as spla -import numpy as np -from scipy.sparse import issparse - -# from netcomp.linalg.matrices import laplacian_matrix -# from netcomp.exception import UndefinedException - def resistance_matrix(A, check_connected=True): """Return the resistance matrix of G. @@ -391,7 +360,7 @@ def resistance_matrix(A, check_connected=True): G = nx.from_numpy_array(A) if not nx.is_connected(G): raise UndefinedException( - "Graph is not connected. " "Resistance matrix is undefined." + "Graph is not connected. Resistance matrix is undefined." ) L = laplacian_matrix(A) try: @@ -543,34 +512,7 @@ def conductance_matrix(A): return C -######################## -## CytoTrace wrapper ## -######################## - -from typing import ( - Any, - Dict, - List, - Tuple, - Union, - TypeVar, - Hashable, - Iterable, - Optional, - Sequence, -) -import numpy as np -import pandas as pd -from pandas import Series -from scipy.stats import norm -from numpy.linalg import norm as d_norm -from scipy.sparse import eye as speye -from scipy.sparse import diags, issparse, spmatrix, csr_matrix, isspmatrix_csr -from sklearn.cluster import KMeans -from pandas.api.types import infer_dtype, is_categorical_dtype -from scipy.sparse.linalg import norm as sparse_norm - - +# CytoTrace wrapper def _mat_mat_corr_sparse( X: csr_matrix, Y: np.ndarray, @@ -581,28 +523,24 @@ def _mat_mat_corr_sparse( n = X.shape[1] X_bar = np.reshape(np.array(X.mean(axis=1)), (-1, 1)) - X_std = np.reshape( - np.sqrt(np.array(X.power(2).mean(axis=1)) - (X_bar**2)), (-1, 1) - ) + X_std = np.reshape(np.sqrt(np.array(X.power(2).mean(axis=1)) - (X_bar**2)), (-1, 1)) y_bar = np.reshape(np.mean(Y, axis=0), (1, -1)) y_std = np.reshape(np.std(Y, axis=0), (1, -1)) - with np.warnings.catch_warnings(): - np.warnings.filterwarnings( - "ignore", r"invalid value encountered in true_divide" - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"invalid value encountered in true_divide") return (X @ Y - (n * X_bar * y_bar)) / ((n - 1) * X_std * y_std) def _correlation_test_helper( - X: Union[np.ndarray, spmatrix], + X: np.ndarray | spmatrix, Y: np.ndarray, - n_perms: Optional[int] = None, - seed: Optional[int] = None, + n_perms: int | None = None, + seed: int | None = None, confidence_level: float = 0.95, **kwargs, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ This function is borrow from cellrank. Compute the correlation between rows in matrix ``X`` columns of matrix ``Y``. @@ -624,30 +562,17 @@ def _correlation_test_helper( Keyword arguments for :func:`cellrank.ul._parallelize.parallelize`. Returns ------- - Correlations, p-values, corrected p-values, lower and upper bound of 95% confidence interval. - Each array if of shape ``(n_genes, n_lineages)``. + Correlations, p-values, corrected p-values, lower and upper bound of 95% + confidence interval. Each array if of shape ``(n_genes, n_lineages)``. """ - def perm_test_extractor( - res: Sequence[Tuple[np.ndarray, np.ndarray]] - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - pvals, corr_bs = zip(*res) - pvals = np.sum(pvals, axis=0) / float(n_perms) - - corr_bs = np.concatenate(corr_bs, axis=0) - corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile( - corr_bs, q=qh, axis=0 - ) - - return pvals, corr_ci_low, corr_ci_high - if not (0 <= confidence_level <= 1): raise ValueError( - f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`." + "Expected `confidence_level` to be in interval `[0, 1]`, " + + f"found `{confidence_level}`." ) n = X.shape[1] # genes x cells - ql = 1 - confidence_level - (1 - confidence_level) / 2.0 qh = confidence_level + (1 - confidence_level) / 2.0 if issparse(X) and not isspmatrix_csr(X): @@ -655,7 +580,8 @@ def perm_test_extractor( corr = _mat_mat_corr_sparse(X, Y) if issparse(X) else _mat_mat_corr_dense(X, Y) - # see: https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#Using_the_Fisher_transformation + # see: + # https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#Using_the_Fisher_transformation mean, se = np.arctanh(corr), 1.0 / np.sqrt(n - 3) z_score = (np.arctanh(corr) - np.arctanh(0)) * np.sqrt(n - 3) @@ -676,10 +602,8 @@ def _mat_mat_corr_dense(X: np.ndarray, Y: np.ndarray) -> np.ndarray: y_bar = np.reshape(np_mean(Y, axis=0), (1, -1)) y_std = np.reshape(np_std(Y, axis=0), (1, -1)) - with np.warnings.catch_warnings(): - np.warnings.filterwarnings( - "ignore", r"invalid value encountered in true_divide" - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"invalid value encountered in true_divide") return (X @ Y - (n * X_bar * y_bar)) / ((n - 1) * X_std * y_std) diff --git a/stlearn/spatial/trajectory/weight_optimization.py b/stlearn/spatial/trajectory/weight_optimization.py new file mode 100644 index 00000000..7c47405e --- /dev/null +++ b/stlearn/spatial/trajectory/weight_optimization.py @@ -0,0 +1,188 @@ +import networkx as nx +import numpy as np +import pandas as pd +from anndata import AnnData +from tqdm import tqdm + +from .global_level import global_level +from .local_level import local_level +from .utils import lambda_dist, resistance_distance + + +def weight_optimizing_global( + adata: AnnData, + use_label: str = "louvain", + list_clusters=None, + step=0.01, + k=10, + use_rep="X_pca", + n_dims=40, +): + if k <= 0: + raise ValueError(f"k must be positive, got {k}") + + # Determine effective k value based on available sub-clusters + actual_k = k + if use_label and list_clusters: + if "sub_cluster_labels" not in adata.obs.columns: + print( + "Warning: 'sub_cluster_labels' column not found. Using provided " + + "k value." + ) + else: + try: + filtered_data = adata.obs[adata.obs[use_label].isin(list_clusters)] + if len(filtered_data) == 0: + raise ValueError( + f"No cells found for clusters {list_clusters} " + + "in column '{use_label}'" + ) + + # Minimum 1 cluster, use K or max available sub-clusters + n_subclusters = len(filtered_data["sub_cluster_labels"].unique()) + actual_k = max(1, min(k, n_subclusters)) + + if actual_k != k: + print( + f"Adjusted k from {k} to {actual_k} based on available " + + "sub-clusters ({n_subclusters})" + ) + + except Exception as e: + print( + f"Warning: Could not determine sub-cluster count: {e}. " + + "Using provided k value." + ) + actual_k = k + + # Screening PTS graph + print("Screening PTS global graph...") + Gs = [] + j = 0 + total_iterations = int(1 / step + 1) + with tqdm( + total=total_iterations, + desc="Screening", + bar_format="{l_bar}{bar} [ time left: {remaining} ]", + ) as pbar: + for i in range(0, total_iterations): + weight = round(i * step, 2) + matrix = global_level( + adata, + use_label=use_label, + list_clusters=list_clusters, + use_rep=use_rep, + n_dims=n_dims, + w=weight, + return_graph=True, + verbose=False, + ) + Gs.append(nx.to_scipy_sparse_array(matrix)) + j = j + step + pbar.update(1) + + # Calculate the graph dissimilarity using Laplacian matrix + print("Calculate the graph dissimilarity using Laplacian matrix...") + result = [] + a1_list = [] + a2_list = [] + index = [] + w = 0 + with tqdm( + total=int(1 / step - 1), + desc="Calculating", + bar_format="{l_bar}{bar} [ time left: {remaining} ]", + ) as pbar: + for i in range(1, int(1 / step)): + w += step + a1 = lambda_dist(Gs[i], Gs[0], k=actual_k) + a2 = lambda_dist(Gs[i], Gs[-1], k=actual_k) + a1_list.append(a1) + a2_list.append(a2) + index.append(w) + result.append(np.absolute(1 - a1 / a2)) + pbar.update(1) + + screening_result = pd.DataFrame( + {"w": index, "A1": a1_list, "A2": a2_list, "Dissmilarity_Score": result} + ) + + adata.uns["screening_result_global"] = screening_result + + normalised_result = normalize_data(result) + + try: + optimized_ind = np.where(normalised_result == np.amin(normalised_result))[0][0] + opt_w = round(index[optimized_ind], 2) + print("The optimized weighting is:", str(opt_w)) + return opt_w + except: + print("The optimized weighting is: 0.5") + return 0.5 + + +def weight_optimizing_local( + adata: AnnData, use_label: str = "louvain", cluster=None, step=0.01 +): + # Screening PTS graph + print("Screening PTS local graph...") + Gs = [] + j = 0 + with tqdm( + total=int(1 / step + 1), + desc="Screening", + bar_format="{l_bar}{bar} [ time left: {remaining} ]", + ) as pbar: + for i in range(0, int(1 / step + 1)): + matrix = local_level( + adata, + use_label=use_label, + cluster=cluster, + w=round(j, 2), + verbose=False, + return_matrix=True, + ) + Gs.append(matrix) + j = j + step + pbar.update(1) + + # Calculate the graph dissimilarity using Laplacian matrix + print("Calculate the graph dissimilarity using Resistance distance...") + result = [] + a1_list = [] + a2_list = [] + index = [] + w = 0 + + with tqdm( + total=int(1 / step - 1), + desc="Calculating", + bar_format="{l_bar}{bar} [ time left: {remaining} ]", + ) as pbar: + for i in range(1, int(1 / step)): + w += step + a1 = resistance_distance(Gs[i], Gs[0]) + a2 = resistance_distance(Gs[i], Gs[-1]) + a1_list.append(a1) + a2_list.append(a2) + index.append(w) + result.append(np.absolute(1 - a1 / a2)) + pbar.update(1) + + screening_result = pd.DataFrame( + {"w": index, "A1": a1_list, "A2": a2_list, "Dissmilarity_Score": result} + ) + + adata.uns["screening_result_local"] = screening_result + + normalised_result = normalize_data(result) + + optimized_ind = np.where(normalised_result == np.amin(normalised_result))[0][0] + opt_w = round(index[optimized_ind], 2) + print("The optimized weighting is:", str(opt_w)) + + return opt_w + + +def normalize_data(data): + return (data - np.min(data)) / (np.max(data) - np.min(data)) diff --git a/stlearn/spatials/__init__.py b/stlearn/spatials/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/stlearn/spatials/morphology/__init__.py b/stlearn/spatials/morphology/__init__.py deleted file mode 100644 index 115a5979..00000000 --- a/stlearn/spatials/morphology/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .adjust import adjust diff --git a/stlearn/spatials/smooth/__init__.py b/stlearn/spatials/smooth/__init__.py deleted file mode 100644 index 70f1149d..00000000 --- a/stlearn/spatials/smooth/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .disk import disk diff --git a/stlearn/spatials/trajectory/weight_optimization.py b/stlearn/spatials/trajectory/weight_optimization.py deleted file mode 100644 index 8bc2b363..00000000 --- a/stlearn/spatials/trajectory/weight_optimization.py +++ /dev/null @@ -1,161 +0,0 @@ -import numpy as np -import pandas as pd -import networkx as nx -from .global_level import global_level -from .local_level import local_level -from .utils import lambda_dist, resistance_distance -from tqdm import tqdm - - -def weight_optimizing_global( - adata, - use_label=None, - list_clusters=None, - step=0.01, - k=10, - use_rep="X_pca", - n_dims=40, -): - # Screening PTS graph - print("Screening PTS global graph...") - Gs = [] - j = 0 - - with tqdm( - total=int(1 / step + 1), - desc="Screening", - bar_format="{l_bar}{bar} [ time left: {remaining} ]", - ) as pbar: - for i in range(0, int(1 / step + 1)): - - Gs.append( - nx.to_scipy_sparse_array( - global_level( - adata, - use_label=use_label, - list_clusters=list_clusters, - use_rep=use_rep, - n_dims=n_dims, - w=round(j, 2), - return_graph=True, - verbose=False, - ) - ) - ) - - j = j + step - pbar.update(1) - - # Calculate the graph dissimilarity using Laplacian matrix - print("Calculate the graph dissimilarity using Laplacian matrix...") - result = [] - a1_list = [] - a2_list = [] - indx = [] - w = 0 - k = len( - adata.obs[adata.obs[use_label].isin(list_clusters)][ - "sub_cluster_labels" - ].unique() - ) - with tqdm( - total=int(1 / step - 1), - desc="Calculating", - bar_format="{l_bar}{bar} [ time left: {remaining} ]", - ) as pbar: - for i in range(1, int(1 / step)): - w += step - a1 = lambda_dist(Gs[i], Gs[0], k=k) - a2 = lambda_dist(Gs[i], Gs[-1], k=k) - a1_list.append(a1) - a2_list.append(a2) - indx.append(w) - result.append(np.absolute(1 - a1 / a2)) - pbar.update(1) - - screening_result = pd.DataFrame( - {"w": indx, "A1": a1_list, "A2": a2_list, "Dissmilarity_Score": result} - ) - - adata.uns["screening_result_global"] = screening_result - - def NormalizeData(data): - return (data - np.min(data)) / (np.max(data) - np.min(data)) - - result = NormalizeData(result) - - try: - optimized_ind = np.where(result == np.amin(result))[0][0] - opt_w = round(indx[optimized_ind], 2) - print("The optimized weighting is:", str(opt_w)) - return opt_w - except: - print("The optimized weighting is: 0.5") - return 0.5 - - -def weight_optimizing_local(adata, use_label=None, cluster=None, step=0.01): - # Screening PTS graph - print("Screening PTS local graph...") - Gs = [] - j = 0 - with tqdm( - total=int(1 / step + 1), - desc="Screening", - bar_format="{l_bar}{bar} [ time left: {remaining} ]", - ) as pbar: - for i in range(0, int(1 / step + 1)): - - Gs.append( - local_level( - adata, - use_label=use_label, - cluster=cluster, - w=round(j, 2), - verbose=False, - return_matrix=True, - ) - ) - - j = j + step - pbar.update(1) - - # Calculate the graph dissimilarity using Laplacian matrix - print("Calculate the graph dissimilarity using Resistance distance...") - result = [] - a1_list = [] - a2_list = [] - indx = [] - w = 0 - - with tqdm( - total=int(1 / step - 1), - desc="Calculating", - bar_format="{l_bar}{bar} [ time left: {remaining} ]", - ) as pbar: - for i in range(1, int(1 / step)): - w += step - a1 = resistance_distance(Gs[i], Gs[0]) - a2 = resistance_distance(Gs[i], Gs[-1]) - a1_list.append(a1) - a2_list.append(a2) - indx.append(w) - result.append(np.absolute(1 - a1 / a2)) - pbar.update(1) - - screening_result = pd.DataFrame( - {"w": indx, "A1": a1_list, "A2": a2_list, "Dissmilarity_Score": result} - ) - - adata.uns["screening_result_local"] = screening_result - - def NormalizeData(data): - return (data - np.min(data)) / (np.max(data) - np.min(data)) - - result = NormalizeData(result) - - optimized_ind = np.where(result == np.amin(result))[0][0] - opt_w = round(indx[optimized_ind], 2) - print("The optimized weighting is:", str(opt_w)) - - return opt_w diff --git a/stlearn/tl.py b/stlearn/tl.py deleted file mode 100644 index 073ef289..00000000 --- a/stlearn/tl.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tools import clustering -from .tools.microenv import cci -from .tools.label import label diff --git a/stlearn/tl/__init__.py b/stlearn/tl/__init__.py new file mode 100644 index 00000000..eb55fc20 --- /dev/null +++ b/stlearn/tl/__init__.py @@ -0,0 +1,10 @@ +# stlearn/tl/__init__.py + +from . import cache, cci, clustering, label + +__all__ = [ + "cache", + "clustering", + "cci", + "label", +] diff --git a/stlearn/tl/cache/__init__.py b/stlearn/tl/cache/__init__.py new file mode 100644 index 00000000..9fe80973 --- /dev/null +++ b/stlearn/tl/cache/__init__.py @@ -0,0 +1,6 @@ +from .anndata import merge_h5ad_into_adata, write_subset_h5ad + +__all__ = [ + "write_subset_h5ad", + "merge_h5ad_into_adata", +] diff --git a/stlearn/tl/cache/anndata.py b/stlearn/tl/cache/anndata.py new file mode 100644 index 00000000..a208feb8 --- /dev/null +++ b/stlearn/tl/cache/anndata.py @@ -0,0 +1,51 @@ +import anndata as ad +import numpy as np +import pandas as pd + + +def write_subset_h5ad(adata, filename, obsm_keys=None, uns_keys=None): + """Write only specific obsm and uns components to H5AD""" + + # Create a minimal AnnData object with the same structure + minimal_adata = ad.AnnData( + X=np.zeros((adata.n_obs, 1)), + obs=adata.obs.index.to_frame(name="cell_id"), + var=pd.DataFrame(index=["placeholder"]), + ) + + if obsm_keys: + for key in obsm_keys: + if key in adata.obsm: + value = adata.obsm[key] + if isinstance(value, list): + value = np.array(value) + minimal_adata.obsm[key] = value + print(f"Added obsm['{key}'] with shape {value.shape}") + else: + print(f"Warning: obsm['{key}'] not found") + + if uns_keys: + for key in uns_keys: + if key in adata.uns: + minimal_adata.uns[key] = adata.uns[key] + print(f"Added uns['{key}']") + else: + print(f"Warning: uns['{key}'] not found") + + minimal_adata.write_h5ad(filename, compression="gzip", compression_opts=9) + print(f"Wrote subset to {filename}") + + +def merge_h5ad_into_adata(adata_main, h5ad_file): + adata_subset = ad.read_h5ad(h5ad_file) + print(f"Reading {h5ad_file}") + + for key, value in adata_subset.obsm.items(): + adata_main.obsm[key] = value + print(f"Added obsm['{key}'] with shape {value.shape}") + + for key, value in adata_subset.uns.items(): + adata_main.uns[key] = value + print(f"Added uns['{key}']") + + return adata_main diff --git a/stlearn/tl/cci/__init__.py b/stlearn/tl/cci/__init__.py new file mode 100644 index 00000000..5716bf1d --- /dev/null +++ b/stlearn/tl/cci/__init__.py @@ -0,0 +1,4 @@ +from .analysis import adj_pvals, grid, load_lrs, run, run_cci, run_lr_go +from .het import get_edges + +__all__ = ["load_lrs", "grid", "run", "adj_pvals", "run_lr_go", "run_cci", "get_edges"] diff --git a/stlearn/tools/microenv/cci/analysis.py b/stlearn/tl/cci/analysis.py similarity index 86% rename from stlearn/tools/microenv/cci/analysis.py rename to stlearn/tl/cci/analysis.py index adbb2d19..0905202c 100644 --- a/stlearn/tools/microenv/cci/analysis.py +++ b/stlearn/tl/cci/analysis.py @@ -1,60 +1,64 @@ -""" Wrapper function for performing CCI analysis, varrying the analysis based on - the inputted data / state of the anndata object. +"""Wrapper function for performing CCI analysis, varrying the analysis based on +the inputted data / state of the anndata object. """ import os + import numba -from numba import types -from numba.typed import List import numpy as np import pandas as pd -from typing import Union from anndata import AnnData +from statsmodels.stats.multitest import multipletests from tqdm import tqdm -from .base import calc_neighbours, get_lrs_scores, calc_distance -from .permutation import perform_spot_testing + +from .base import calc_distance, calc_neighbours, get_lrs_scores from .go import run_GO from .het import ( count, - get_neighbourhoods, get_data_for_counting, get_interaction_matrix, get_interaction_pvals, + get_neighbourhoods, grid_parallel, ) -from statsmodels.stats.multitest import multipletests +from .permutation import perform_spot_testing + -################################################################################ -# Functions related to Ligand-Receptor interactions # -################################################################################ -def load_lrs(names: Union[str, list, None] = None, species: str = "human") -> np.array: - """Loads inputted LR database, & concatenates into consistent database set of pairs without duplicates. If None loads 'connectomeDB2020_lit'. +# Functions related to Ligand-Receptor interactions +def load_lrs(names: str | list | None = None, species: str = "human") -> np.ndarray: + """Loads inputted LR database, & concatenates into consistent database set of + pairs without duplicates. If None loads 'connectomeDB2020_lit'. Parameters ---------- - names: list Databases to load, options: 'connectomeDB2020_lit' (literature verified), 'connectomeDB2020_put' (putative). If more than one specified, loads all & removes duplicates. - species: str Format of the LR genes, either 'human' or 'mouse'. + names: list + Databases to load, options: 'connectomeDB2020_lit' (literature verified), + 'connectomeDB2020_put' (putative). If more than one specified, loads all & + removes duplicates. + species: str + Format of the LR genes, either 'human' or 'mouse'. Returns ------- - lrs: np.array lr pairs from the database in format ['L1_R1', 'LN_RN'] + lrs: np.array + lr pairs from the database in format ['L1_R1', 'LN_RN'] """ - if type(names) == type(None): + if names is None: names = ["connectomeDB2020_lit"] - if type(names) == str: + if isinstance(names, str): names = [names] path = os.path.dirname(os.path.realpath(__file__)) dbs = [pd.read_csv(f"{path}/databases/{name}.txt", sep="\t") for name in names] lrs_full = [] for db in dbs: - lrs = [f"{db.values[i,0]}_{db.values[i,1]}" for i in range(db.shape[0])] + lrs = [f"{db.values[i, 0]}_{db.values[i, 1]}" for i in range(db.shape[0])] lrs_full.extend(lrs) - lrs_full = np.unique(lrs_full) + lrs_full_arr = np.unique(np.array(lrs_full)) # If dealing with mouse, need to reformat # if species == "mouse": genes1 = [lr_.split("_")[0] for lr_ in lrs_full] genes2 = [lr_.split("_")[1] for lr_ in lrs_full] - lrs_full = np.array( + lrs_full_arr = np.array( [ genes1[i][0] + genes1[i][1:].lower() @@ -65,14 +69,14 @@ def load_lrs(names: Union[str, list, None] = None, species: str = "human") -> np ] ) - return lrs_full + return lrs_full_arr def grid( adata, n_row: int = 10, n_col: int = 10, - use_label: str = None, + use_label: str | None = None, n_cpus: int = 1, verbose: bool = True, ): @@ -102,15 +106,16 @@ def grid( print("Gridding...") # Setting threads for paralellisation # - if type(n_cpus) != type(None): + if n_cpus is not None: numba.set_num_threads(n_cpus) # Retrieving the coordinates of each grid # n_squares = n_row * n_col cell_bcs = adata.obs_names.values.astype(str) - xs, ys = adata.obs["imagecol"].values.astype(int), adata.obs[ - "imagerow" - ].values.astype(int) + xs, ys = ( + adata.obs["imagecol"].values.astype(int), + adata.obs["imagerow"].values.astype(int), + ) grid_counts, xedges, yedges = np.histogram2d(xs, ys, bins=[n_col, n_row]) grid_counts, xedges, yedges = ( @@ -121,10 +126,10 @@ def grid( grid_expr = np.zeros((n_squares, adata.shape[1])) grid_coords = np.zeros((n_squares, 2)) - grid_cell_counts = np.zeros((n_squares), dtype=np.int64) - # If use_label specified, then will generate deconvolution information + grid_cell_counts = np.zeros(n_squares, dtype=np.int64) + # If use_label is specified, then it will generate deconvolution information cell_labels, cell_set, cell_info = None, None, None - if type(use_label) != type(None): + if use_label is not None: cell_labels = adata.obs[use_label].values.astype(str) cell_set = np.unique(cell_labels).astype(str) cell_info = np.zeros((n_squares, len(cell_set)), dtype=np.float64) @@ -142,7 +147,7 @@ def grid( grid_cell_counts, grid_expr, adata.X, - type(use_label) != type(None), + use_label is not None, cell_labels, cell_info, cell_set, @@ -161,7 +166,7 @@ def grid( grid_data.obsm["spatial"] = grid_coords grid_data.uns["spatial"] = adata.uns["spatial"] - if type(use_label) != type(None): + if use_label is not None and cell_info is not None and cell_set is not None: grid_data.uns[use_label] = pd.DataFrame( cell_info, index=grid_data.obs_names.values.astype(str), columns=cell_set ) @@ -169,14 +174,14 @@ def grid( grid_data.obs[use_label] = [cell_set[index] for index in max_indices] grid_data.obs[use_label] = grid_data.obs[use_label].astype("category") grid_data.obs[use_label] = grid_data.obs[use_label].cat.set_categories( - list(adata.obs[use_label].cat.categories) + adata.obs[use_label].cat.categories ) if f"{use_label}_colors" in adata.uns: grid_data.uns[f"{use_label}_colors"] = adata.uns[f"{use_label}_colors"] # Subsetting to only gridded spots that contain cells # grid_data = grid_data[grid_data.obs["n_cells"] > 0, :].copy() - if type(use_label) != type(None): + if use_label is not None: grid_data.uns[use_label] = grid_data.uns[use_label].loc[grid_data.obs_names, :] grid_data.uns["grid_counts"] = grid_counts @@ -188,12 +193,12 @@ def grid( def run( adata: AnnData, - lrs: np.array, + lrs: np.ndarray, min_spots: int = 10, - distance: int = None, + distance: float | None = None, n_pairs: int = 1000, - n_cpus: int = None, - use_label: str = None, + n_cpus: int | None = None, + use_label: str | None = None, adj_method: str = "fdr_bh", pval_adj_cutoff: float = 0.05, min_expr: float = 0, @@ -207,7 +212,7 @@ def run( ----------- adata: AnnData The data object. - lrs: np.array + lrs: np.ndarray The LR pairs to score/test for enrichment (in format 'L1_R1'). min_spots: int Minimum number of spots with an LR score for an LR to be considered for @@ -249,18 +254,20 @@ def run( adata.uns['lr_summary'] Summary of significant spots detected per LR, the LRs listed in the index is the same order of LRs in the columns of - results stored in adata.obsm below. Hence the order of this must be maintained. + results stored in adata.obsm below. Hence, the order of this must be + maintained. adata.obsm Additional keys are added; 'lr_scores', 'lr_sig_scores', 'p_vals', 'p_adjs', '-log10(p_adjs)'. All are numpy matrices, with columns referring to the LRs listed in adata.uns['lr_summary']. 'lr_scores' is the raw scores, while 'lr_sig_scores' is the same except only for significant scores; non-significant scores are set to zero. - adata.obsm['het'] - Only if use_label specified; contains the counts of the cell types found per spot. + adata.obsm['cci_het'] + Only if use_label specified; contains the counts of the cell types found + per spot. """ - # Setting threads for paralellisation # - if type(n_cpus) != type(None): + # Setting threads for parallelisation + if n_cpus is not None: numba.set_num_threads(n_cpus) # Making sure none of the var_names contains '_' already, these will need @@ -305,13 +312,13 @@ def run( ) # Conduct with cell heterogeneity info if label_transfer provided # - cell_het = type(use_label) != type(None) and use_label in adata.uns.keys() + cell_het = use_label is not None and use_label in adata.uns.keys() if cell_het: if verbose: - print("Calculating cell hetereogeneity...") + print("Calculating cell heterogeneity...") # Calculating cell heterogeneity # - count(adata, distance=distance, use_label=use_label, use_het=use_label) + count(adata, distance=distance, use_label=use_label) het_vals = ( np.array([1] * len(adata)) @@ -322,13 +329,13 @@ def run( """ 1. Filter any LRs without stored expression. """ # Calculating the lr_scores across spots for the inputted lrs # - lr_scores, lrs = get_lrs_scores(adata, lrs, neighbours, het_vals, min_expr) + lr_scores, new_lrs = get_lrs_scores(adata, lrs, neighbours, het_vals, min_expr) lr_bool = (lr_scores > 0).sum(axis=0) > min_spots - lrs = lrs[lr_bool] + new_lrs = new_lrs[lr_bool] lr_scores = lr_scores[:, lr_bool] if verbose: - print("Altogether " + str(len(lrs)) + " valid L-R pairs") - if len(lrs) == 0: + print("Altogether " + str(len(new_lrs)) + " valid L-R pairs") + if len(new_lrs) == 0: print("Exiting due to lack of valid LR pairs.") return @@ -337,7 +344,7 @@ def run( perform_spot_testing( adata, lr_scores, - lrs, + new_lrs, n_pairs, neighbours, het_vals, @@ -401,12 +408,12 @@ def adj_pvals( spot_padjs = multipletests(lr_ps, method=adj_method)[1] padjs[spot_indices, lr_i] = spot_padjs sig_scores[spot_indices[spot_padjs >= pval_adj_cutoff], lr_i] = 0 - elif type(correct_axis) == type(None): + elif correct_axis is None: padjs = ps.copy() sig_scores[padjs >= pval_adj_cutoff] = 0 else: raise Exception( - f"Invalid correct_axis input, must be one of: " f"'LR', 'spot', or None" + "Invalid correct_axis input, must be one of: 'LR', 'spot', or None" ) # Counting spots significant per lr # @@ -418,7 +425,7 @@ def adj_pvals( adata.uns["lr_summary"].loc[:, "n_spots_sig_pval"] = lr_counts_pval new_order = np.argsort(-adata.uns["lr_summary"].loc[:, "n_spots_sig"].values) adata.uns["lr_summary"] = adata.uns["lr_summary"].iloc[new_order, :] - print(f"Updated adata.uns[lr_summary]") + print("Updated adata.uns[lr_summary]") scores_ordered = scores[:, new_order] sig_scores_ordered = sig_scores[:, new_order] ps_ordered = ps[:, new_order] @@ -436,7 +443,7 @@ def run_lr_go( adata: AnnData, r_path: str, n_top: int = 100, - bg_genes: np.array = None, + bg_genes: np.ndarray | None = None, min_sig_spots: int = 1, species: str = "human", p_cutoff: float = 0.01, @@ -468,18 +475,19 @@ def run_lr_go( q_cutoff: float Q-value cutoff below which results will be returned. onts: str - As per clusterProfiler; One of "BP", "MF", and "CC" subontologies, or "ALL" for all three. + As per clusterProfiler; One of "BP", "MF", and "CC" subontologies, or "ALL" + for all three. Returns ------- adata: AnnData Relevant information stored in adata.uns['lr_go'] """ - #### Making sure inputted correct species #### + # Making sure inputted correct species all_species = ["human", "mouse"] if species not in all_species: - raise Exception(f"Got {species} for species, must be one of " f"{all_species}") + raise Exception(f"Got {species} for species, must be one of {all_species}") - #### Getting the genes from the top LR pairs #### + # Getting the genes from the top LR pairs if "lr_summary" not in adata.uns: raise Exception("Need to run st.tl.cci.run first.") lrs = adata.uns["lr_summary"].index.values.astype(str) @@ -487,12 +495,13 @@ def run_lr_go( top_lrs = lrs[n_sig > min_sig_spots][0:n_top] top_genes = np.unique([lr.split("_") for lr in top_lrs]) - ## Determining the background genes if not inputted ## - if type(bg_genes) == type(None): + # Determining the background genes if not inputted + if bg_genes is None: all_lrs = load_lrs("connectomeDB2020_put") - bg_genes = np.unique([lr_.split("_") for lr_ in all_lrs]) + all_genes = [lr_.split("_") for lr_ in all_lrs] + bg_genes = np.unique(all_genes) - #### Running the GO analysis #### + # Running the GO analysis go_results = run_GO( top_genes, bg_genes, @@ -507,9 +516,7 @@ def run_lr_go( print("GO results saved to adata.uns['lr_go']") -################################################################################ -# Functions for calling Celltype-Celltype interactions # -################################################################################ +# Functions for calling Celltype-Celltype interactions def run_cci( adata: AnnData, use_label: str, @@ -522,7 +529,8 @@ def run_cci( n_cpus: int = 1, verbose: bool = True, ): - """Calls significant celltype-celltype interactions based on cell-type data randomisation. + """Calls significant celltype-celltype interactions based on cell-type data + randomisation. Parameters ---------- @@ -590,14 +598,14 @@ def run_cci( subsetted to significant CCIs. """ # Setting threads for paralellisation # - if type(n_cpus) != type(None): + if n_cpus is not None: numba.set_num_threads(n_cpus) ran_lr = "lr_summary" in adata.uns ran_sig = False if not ran_lr else "n_spots_sig" in adata.uns["lr_summary"].columns if not ran_lr and not ran_sig: raise Exception( - "No LR results testing results found, " "please run st.tl.cci.run first" + "No LR results testing results found, please run st.tl.cci.run first" ) # Ensuring compatibility with current way of adding label_transfer to object @@ -609,8 +617,7 @@ def run_cci( # Getting the cell/tissue types that we are actually testing # if obs_key not in adata.obs: raise Exception( - f"Missing {obs_key} from adata.obs, need this even if " - f"using mixture mode." + f"Missing {obs_key} from adata.obs, need this even if using mixture mode." ) tissue_types = adata.obs[obs_key].values.astype(str) all_set = np.unique(tissue_types) @@ -638,12 +645,12 @@ def run_cci( msg = msg + "Rows do not correspond to adata.obs_names.\n" raise Exception(msg) - #### Checking for case where have cell types that are never dominant - #### in a spot, so need to include these in all_set + # Checking for case where have cell types that are never dominant + # in a spot, so need to include these in all_set if len(all_set) < adata.uns[uns_key].shape[1]: all_set = adata.uns[uns_key].columns.values.astype(str) - #### Getting minimum necessary information for edge counting #### + # Getting minimum necessary information for edge counting if verbose: print("Getting cached neighbourhood information...") # Getting the neighbourhoods # @@ -675,20 +682,21 @@ def run_cci( per_lr_cci = {} # Per LR significant CCI counts # per_lr_cci_pvals = {} # Per LR CCI p-values # per_lr_cci_raw = {} # Per LR raw CCI counts # - lr_n_spot_cci = np.zeros((lr_summary.shape[0])) - lr_n_spot_cci_sig = np.zeros((lr_summary.shape[0])) - lr_n_cci_sig = np.zeros((lr_summary.shape[0])) + lr_n_spot_cci = np.zeros(lr_summary.shape[0]) + lr_n_spot_cci_sig = np.zeros(lr_summary.shape[0]) + lr_n_cci_sig = np.zeros(lr_summary.shape[0]) with tqdm( total=len(best_lrs), - desc=f"Counting celltype-celltype interactions per LR and permutating {n_perms} times.", + desc="Counting celltype-celltype interactions per LR and permuting " + + f"{n_perms} times.", bar_format="{l_bar}{bar} [ time left: {remaining} ]", - disable=verbose == False, + disable=not verbose, ) as pbar: for i, best_lr in enumerate(best_lrs): - l, r = best_lr.split("_") + ligand, receptor = best_lr.split("_") - L_bool = lr_expr.loc[:, l].values > 0 - R_bool = lr_expr.loc[:, r].values > 0 + L_bool = lr_expr.loc[:, ligand].values > 0 + R_bool = lr_expr.loc[:, receptor].values > 0 lr_index = np.where(adata.uns["lr_summary"].index.values == best_lr)[0][0] sig_bool = adata.obsm[col][:, lr_index] > 0 diff --git a/stlearn/tools/microenv/cci/base.py b/stlearn/tl/cci/base.py similarity index 67% rename from stlearn/tools/microenv/cci/base.py rename to stlearn/tl/cci/base.py index ecea7ecc..7824a526 100644 --- a/stlearn/tools/microenv/cci/base.py +++ b/stlearn/tl/cci/base.py @@ -1,36 +1,47 @@ import numpy as np import pandas as pd import scipy as sc -from numba import njit, prange -from numba.typed import List import scipy.spatial as spatial from anndata import AnnData +from numba import njit, prange +from numba.typed import List + from .het import create_grids def lr( adata: AnnData, use_lr: str = "cci_lr", - distance: float = None, + distance: float | None = None, verbose: bool = True, - neighbours: list = None, + neighbours: list | None = None, fast: bool = True, ) -> AnnData: - """Calculate the proportion of known ligand-receptor co-expression among the neighbouring spots or within spots + """Calculate the proportion of known ligand-receptor co-expression among the + neighbouring spots or within spots Parameters ---------- - adata: AnnData The data object to scan - use_lr: str object to keep the result (default: adata.uns['cci_lr']) - distance: float Distance to determine the neighbours (default: closest), distance=0 means within spot - neighbours: list List of the neighbours for each spot, if None then computed. Useful for speeding up function. - fast: bool Whether to use the fast implimentation or not. + adata: AnnData + The data object to scan + use_lr: str + object to keep the result (default: adata.uns['cci_lr']) + distance: float + Distance to determine the neighbours (default: closest), distance=0 means + within spot. If distance is None gets it from adata.uns["spatial"] + neighbours: list + List of the neighbours for each spot, if None then computed. Useful for + speeding up function. + fast: bool + Whether to use the fast implementation or not. Returns ------- - adata: AnnData The data object including the results + adata: AnnData + The data object including the results """ - # automatically calculate distance if not given, won't overwrite distance=0 which is within-spot + # automatically calculate distance if not given, won't overwrite distance=0 + # which is within-spot distance = calc_distance(adata, distance) # # expand the LR pairs list by swapping ligand-receptor positions @@ -41,7 +52,7 @@ def lr( print("Altogether " + str(spot_lr1.shape[1]) + " valid L-R pairs") # get neighbour spots for each spot according to the specified distance - if type(neighbours) == type(None): + if neighbours is None: neighbours = calc_neighbours(adata, distance, index=fast) # Calculating the scores, can have either the fast or the pandas version # @@ -60,17 +71,21 @@ def lr( # return adata -def calc_distance(adata: AnnData, distance: float): +def calc_distance(adata: AnnData, distance: float | None) -> float: """Automatically calculate distance if not given, won't overwrite \ distance=0 which is within-spot. Parameters ---------- - adata: AnnData The data object to scan - distance: float Distance to determine the neighbours (default: closest), distance=0 means within spot + adata: AnnData + The data object to scan + distance: float + Distance to determine the neighbours (default: closest), distance=0 means + within spot Returns ------- - distance: float The automatically calcualted distance (or inputted distance) + distance: float + The automatically calculate distance (or inputted distance) """ if not distance and distance != 0: # for arranged-spots @@ -88,28 +103,36 @@ def calc_distance(adata: AnnData, distance: float): def get_lrs_scores( adata: AnnData, - lrs: np.array, - neighbours: np.array, - het_vals: np.array, + lrs: np.ndarray, + neighbours: np.ndarray, + het_vals: np.ndarray, min_expr: float, filter_pairs: bool = True, - spot_indices: np.array = None, + spot_indices: np.ndarray | None = None, ): """Gets the scores for the indicated set of LR pairs & the heterogeneity values. Parameters ---------- - adata: AnnData See run() doc-string. - lrs: np.array See run() doc-string. - neighbours: np.array Array of arrays with indices specifying neighbours of each spot. - het_vals: np.array Cell heterogeneity counts per spot. - min_expr: float Minimum gene expression of either L or R for spot to be considered to have reasonable score. - filter_pairs: bool Whether to filter to valid pairs or not. - spot_indices: np.array Array of integers speci + adata: AnnData + See run() doc-string. + lrs: np.ndarray + See run() doc-string. + neighbours: np.ndarray + Array of arrays with indices specifying neighbours of each spot. + het_vals: np.ndarray + Cell heterogeneity counts per spot. + min_expr: float + Minimum gene expression of either L or R for spot to be considered to + have reasonable score. + filter_pairs: bool + Whether to filter to valid pairs or not. + spot_indices: np.ndarray + Array of integers speci Returns ------- - lrs: np.array lr pairs from the database in format ['L1_R1', 'LN_RN'] + lrs: np.ndarray lr pairs from the database in format ['L1_R1', 'LN_RN'] """ - if type(spot_indices) == type(None): + if spot_indices is None: spot_indices = np.array(list(range(len(adata))), dtype=np.int32) spot_lr1s = get_spot_lrs( @@ -118,45 +141,47 @@ def get_lrs_scores( spot_lr2s = get_spot_lrs( adata, lr_pairs=lrs, lr_order=False, filter_pairs=filter_pairs ) - if filter_pairs: - lrs = np.array( - [ - "_".join(spot_lr1s.columns.values[i : i + 2]) - for i in range(0, spot_lr1s.shape[1], 2) - ] - ) - # Calculating the lr_scores across spots for the inputted lrs # lr_scores = get_scores( spot_lr1s.values, spot_lr2s.values, neighbours, het_vals, min_expr, spot_indices ) - if filter_pairs: - return lr_scores, lrs - else: - return lr_scores + new_lrs = np.array( + [ + "_".join(spot_lr1s.columns.values[i : i + 2]) + for i in range(0, spot_lr1s.shape[1], 2) + ] + ) + + return lr_scores, new_lrs def get_spot_lrs( adata: AnnData, - lr_pairs: list, + lr_pairs: np.ndarray, lr_order: bool, filter_pairs: bool = True, ): """ Parameters ---------- - adata: AnnData The adata object to scan - lr_pairs: list List of the lr pairs (e.g. ['L1_R1', 'L2_R2',...] - lr_order: bool Forward version of the spot lr pairs (L1_R1), False indicates reverse (R1_L1) - filter_pairs: bool Whether to filter the pairs or not (check if present before subsetting). + adata (AnnData): + The adata object to scan + lr_pairs (np.ndarray): + np.ndarray of the lr pairs (e.g. ['L1_R1', 'L2_R2',...] + lr_order (bool): + Forward version of the spot lr pairs (L1_R1), False indicates reverse (R1_L1) + filter_pairs (bool): + Whether to filter the pairs or not (check if present before sub-setting). Returns ------- - spot_lrs: pd.DataFrame Spots*GeneOrder, in format l1, r1, ... ln, rn if lr_order True, else r1, l1, ... rn, ln + spot_lrs: pd.DataFrame + Spots*GeneOrder, in format l1, r1, ... ln, rn if lr_order True, else r1, + l1, ... rn, ln """ df = adata.to_df() - pairs_rev = [f'{pair.split("_")[1]}_{pair.split("_")[0]}' for pair in lr_pairs] + pairs_rev = [f"{pair.split('_')[1]}_{pair.split('_')[0]}" for pair in lr_pairs] pairs_wRev = [] for i in range(len(lr_pairs)): pairs_wRev.extend([lr_pairs[i], pairs_rev[i]]) @@ -168,27 +193,39 @@ def get_spot_lrs( if lr.split("_")[0] in df.columns and lr.split("_")[1] in df.columns ] - lr_cols = [pair.split("_")[int(lr_order == False)] for pair in pairs_wRev] + if lr_order: + lr_cols = [pair.split("_")[0] for pair in pairs_wRev] # Get ligand + else: + lr_cols = [pair.split("_")[1] for pair in pairs_wRev] # Get receptor spot_lrs = df[lr_cols] return spot_lrs def calc_neighbours( adata: AnnData, - distance: float = None, + distance: float | None = None, index: bool = True, verbose: bool = True, ) -> List: - """Calculate the proportion of known ligand-receptor co-expression among the neighbouring spots or within spots + """Calculate the proportion of known ligand-receptor co-expression among the + neighbouring spots or within spots Parameters ---------- - adata: AnnData The data object to scan - distance: float Distance to determine the neighbours (default: closest), distance=0 means within spot - index: bool Indicates whether to return neighbours as indices to other spots or names of other spots. + adata (AnnData): + The data object to scan + distance (float): + Distance to determine the neighbours (default: closest), distance=0 means + within spot + index (bool): + Indicates whether to return neighbours as indices to other spots or names of + other spots. + verbose (bool): + Display debugging information Returns ------- - neighbours: numba.typed.List List of np.array's indicating neighbours by indices for each spot. + neighbours (numba.typed.List): + List of np.array's indicating neighbours by indices for each spot. """ if verbose: print("Calculating neighbours...") @@ -219,7 +256,7 @@ def calc_neighbours( n_neighs = np.array([len(neigh) for neigh in neighbours]) if verbose: print( - f"{len(np.where(n_neighs==0)[0])} spots with no neighbours, " + f"{len(np.where(n_neighs == 0)[0])} spots with no neighbours, " f"{int(np.median(n_neighs))} median spot neighbours." ) @@ -243,13 +280,18 @@ def lr_core( """Calculate the lr scores for each spot. Parameters ---------- - spot_lr1: np.ndarray Spots*Ligands - spot_lr2: np.ndarray Spots*Receptors - neighbours: numba.typed.List List of np.array's indicating neighbours by indices for each spot. - min_expr: float Minimum expression for gene to be considered expressed. + spot_lr1: np.ndarray + Spots*Ligands + spot_lr2: np.ndarray + Spots*Receptors + neighbours: numba.typed.List + List of np.array's indicating neighbours by indices for each spot. + min_expr: float + Minimum expression for gene to be considered expressed. Returns ------- - lr_scores: numpy.ndarray Cells*LR-scores. + lr_scores: numpy.ndarray + Cells*LR-scores. """ # Calculating mean of lr2 expressions from neighbours of each spot nb_lr2 = np.zeros((len(spot_indices), spot_lr2.shape[1]), np.float64) @@ -278,15 +320,20 @@ def lr_pandas( """Calculate the lr scores for each spot. Parameters ---------- - spot_lr1: pd.DataFrame Cells*Ligands - spot_lr2: pd.DataFrame Cells*Receptors - neighbours: list List of neighbours by indices for each spot. + spot_lr1 (pd.DataFrame): + Cells*Ligands + spot_lr2 (pd.DataFrame): + Cells*Receptors + neighbours (list): + List of neighbours by indices for each spot. Returns ------- - lr_scores: numpy.ndarray Cells*LR-scores. + lr_scores (numpy.ndarray): + Cells*LR-scores. """ - # function to calculate mean of lr2 expression between neighbours or within spot (distance==0) for each spot + # function to calculate mean of lr2 expression between neighbours or within + # spot (distance==0) for each spot def mean_lr2(x): # get lr2 expressions from the neighbour(s) n_spots = neighbours[spot_lr2.index.tolist().index(x.name)] @@ -317,21 +364,27 @@ def get_scores( spot_lr1s: np.ndarray, spot_lr2s: np.ndarray, neighbours: List, - het_vals: np.array, + het_vals: np.ndarray, min_expr: float, - spot_indices: np.array, -) -> np.array: + spot_indices: np.ndarray, +) -> np.ndarray: """Calculates the scores. Parameters ---------- - spot_lr1s: np.ndarray Spots*GeneOrder1, in format l1, r1, ... ln, rn - spot_lr2s: np.ndarray Spots*GeneOrder2, in format r1, l1, ... rn, ln - het_vals: np.ndarray Spots*Het counts - neighbours: numba.typed.List List of np.array's indicating neighbours by indices for each spot. - min_expr: float Minimum expression for gene to be considered expressed. + spot_lr1s: np.ndarray + Spots*GeneOrder1, in format l1, r1, ... ln, rn + spot_lr2s: np.ndarray + Spots*GeneOrder2, in format r1, l1, ... rn, ln + het_vals: np.ndarray + Spots*Het counts + neighbours: numba.typed.List + List of np.array's indicating neighbours by indices for each spot. + min_expr: float + Minimum expression for gene to be considered expressed. Returns ------- - spot_scores: np.ndarray Spots*LR pair of the LR scores per spot. + spot_scores: np.ndarray + Spots*LR pair of the LR scores per spot. """ spot_scores = np.zeros((len(spot_indices), spot_lr1s.shape[1] // 2), np.float64) for i in prange(0, spot_lr1s.shape[1] // 2): @@ -352,18 +405,26 @@ def lr_grid( radius: int = 1, verbose: bool = True, ) -> AnnData: - """Calculate the proportion of known ligand-receptor co-expression among the neighbouring grids or within each grid + """Calculate the proportion of known ligand-receptor co-expression among the + neighbouring grids or within each grid Parameters ---------- - adata: AnnData The data object to scan - num_row: int Number of grids on height - num_col: int Number of grids on width - use_lr: str object to keep the result (default: adata.uns['cci_lr']) - radius: int Distance to determine the neighbour grids (default: 1=nearest), radius=0 means within grid + adata: AnnData + The data object to scan + num_row: int + Number of grids on height + num_col: int + Number of grids on width + use_lr: str + object to keep the result (default: adata.uns['cci_lr']) + radius: int + Distance to determine the neighbour grids (default: 1=nearest), + radius=0 means within grid Returns ------- - adata: AnnData The data object with the cci_lr grid result updated + adata: AnnData + The data object with the cci_lr grid result updated """ # prepare data as pd.dataframe @@ -406,7 +467,8 @@ def lr_grid( if verbose: print("Altogether " + str(len(avail)) + " valid L-R pairs") - # function to calculate mean of lr2 expression between neighbours or within spot (distance==0) for each spot + # function to calculate mean of lr2 expression between neighbours or within spot + # (distance==0) for each spot def mean_lr2(x): # get the neighbour(s)' lr2 expressions nbs = grid_lr2.loc[neighbours[df_grid.index.tolist().index(x.name)], :] diff --git a/stlearn/tools/microenv/cci/base_grouping.py b/stlearn/tl/cci/base_grouping.py similarity index 84% rename from stlearn/tools/microenv/cci/base_grouping.py rename to stlearn/tl/cci/base_grouping.py index a24b21e8..5e229efe 100644 --- a/stlearn/tools/microenv/cci/base_grouping.py +++ b/stlearn/tl/cci/base_grouping.py @@ -1,15 +1,16 @@ -""" Performs LR analysis by grouping LR pairs which having hotspots across - similar tissues. +"""Performs LR analysis by grouping LR pairs which having hotspots across +similar tissues. """ -from stlearn.pl import het_plot -from sklearn.cluster import DBSCAN, AgglomerativeClustering -from anndata import AnnData -from tqdm import tqdm +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt import seaborn as sb +from anndata import AnnData +from sklearn.cluster import DBSCAN, AgglomerativeClustering +from tqdm import tqdm + +from stlearn.pl import het_plot def get_hotspots( @@ -22,18 +23,28 @@ def get_hotspots( plot_diagnostics: bool = False, show_plot: bool = False, ): - """Determines the hotspots for the inputted scores by progressively setting more stringent cutoffs & cluster in space, chooses point which maximises number of clusters. + """Determines the hotspots for the inputted scores by progressively setting + more stringent cutoffs & cluster in space, chooses point which maximises number + of clusters. Parameters ---------- - adata: AnnData The data object - lr_scores: np.ndarray LR_pair*Spots containing the LR scores. - lrs: np.array The LR_pairs, in-line with the rows of scores. - eps: float The eps parameter used in DBScan to get the number of clusters. - quantile: float The quantiles to use for the cutoffs, if 0.05 then will take non-zero quantiles of 0.05, 0.1,..., 1 quantiles to cluster. + adata: AnnData + The data object + lr_scores: np.ndarray + LR_pair*Spots containing the LR scores. + lrs: np.array + The LR_pairs, in-line with the rows of scores. + eps: float + The eps parameter used in DBScan to get the number of clusters. + quantile: float + The quantiles to use for the cutoffs, if 0.05 then will take non-zero + quantiles of 0.05, 0.1,..., 1 quantiles to cluster. Returns ------- - lr_hot_scores: np.ndarray, lr_cutoffs: np.array First is the LR scores for just the hotspots, second is the cutoff used to get those LR_scores. + lr_hot_scores: np.ndarray, lr_cutoffs: np.array + First is the LR scores for just the hotspots, second is the cutoff used to + get those LR_scores. """ coors = adata.obs[["imagerow", "imagecol"]].values lr_summary, lr_hot_scores = hotspot_core( @@ -107,14 +118,13 @@ def get_hotspots( adata.obsm["cluster_scores"] = cluster_scores if verbose: - print(f"\tSummary values of lrs in adata.uns['lr_summary'].") - print( - f"\tMatrix of lr scores in same order as the summary in adata.obsm['lr_scores']." - ) - print(f"\tMatrix of the hotspot scores in adata.obsm['lr_hot_scores'].") + print("\tSummary values of lrs in adata.uns['lr_summary'].") print( - f"\tMatrix of the mean LR cluster scores in adata.obsm['cluster_scores']." + "\tMatrix of lr scores in same order as the summary in " + + "adata.obsm['lr_scores']." ) + print("\tMatrix of the hotspot scores in adata.obsm['lr_hot_scores'].") + print("\tMatrix of the mean LR cluster scores in adata.obsm['cluster_scores'].") def hotspot_core( @@ -137,7 +147,7 @@ def hotspot_core( # cols: spot_counts, cutoff, hotspot_counts, lr_cluster lr_summary = np.zeros((score_copy.shape[0], 4)) - ### Also creating grouping lr_pairs by quantiles to plot diagnostics ### + # Also creating grouping lr_pairs by quantiles to plot diagnostics if plot_diagnostics: lr_quantiles = [(i / 6) for i in range(1, 7)][::-1] lr_mean_scores = np.apply_along_axis(non_zero_mean, 1, score_copy) @@ -152,7 +162,7 @@ def hotspot_core( total=len(lrs), desc="Removing background lr scores...", bar_format="{l_bar}{bar}", - disable=verbose == False, + disable=not verbose, ) as pbar: for i, lr_ in enumerate(lrs): lr_score_ = score_copy[i, :] @@ -185,7 +195,7 @@ def hotspot_core( lr_summary[i, 2] = len(np.where(lr_score_ > 0)[0]) # Adding the diagnostic plots # - if plot_diagnostics and lr_ in quant_lrs and type(adata) != type(None): + if plot_diagnostics and lr_ in quant_lrs and adata is not None: add_diagnostic_plots( adata, i, @@ -230,7 +240,7 @@ def add_diagnostic_plots( # Scatter plot # axes[q_i][0].scatter(cutoffs, n_clusters) - axes[q_i][0].set_title(f"n_clusts*mean_spot_score vs cutoff") + axes[q_i][0].set_title("n_clusts*mean_spot_score vs cutoff") axes[q_i][0].set_xlabel("cutoffs") axes[q_i][0].set_ylabel("n_clusts*mean_spot_score") diff --git a/stlearn/plotting/__init__.py b/stlearn/tl/cci/databases/__init__.py similarity index 100% rename from stlearn/plotting/__init__.py rename to stlearn/tl/cci/databases/__init__.py diff --git a/stlearn/tools/microenv/cci/databases/connectomeDB2020_lit.txt b/stlearn/tl/cci/databases/connectomeDB2020_lit.txt similarity index 100% rename from stlearn/tools/microenv/cci/databases/connectomeDB2020_lit.txt rename to stlearn/tl/cci/databases/connectomeDB2020_lit.txt diff --git a/stlearn/tools/microenv/cci/databases/connectomeDB2020_put.txt b/stlearn/tl/cci/databases/connectomeDB2020_put.txt similarity index 100% rename from stlearn/tools/microenv/cci/databases/connectomeDB2020_put.txt rename to stlearn/tl/cci/databases/connectomeDB2020_put.txt diff --git a/stlearn/tools/microenv/cci/go.R b/stlearn/tl/cci/go.R similarity index 100% rename from stlearn/tools/microenv/cci/go.R rename to stlearn/tl/cci/go.R diff --git a/stlearn/tools/microenv/cci/go.py b/stlearn/tl/cci/go.py similarity index 87% rename from stlearn/tools/microenv/cci/go.py rename to stlearn/tl/cci/go.py index 617a7fe5..6abfd71a 100644 --- a/stlearn/tools/microenv/cci/go.py +++ b/stlearn/tl/cci/go.py @@ -1,8 +1,8 @@ -""" Wrapper for performing the LR GO analysis. -""" +"""Wrapper for performing the LR GO analysis.""" import os -import stlearn.tools.microenv.cci.r_helpers as rhs + +import stlearn.tl.cci.r_helpers as rhs def run_GO(genes, bg_genes, species, r_path, p_cutoff=0.01, q_cutoff=0.5, onts="BP"): @@ -21,7 +21,7 @@ def run_GO(genes, bg_genes, species, r_path, p_cutoff=0.01, q_cutoff=0.5, onts=" # Running the function on the genes # genes_r = rhs.ro.StrVector(genes) - if type(bg_genes) != type(None): + if bg_genes is not None: bg_genes_r = rhs.ro.StrVector(bg_genes) else: bg_genes_r = rhs.ro.r["as.null"]() diff --git a/stlearn/tools/microenv/cci/het.py b/stlearn/tl/cci/het.py similarity index 91% rename from stlearn/tools/microenv/cci/het.py rename to stlearn/tl/cci/het.py index bc6fb221..30043659 100644 --- a/stlearn/tools/microenv/cci/het.py +++ b/stlearn/tl/cci/het.py @@ -1,16 +1,17 @@ +from collections.abc import Iterable + import numpy as np import pandas as pd -from anndata import AnnData import scipy.spatial as spatial - +from anndata import AnnData +from numba import jit, njit, prange from numba.typed import List -from numba import njit, jit, prange -from stlearn.tools.microenv.cci.het_helpers import ( +from stlearn.tl.cci.het_helpers import ( + add_unique_edges, edge_core, get_between_spot_edge_array, get_data_for_counting, - add_unique_edges, get_neighbourhoods, init_edge_list, ) @@ -18,28 +19,36 @@ def count( adata: AnnData, - use_label: str = None, + use_label: str | None = None, use_het: str = "cci_het", verbose: bool = True, - distance: float = None, + distance: float | None = None, ) -> AnnData: """Count the cell type densities Parameters ---------- - adata: AnnData The data object including the cell types to count - use_label: The cell type results to use in counting - use_het: The stoarge place for result - distance: int Distance to determine the neighbours (default is the nearest neighbour), distance=0 means within spot + adata: AnnData + The data object including the cell types to count + use_label: + The cell type results to use in counting + use_het: + The storage place for result + distance: int + Distance to determine the neighbours (default is the nearest neighbour), + distance=0 means within spot Returns ------- - adata: AnnData With the counts of specified clusters in nearby spots stored as adata.uns['het'] + adata: AnnData + With the counts of specified clusters in nearby spots stored as + adata.uns['het'] """ library_id = list(adata.uns["spatial"].keys())[0] # between spot if distance != 0: - # automatically calculate distance if not given, won't overwrite distance=0 which is within-spot + # automatically calculate distance if not given, won't overwrite distance=0 + # which is within-spot if not distance: # calculate default neighbour distance scalefactors = next(iter(adata.uns["spatial"].values()))["scalefactors"] @@ -92,10 +101,15 @@ def get_edges(adata: AnnData, L_bool: np.array, R_bool: np.array, sig_bool: np.a Parameters ---------- - adata: AnnData - L_bool: np.array len(L_bool)==len(adata), True if ligand expressed in that spot. - R_bool: np.array len(R_bool)==len(adata), True if receptor expressed in that spot. - sig_bool np.array: len(sig_bool)==len(adata), True if spot has significant LR interactions. + adata : AnnData + Annotated data object containing spatial transcriptomics data. + L_bool : np.ndarray of bool, shape (n_spots,) + Boolean array indicating spots where the ligand is expressed. + R_bool : np.ndarray of bool, shape (n_spots,) + Boolean array indicating spots where the receptor is expressed. + sig_bool : np.ndarray of bool, shape (n_spots,) + Boolean array indicating spots with significant ligand-receptor interactions. + Returns ------- edge_list_unique: list> Either a list of tuples (directed), or @@ -266,7 +280,8 @@ def get_interaction_matrix( # 1) sig spot with ligand, only neighbours with receptor relevant # 2) sig spot with receptor, only neighbours with ligand relevant # NOTE, A<->B is double counted, but on different side of matrix. - # (if bidirectional interaction between two spots, counts as two seperate interactions). + # (if bidirectional interaction between two spots, counts as two seperate + # interactions). LR_edges = get_interactions( cell_data, neighbourhood_bcs, @@ -341,7 +356,6 @@ def get_interactions( # Now retrieving the interaction edges # for i in range(all_set.shape[0]): - # Determining which spots have cell type A # A_bool_2 = cell_data[:, i] > cell_prop_cutoff A_gene1_bool = np.logical_and(A_bool_2, gene1_bool) @@ -401,7 +415,7 @@ def create_grids(adata: AnnData, num_row: int, num_col: int, radius: int = 1): grids, neighbours = [], [] # generate grids from top to bottom and left to right for n in range(num_row * num_col): - neighbour = [] + neighbour: Iterable[float] = [] x = min_x + n // num_row * width # left side y = min_y + n % num_row * height # upper side grids.append([x, y]) @@ -435,7 +449,7 @@ def count_grid( adata: AnnData, num_row: int = 30, num_col: int = 30, - use_label: str = None, + use_label: str | None = None, use_het: str = "cci_het_grid", radius: int = 1, verbose: bool = True, @@ -446,13 +460,15 @@ def count_grid( adata: AnnData The data object including the cell types to count num_row: int Number of grids on height num_col: int Number of grids on width - use_label: The cell type results to use in counting - use_het: The stoarge place for result - radius: int Distance to determine the neighbour grids (default: 1=nearest), radius=0 means within grid + use_label: The cell type results to use in counting + use_het: The storage place for result + radius: int Distance to determine the neighbour grids + (default: 1=nearest), radius=0 means within grid Returns ------- - adata: AnnData With the counts of specified clusters in each grid of the tissue stored as adata.uns['het'] + adata (AnnData): With the counts of specified clusters in each grid of the + tissue stored as adata.uns['het'] """ coor = adata.obs[["imagerow", "imagecol"]] diff --git a/stlearn/tools/microenv/cci/het_helpers.py b/stlearn/tl/cci/het_helpers.py similarity index 70% rename from stlearn/tools/microenv/cci/het_helpers.py rename to stlearn/tl/cci/het_helpers.py index 270e811c..e5761f15 100644 --- a/stlearn/tools/microenv/cci/het_helpers.py +++ b/stlearn/tl/cci/het_helpers.py @@ -3,10 +3,8 @@ """ import numpy as np -import numba -from numba import types +from numba import njit from numba.typed import List -from numba import njit, jit @njit @@ -32,7 +30,7 @@ def edge_core( cell_type_index: int Column of cell_data that contains the \ cell type of interest. - neighbourhood_bcs: List List of lists, inner list for each \ + neighbourhood_bcs (List): List of lists, inner list for each \ spot. First element of inner list is \ spot barcode, second element is array \ of neighbourhood spot barcodes. @@ -77,7 +75,7 @@ def edge_core( elif len(spot_indices) == 0: return edge_list[1:] - ### Within-spot mode + # Within-spot mode # within-spot, will have only itself as a neighbour in this mode within_mode = edge_list[0][0] == edge_list[0][1] if within_mode: @@ -86,7 +84,7 @@ def edge_core( if neigh_bool[i] and cell_data[i] > cutoff: edge_list.append((neighbourhood_bcs[i][0], neighbourhood_bcs[i][1][0])) - ### Between-spot mode + # Between-spot mode else: # Subsetting the neighbourhoods to relevant spots # neighbourhood_bcs_sub = List() @@ -228,128 +226,6 @@ def get_data_for_counting(adata, use_label, mix_mode, all_set): ) # neighbourhood_bcs, neighbourhood_indices -def get_data_for_counting_OLD(adata, use_label, mix_mode, all_set): - """Retrieves the minimal information necessary to perform edge counting.""" - # First determining how the edge counting needs to be performed # - # Ensuring compatibility with current way of adding label_transfer to object - if use_label == "label_transfer" or use_label == "predictions": - obs_key, uns_key = "predictions", "label_transfer" - else: - obs_key, uns_key = use_label, use_label - - # Getting the neighbourhoods # - neighbours, neighbourhood_bcs, neighbourhood_indices = get_neighbourhoods(adata) - - # Getting the cell type information; if not mixtures then populate - # matrix with one's indicating pure spots. - if mix_mode: - cell_props = adata.uns[uns_key] - cols = cell_props.columns.values.astype(str) - col_order = [ - np.where([cell_type in col for col in cols])[0][0] for cell_type in all_set - ] - cell_data = adata.uns[uns_key].iloc[:, col_order].values.astype(np.float64) - else: - cell_labels = adata.obs.loc[:, obs_key].values - cell_data = np.zeros((len(cell_labels), len(all_set)), dtype=np.float64) - for i, cell_type in enumerate(all_set): - cell_data[:, i] = ( - (cell_labels == cell_type).astype(np.int32).astype(np.float64) - ) - - spot_bcs = adata.obs_names.values.astype(str) - return spot_bcs, cell_data, neighbourhood_bcs, neighbourhood_indices - - -# @njit -def get_neighbourhoods_FAST( - spot_bcs: np.array, - spot_neigh_bcs: np.ndarray, - n_spots: int, - str_dtype: str, - neigh_indices: np.array, - neigh_bcs: np.array, -): - """Gets the neighbourhood information, njit compiled.""" - - # Determining the neighbour spots used for significance testing # - # neighbours = List( numba.int64[:] ) - # neighbourhood_bcs = List((numba.int64, numba.int64[:])) - # neighbourhood_indices = List( (types.unicode_type, types.unicode_type[:]) ) - - ### Numba version - # neighbours = List([neigh_indices])[1:] - # neighbourhood_bcs = List() - # neighbourhood_indices = List([(0, neigh_indices)])[1:] - - #### Trying normal lists - neighbours, neighbourhood_bcs, neighbourhood_indices = [], [], [] - - for i in range(spot_neigh_bcs.shape[0]): - neigh_bcs = np.array(spot_neigh_bcs[i, :][0].split(",")) - neigh_bcs = neigh_bcs[neigh_bcs != ""] - # neigh_bcs_sub = List() - # for neigh_bc in neigh_bcs: - # if neigh_bc in spot_bcs: - # neigh_bcs_sub.append( neigh_bc ) - - # neigh_bcs_array = np.empty((len(neigh_bcs_sub)), str_dtype) - # neigh_bcs_array = np.empty(len(neigh_bcs_sub), dtype=str_dtype) - # neigh_indices = np.zeros((len(neigh_bcs_sub)), dtype=np.int64) - neigh_bcs_array, neigh_indices = [], [] - neigh_bcs_sub = List() - for j, neigh_bc in enumerate(neigh_bcs): - - bc_indices = np.where(spot_bcs == neigh_bc)[0] - if len(bc_indices) > 0: - - neigh_bcs_array.append(neigh_bc) - neigh_indices.append(bc_indices[0]) - - neigh_bcs_array = np.array(neigh_bcs_array, dtype=str_dtype) - neigh_indices = np.array(neigh_indices, dtype=np.int64) - - neighbours.append(neigh_indices) - neighbourhood_indices.append((i, neigh_indices)) - neighbourhood_bcs.append((spot_bcs[i], neigh_bcs_array)) - - # return neighbours, neighbourhood_bcs, neighbourhood_indices - return List(neighbours), List(neighbourhood_bcs), List(neighbourhood_indices) - - -def get_data_for_counting_OLD(adata, use_label, mix_mode, all_set): - """Retrieves the minimal information necessary to perform edge counting.""" - # First determining how the edge counting needs to be performed # - # Ensuring compatibility with current way of adding label_transfer to object - if use_label == "label_transfer" or use_label == "predictions": - obs_key, uns_key = "predictions", "label_transfer" - else: - obs_key, uns_key = use_label, use_label - - # Getting the neighbourhoods # - neighbours, neighbourhood_bcs, neighbourhood_indices = get_neighbourhoods(adata) - - # Getting the cell type information; if not mixtures then populate - # matrix with one's indicating pure spots. - if mix_mode: - cell_props = adata.uns[uns_key] - cols = cell_props.columns.values.astype(str) - col_order = [ - np.where([cell_type in col for col in cols])[0][0] for cell_type in all_set - ] - cell_data = adata.uns[uns_key].iloc[:, col_order].values.astype(np.float64) - else: - cell_labels = adata.obs.loc[:, obs_key].values - cell_data = np.zeros((len(cell_labels), len(all_set)), dtype=np.float64) - for i, cell_type in enumerate(all_set): - cell_data[:, i] = ( - (cell_labels == cell_type).astype(np.int_).astype(np.float64) - ) - - spot_bcs = adata.obs_names.values.astype(str) - return spot_bcs, cell_data, neighbourhood_bcs, neighbourhood_indices - - def get_neighbourhoods_FAST( spot_bcs: np.array, spot_neigh_bcs: np.ndarray, @@ -368,9 +244,7 @@ def get_neighbourhoods_FAST( neigh_bcs = neigh_bcs[neigh_bcs != ""] neigh_bcs_array, neigh_indices = [], [] - neigh_bcs_sub = List() for j, neigh_bc in enumerate(neigh_bcs): - bc_indices = np.where(spot_bcs == neigh_bc)[0] if len(bc_indices) > 0: neigh_bcs_array.append(neigh_bc) @@ -391,7 +265,7 @@ def get_neighbourhoods(adata): # Old stlearn version where didn't store neighbourhood barcodes, not good # for anndata subsetting!! - if not "spot_neigh_bcs" in adata.obsm: + if "spot_neigh_bcs" not in adata.obsm: # Determining the neighbour spots used for significance testing # neighbours = List() for i in range(adata.obsm["spot_neighbours"].shape[0]): @@ -410,7 +284,6 @@ def get_neighbourhoods(adata): neighbourhood_indices.append((spot_i, neighbours[spot_i])) neighbourhood_bcs.append((spot_bcs[spot_i], spot_bcs[neighbours[spot_i]])) else: # Newer version - spot_bcs = adata.obs_names.values.astype(str) spot_neigh_bcs = adata.obsm["spot_neigh_bcs"].values.astype(str) diff --git a/stlearn/tools/microenv/cci/merge.py b/stlearn/tl/cci/merge.py similarity index 93% rename from stlearn/tools/microenv/cci/merge.py rename to stlearn/tl/cci/merge.py index 6f25908b..4eb91dc4 100644 --- a/stlearn/tools/microenv/cci/merge.py +++ b/stlearn/tl/cci/merge.py @@ -1,5 +1,4 @@ import numpy as np -import pandas as pd from anndata import AnnData @@ -25,7 +24,8 @@ def merge( if verbose: print( - "Results of spatial interaction analysis has been written to adata.uns['merged']" + "Results of spatial interaction analysis has been written to " + + "adata.uns['merged']" ) return adata diff --git a/stlearn/tools/microenv/cci/perm_utils.py b/stlearn/tl/cci/perm_utils.py similarity index 82% rename from stlearn/tools/microenv/cci/perm_utils.py rename to stlearn/tl/cci/perm_utils.py index 083ae1ef..6bc84d4d 100644 --- a/stlearn/tools/microenv/cci/perm_utils.py +++ b/stlearn/tl/cci/perm_utils.py @@ -1,10 +1,9 @@ import numpy as np import pandas as pd -from scipy.spatial.distance import euclidean, canberra -from sklearn.preprocessing import MinMaxScaler - from numba import njit, prange from numba.typed import List +from scipy.spatial.distance import canberra +from sklearn.preprocessing import MinMaxScaler from .base import get_lrs_scores @@ -13,7 +12,7 @@ def nonzero_quantile(expr, q, interpolation): """Calculating the non-zero quantiles.""" nonzero_expr = expr[expr > 0] quants = np.quantile(nonzero_expr, q=q, interpolation=interpolation) - if type(quants) != np.array and type(quants) != np.ndarray: + if not isinstance(quants, np.ndarray) or quants.ndim == 0: quants = np.array([quants]) return quants @@ -36,7 +35,9 @@ def get_lr_quants( """Gets the quantiles per gene in the LR pair, & then concatenates. Returns ------- - lr_quants, l_quants, r_quants: np.ndarray First is concatenation of two latter. Each row is a quantile value, each column is a LR pair. + lr_quants, l_quants, r_quants (np.ndarray): First is concatenation of two latter. + Each row is a quantile value, each + column is an LR pair. """ quant_func = nonzero_quantile if method != "quantiles" else np.quantile @@ -58,7 +59,9 @@ def get_lr_zeroprops(lr_expr: pd.DataFrame, l_indices: list, r_indices: list): """Gets the proportion of zeros per gene in the LR pair, & then concatenates. Returns ------- - lr_props, l_props, r_props: np.ndarray First is concatenation of two latter. Each row is a prop value, each column is a LR pair. + lr_props, l_props, r_props (np.ndarray): First is concatenation of two latter. + Each row is a prop value, each column + is an LR pair. """ # First getting the quantiles of gene expression # @@ -76,7 +79,8 @@ def get_lr_bounds(lr_value: float, bin_bounds: np.array): """For the given lr_value, returns the bin where it belongs. Returns ------- - lr_bin: tuple Tuple of length 2, first is the lower bound of the bin, second is upper bound of the bin. + lr_bin (tuple): Tuple of length 2, first is the lower bound of the bin, second + is upper bound of the bin. """ if np.any(bin_bounds == lr_value): # If sits on a boundary lr_i = np.where(bin_bounds == lr_value)[0][0] @@ -105,17 +109,17 @@ def get_similar_genes( by measuring distance between the gene expression quantiles. Parameters ---------- - ref_quants: np.array The pre-calculated quantiles. - ref_props: np.array The query zero proportions. - n_genes: int Number of equivalent genes to select. + ref_quants: np.array The pre-calculated quantiles. + ref_props: np.array The query zero proportions. + n_genes: int Number of equivalent genes to select. candidate_expr: np.ndarray Expression of gene candidates (cells*genes). candidate_genes: np.array Same as candidate_expr.shape[1], indicating gene names. - quantiles: tuple The quantile to use + quantiles: tuple The quantile to use Returns ------- similar_genes: np.array Array of strings for gene names. """ - if type(quantiles) == float: + if isinstance(quantiles, float): quantiles = np.array([quantiles]) else: quantiles = np.array(quantiles) @@ -168,17 +172,21 @@ def get_similar_genes_Quantiles( by measuring distance between the gene expression quantiles. Parameters ---------- - gene_expr: np.array Expression of the gene of interest, or, if the same length as quantiles, then assumes is the pre-calculated quantiles. - n_genes: int Number of equivalent genes to select. - candidate_quants: np.ndarray Expression quantiles of gene candidates (quantiles*genes). - candidate_genes: np.array Same as candidate_expr.shape[1], indicating gene names. - quantiles: tuple The quantile to use + gene_expr: np.array Expression of the gene of interest, or, if the + same length as quantiles, then assumes is the + pre-calculated quantiles. + n_genes: int Number of equivalent genes to select. + candidate_quants: np.ndarray Expression quantiles of gene candidates + (quantiles*genes). + candidate_genes: np.array Same as candidate_expr.shape[1], indicating gene + names. + quantiles: tuple The quantile to use Returns ------- similar_genes: np.array Array of strings for gene names. """ - if type(quantiles) == float: + if isinstance(quantiles, float): quantiles = np.array([quantiles]) else: quantiles = np.array(quantiles) @@ -218,7 +226,7 @@ def get_similar_genesFAST( ref_quants: np.array, n_genes: int, candidate_quants: np.ndarray, - candidate_genes: np.array, + candidate_genes: np.ndarray, ): """Fast version of the above with parallelisation.""" @@ -295,7 +303,7 @@ def get_lr_features(adata, lr_expr, lrs, quantiles): # Calculating the zero proportions, for grouping based on median/zeros # lr_props, l_props, r_props = get_lr_zeroprops(lr_expr, l_indices, r_indices) - ######## Getting lr features for later diagnostics ####### + # Getting lr features for later diagnostics lr_meds, l_meds, r_meds = get_lr_quants( lr_expr, l_indices, r_indices, quantiles=np.array([0.5]), method="" ) @@ -311,17 +319,19 @@ def get_lr_features(adata, lr_expr, lrs, quantiles): # Saving the lrfeatures... cols = ["nonzero-median", "zero-prop", "median_rank", "prop_rank", "mean_rank"] - lr_features = pd.DataFrame(index=lrs, columns=cols) - lr_features.iloc[:, 0] = lr_median_means - lr_features.iloc[:, 1] = lr_prop_means - lr_features.iloc[:, 2] = np.array(median_ranks) - lr_features.iloc[:, 3] = np.array(prop_ranks) - lr_features.iloc[:, 4] = np.array(mean_ranks) + lr_features_data = { + cols[0]: np.array(lr_median_means, dtype=np.float64), + cols[1]: np.array(lr_prop_means, dtype=np.float64), + cols[2]: np.array(median_ranks, dtype=np.float64), + cols[3]: np.array(prop_ranks, dtype=np.float64), + cols[4]: np.array(mean_ranks, dtype=np.float64), + } + lr_features = pd.DataFrame(lr_features_data, index=lrs) lr_features = lr_features.iloc[np.argsort(mean_ranks), :] lr_cols = [f"L_{quant}" for quant in quantiles] + [ f"R_{quant}" for quant in quantiles ] - quant_df = pd.DataFrame(lr_quants, columns=lr_cols, index=lrs) + quant_df = pd.DataFrame(lr_quants, columns=lr_cols, index=lrs, dtype=np.float64) lr_features = pd.concat((lr_features, quant_df), axis=1) adata.uns["lrfeatures"] = lr_features @@ -347,7 +357,10 @@ def get_lr_bg( l_, r_ = lr_.split("_") if l_ not in gene_bg_genes: l_genes = get_similar_genesFAST( - l_quant, n_genes, candidate_quants, genes # group_l_props, + l_quant, + n_genes, + candidate_quants, + genes, # group_l_props, ) gene_bg_genes[l_] = l_genes else: @@ -355,7 +368,10 @@ def get_lr_bg( if r_ not in gene_bg_genes: r_genes = get_similar_genesFAST( - r_quant, n_genes, candidate_quants, genes # group_r_props, + r_quant, + n_genes, + candidate_quants, + genes, # group_r_props, ) gene_bg_genes[r_] = r_genes else: @@ -364,7 +380,7 @@ def get_lr_bg( rand_pairs = gen_rand_pairs(l_genes, r_genes, n_pairs) spot_indices = np.where(lr_score > 0)[0] - background = get_lrs_scores( + background, _ = get_lrs_scores( adata, rand_pairs, neighbours, diff --git a/stlearn/tools/microenv/cci/permutation.py b/stlearn/tl/cci/permutation.py similarity index 83% rename from stlearn/tools/microenv/cci/permutation.py rename to stlearn/tl/cci/permutation.py index 6ca6ce12..ad9a5f99 100644 --- a/stlearn/tools/microenv/cci/permutation.py +++ b/stlearn/tl/cci/permutation.py @@ -1,17 +1,21 @@ -import sys, os, random, scipy +import os +import random +import sys +from typing import Any + import numpy as np import pandas as pd -from numba.typed import List +import scipy import statsmodels.api as sm +from anndata import AnnData +from numba.typed import List +from sklearn.cluster import AgglomerativeClustering from statsmodels.stats.multitest import multipletests - from tqdm import tqdm -from sklearn.cluster import AgglomerativeClustering -from anndata import AnnData -from .base import lr, calc_neighbours, get_spot_lrs, get_lrs_scores, get_scores +from .base import calc_neighbours, get_lrs_scores, get_scores, get_spot_lrs, lr from .merge import merge -from .perm_utils import get_lr_features, get_lr_bg +from .perm_utils import get_lr_bg, get_lr_features # Newest method # @@ -43,8 +47,7 @@ def perform_spot_testing( n_genes = round(np.sqrt(n_pairs) * 2) if len(genes) < n_genes: print( - "Exiting since need atleast " - f"{n_genes} genes to generate {n_pairs} pairs." + f"Exiting since need atleast {n_genes} genes to generate {n_pairs} pairs." ) return @@ -55,7 +58,7 @@ def perform_spot_testing( ) return - ####### Quantiles to select similar gene to LRs to gen. rand-pairs ####### + # Quantiles to select similar gene to LRs to gen. rand-pairs lr_expr = adata[:, lr_genes].to_df() lr_feats = get_lr_features(adata, lr_expr, lrs, quantiles) l_quants = lr_feats.loc[ @@ -72,7 +75,7 @@ def perform_spot_testing( r_quants = r_quants.astype(" AnnData: @@ -356,16 +360,23 @@ def permutation( adata: AnnData The data object including the cell types to count n_pairs: int Number of gene pairs to run permutation test (default: 1000) distance: int Distance between spots (default: 30) - use_lr: str LR cluster used for permutation test (default: 'lr_neighbours_louvain_max') - use_het: str cell type diversity counts used for permutation test (default 'het') - neg_binom: bool Whether to fit neg binomial paramaters to bg distribution for p-val est. - adj_method: str Method used by statsmodels.stats.multitest.multipletests for MHT correction. - neighbours: list List of the neighbours for each spot, if None then computed. Useful for speeding up function. + use_lr: str LR cluster used for permutation test + (default: 'lr_neighbours_louvain_max') + use_het: str cell type diversity counts used for permutation test + (default 'het') + neg_binom: bool Whether to fit neg binomial parameters to bg distribution + for p-val est. + adj_method: str Method used by statsmodels.stats.multitest.multipletests + for MHT correction. + neighbours: list List of the neighbours for each spot, if None then + computed. Useful for speeding up function. **kwargs: Extra arguments parsed to lr. Returns ------- - adata: AnnData Data Frame of p-values from permutation test for each window stored in adata.uns['merged_pvalues'] - Final significant merged scores stored in adata.uns['merged_sign'] + adata: AnnData Data Frame of p-values from permutation test for each + window stored in adata.uns['merged_pvalues'] + Final significant merged scores stored in + adata.uns['merged_sign'] """ # blockPrint() @@ -374,7 +385,7 @@ def permutation( genes = get_valid_genes(adata, n_pairs) if len(adata.uns["lr"]) > 1: raise ValueError("Permutation test only supported for one LR pair scenario.") - elif type(bg_pairs) == type(None): + elif bg_pairs is None: pairs = get_rand_pairs(adata, genes, n_pairs, lrs=adata.uns["lr"]) else: pairs = bg_pairs @@ -383,11 +394,13 @@ def permutation( # generate random pairs lr1 = adata.uns['lr'][0].split('_')[0] lr2 = adata.uns['lr'][0].split('_')[1] - genes = [item for item in adata.var_names.tolist() if not (item.startswith('MT-') or item.startswith('MT_') or item==lr1 or item==lr2)] + genes = [item for item in adata.var_names.tolist() if not + (item.startswith('MT-') or item.startswith('MT_') or + item==lr1 or item==lr2)] random.shuffle(genes) pairs = [i + '_' + j for i, j in zip(genes[:n_pairs], genes[-n_pairs:])] """ - if use_het != None: + if use_het is not None: scores = adata.obsm["merged"] else: scores = adata.obsm[use_lr] @@ -396,12 +409,11 @@ def permutation( query_pair = adata.uns["lr"] # If neighbours not inputted, then compute # - if type(neighbours) == type(None): + if neighbours is None: neighbours = calc_neighbours(adata, distance, index=run_fast) - if not run_fast and type(background) == type( - None - ): # Run original way if 'fast'=False argument inputted. + if not run_fast and background is None: + # Run original way if 'fast'=False argument inputted. background = [] for item in pairs: adata.uns["lr"] = [item] @@ -413,19 +425,19 @@ def permutation( neighbours=neighbours, **kwargs, ) - if use_het != None: + if use_het is not None: merge(adata, use_lr=use_lr, use_het=use_het, verbose=False) background += adata.obsm["merged"].tolist() else: background += adata.obsm[use_lr].tolist() background = np.array(background) - elif type(background) == type(None): # Run fast if background not inputted + elif background is None: # Run fast if background not inputted spot_lr1s = get_spot_lrs(adata, pairs, lr_order=True, filter_pairs=False) spot_lr2s = get_spot_lrs(adata, pairs, lr_order=False, filter_pairs=False) het_vals = ( - np.array([1] * len(adata)) if use_het == None else adata.obsm[use_het] + np.array([1] * len(adata)) if use_het is None else adata.obsm[use_het] ) background = get_scores( spot_lr1s.values, spot_lr2s.values, neighbours, het_vals @@ -434,12 +446,12 @@ def permutation( # log back the original query adata.uns["lr"] = query_pair - #### Negative Binomial fit + # Negative Binomial fit pvals, pvals_adj, log10_pvals, lr_sign = get_stats( - scores, background, neg_binom, adj_method + scores, background, neg_binom, adj_method=adj_method ) - if use_het != None: + if use_het is not None: adata.obsm["merged"] = scores adata.obsm["merged_pvalues"] = log10_pvals adata.obsm["merged_sign"] = lr_sign @@ -477,19 +489,23 @@ def get_stats( scores: np.array Per spot scores for a particular LR pair. background: np.array Background distribution for non-zero scores. total_bg: int Total number of background values calculated. - neg_binom: bool Whether to use neg-binomial distribution to estimate p-values, NOT appropriate with log1p data, alternative is to use background distribution itself (recommend higher number of n_pairs for this). - adj_method: str Parsed to statsmodels.stats.multitest.multipletests for multiple hypothesis testing correction. + neg_binom: bool Whether to use neg-binomial distribution to estimate + p-values, NOT appropriate with log1p data, alternative is + to use background distribution itself (recommend higher + number of n_pairs for this). + adj_method: str Parsed to statsmodels.stats.multitest.multipletests for + multiple hypothesis testing correction. Returns ------- - stats: tuple Per spot pvalues, pvals_adj, log10_pvals_adj, lr_sign (the LR scores for significant spots). + stats: tuple Per spot pvalues, pvals_adj, log10_pvals_adj, lr_sign + (the LR scores for significant spots). """ - ##### Negative Binomial fit + # Negative Binomial fit if neg_binom: # Need to make full background for fitting !!! background = np.array(list(background) + [0] * (total_bg - len(background))) - pmin, pmax = min(background), max(background) + pmin = min(background) background2 = [item - pmin for item in background] - x = np.linspace(pmin, pmax, 1000) res = sm.NegativeBinomial( background2, np.ones(len(background2)), loglike_method="nb2" ).fit(start_params=[0.1, 0.3], disp=0) @@ -506,11 +522,12 @@ def get_stats( # Calculate probability for all spots pvals = 1 - scipy.stats.nbinom.cdf(scores - pmin, size, prob) - else: ###### Using the actual values to estimate p-values + else: + # Using the actual values to estimate p-values pvals = np.zeros((1, len(scores)), dtype=np.float)[0, :] nonzero_score_bool = scores > 0 nonzero_score_indices = np.where(nonzero_score_bool)[0] - zero_score_indices = np.where(nonzero_score_bool == False)[0] + zero_score_indices = np.where(~nonzero_score_bool)[0] pvals[zero_score_indices] = (total_bg - len(background)) / total_bg pvals[nonzero_score_indices] = [ len(np.where(background >= scores[i])[0]) / total_bg @@ -560,28 +577,30 @@ def get_rand_pairs( adata: AnnData, genes: np.array, n_pairs: int, - lrs: list = None, - im: int = None, + lrs: list, + im: int | None = None, ): """Gets equivalent random gene pairs for the inputted lr pair. Parameters ---------- - adata: AnnData The data object including the cell types to count - lr: int The lr pair string to get equivalent random pairs for (e.g. 'L_R') - genes: np.array Candidate genes to use as pairs. - n_pairs: int Number of random pairs to generate. + adata (AnnData): The data object including the cell types to count + genes (np.array): Candidate genes to use as pairs. + n_pairs (int): Number of random pairs to generate. + lr (int): The lr pair string to get equivalent random pairs + for (e.g. 'L_R') Returns ------- - pairs: list List of random gene pairs with equivalent mean expression (e.g. ['L_R']) + pairs (list) List of random gene pairs with equivalent mean expression + (e.g. ['L_R']) """ lr_genes = [lr.split("_")[0] for lr in lrs] lr_genes += [lr.split("_")[1] for lr in lrs] # get the position of the median of the means between the two genes means_ordered, genes_ordered = get_ordered(adata, genes) - if type(im) == type(None): # Single background per lr pair mode - l, r = lrs[0].split("_") - im = get_median_index(l, r, means_ordered.values, genes_ordered) + if im is None: # Single background per lr pair mode + ligand, receptor = lrs[0].split("_") + im = get_median_index(ligand, receptor, means_ordered.values, genes_ordered) # get n_pair genes sorted by distance to im selected = ( @@ -605,21 +624,23 @@ def get_ordered(adata, genes): return means_ordered, genes_ordered -def get_median_index(l, r, means_ordered, genes_ordered): - """ "Retrieves the index of the gene with a mean expression between the two genes in the lr pair. +def get_median_index(ligand, receptor, means_ordered, genes_ordered): + """Retrieves the index of the gene with a mean expression between the two genes + in the lr pair. Parameters ---------- - X: np.ndarray Spot*Gene expression. - l: str Ligand gene. - r: str Receptor gene. - genes: np.array Candidate genes to use as pairs. + ligand: Ligand gene. + receptor: Receptor gene. + genes_ordered: + means_ordered: Returns ------- - pairs: list List of random gene pairs with equivalent mean expression (e.g. ['L_R']) + pairs (list): List of random gene pairs with equivalent mean expression + (e.g. ['L_R']) """ # sort the mean of each gene expression - i1 = np.where(genes_ordered == l)[0][0] - i2 = np.where(genes_ordered == r)[0][0] + i1 = np.where(genes_ordered == ligand)[0][0] + i2 = np.where(genes_ordered == receptor)[0][0] if means_ordered[i1] > means_ordered[i2]: it = i1 i1 = i2 diff --git a/stlearn/tools/microenv/cci/r_helpers.py b/stlearn/tl/cci/r_helpers.py similarity index 100% rename from stlearn/tools/microenv/cci/r_helpers.py rename to stlearn/tl/cci/r_helpers.py diff --git a/stlearn/tools/clustering/__init__.py b/stlearn/tl/clustering/__init__.py similarity index 57% rename from stlearn/tools/clustering/__init__.py rename to stlearn/tl/clustering/__init__.py index 391d4b0e..d68d6df2 100644 --- a/stlearn/tools/clustering/__init__.py +++ b/stlearn/tl/clustering/__init__.py @@ -1,3 +1,9 @@ +from .annotate import annotate_interactive from .kmeans import kmeans from .louvain import louvain -from .annotate import annotate_interactive + +__all__ = [ + "kmeans", + "louvain", + "annotate_interactive", +] diff --git a/stlearn/tools/clustering/annotate.py b/stlearn/tl/clustering/annotate.py similarity index 88% rename from stlearn/tools/clustering/annotate.py rename to stlearn/tl/clustering/annotate.py index e195351b..40cea6e8 100644 --- a/stlearn/tools/clustering/annotate.py +++ b/stlearn/tl/clustering/annotate.py @@ -1,8 +1,9 @@ from anndata import AnnData -from stlearn.plotting.classes_bokeh import Annotate from bokeh.io import output_notebook from bokeh.plotting import show +from stlearn.pl.classes_bokeh import Annotate + def annotate_interactive( adata: AnnData, diff --git a/stlearn/tools/clustering/kmeans.py b/stlearn/tl/clustering/kmeans.py similarity index 70% rename from stlearn/tools/clustering/kmeans.py rename to stlearn/tl/clustering/kmeans.py index 0b451cb1..e689c70f 100644 --- a/stlearn/tools/clustering/kmeans.py +++ b/stlearn/tl/clustering/kmeans.py @@ -1,9 +1,8 @@ -from sklearn.cluster import KMeans -from anndata import AnnData -from typing import Optional, Union -import pandas as pd import numpy as np +import pandas as pd +from anndata import AnnData from natsort import natsorted +from sklearn.cluster import KMeans def kmeans( @@ -14,36 +13,38 @@ def kmeans( n_init: int = 10, max_iter: int = 300, tol: float = 0.0001, - random_state: str = None, + random_state: int | np.random.RandomState | None = None, copy_x: bool = True, - algorithm: str = "auto", + algorithm: str = "lloyd", key_added: str = "kmeans", copy: bool = False, -) -> Optional[AnnData]: - +) -> AnnData | None: """\ Perform kmeans cluster for spatial transcriptomics data Parameters ---------- - adata + adata: AnnData Annotated data matrix. - n_clusters + n_clusters: int, default = 20 The number of clusters to form as well as the number of centroids to generate. - use_data + use_data: str, default = "X_pca" Use dimensionality reduction result. - init - Method for initialization, defaults to 'k-means++' - max_iter + init: str, default = "k-means++" + Method for initialization, defaults to 'k-means++'. + n_init: int, default = 10 + Number of time the k-means algorithm will be run with different + centroid seeds. + max_iter: int, default = 300 Maximum number of iterations of the k-means algorithm for a single run. - tol - Relative tolerance with regards to inertia to declare convergence. - random_state + tol: float, default = 0.0001 + Relative tolerance with regard to inertia to declare convergence. + random_state: int | np.random.RandomState | None, default = None Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. - copy_x + copy_x: bool, default = True When pre-computing distances it is more numerically accurate to center the data first. If copy_x is True (default), then the original data is not modified, ensuring X is C-contiguous. If False, the original data @@ -51,14 +52,13 @@ def kmeans( numerical differences may be introduced by subtracting and then adding the data mean, in this case it will also not ensure that data is C-contiguous which may cause a significant slowdown. - algorithm - K-means algorithm to use. The classical EM-style algorithm is "full". - The "elkan" variation is more efficient by using the triangle - inequality, but currently doesn't support sparse data. "auto" chooses - "elkan" for dense data and "full" for sparse data. - key_added + algorithm: str, default = "lloyd" + K-means algorithm to use. The classical EM-style algorithm is "lloyd". + The "elkan" variation can be more efficient on some datasets with + well-defined clusters, by using the triangle inequality. + key_added: str, default = "kmeans" Key add to adata.obs - copy + copy: bool, default = False Return a copy instead of writing to adata. Returns ------- diff --git a/stlearn/tools/clustering/louvain.py b/stlearn/tl/clustering/louvain.py similarity index 75% rename from stlearn/tools/clustering/louvain.py rename to stlearn/tl/clustering/louvain.py index 8f5ea899..78e973dd 100644 --- a/stlearn/tools/clustering/louvain.py +++ b/stlearn/tl/clustering/louvain.py @@ -1,39 +1,28 @@ +from collections.abc import Mapping, Sequence from types import MappingProxyType -from typing import Optional, Tuple, Sequence, Type, Mapping, Any, Union +from typing import Any, Literal -import numpy as np -import pandas as pd +import scanpy from anndata import AnnData -from natsort import natsorted +from louvain.VertexPartition import MutableVertexPartition from numpy.random.mtrand import RandomState from scipy.sparse import spmatrix -from stlearn._compat import Literal - -try: - from louvain.VertexPartition import MutableVertexPartition -except ImportError: - - class MutableVertexPartition: - pass - - MutableVertexPartition.__module__ = "louvain.VertexPartition" -import scanpy def louvain( adata: AnnData, - resolution: Optional[float] = None, - random_state: Optional[Union[int, RandomState]] = 0, - restrict_to: Optional[Tuple[str, Sequence[str]]] = None, + resolution: float | None = None, + random_state: int | RandomState | None = 0, + restrict_to: tuple[str, Sequence[str]] | None = None, key_added: str = "louvain", - adjacency: Optional[spmatrix] = None, - flavor: Literal["vtraag", "igraph", "rapids"] = "vtraag", + adjacency: spmatrix | None = None, + flavor: Literal["vtraag", "igraph", "rapids"] = "vtraag", # noqa: F821 directed: bool = True, use_weights: bool = False, - partition_type: Optional[Type[MutableVertexPartition]] = None, + partition_type: type[MutableVertexPartition] | None = None, partition_kwargs: Mapping[str, Any] = MappingProxyType({}), copy: bool = False, -) -> Optional[AnnData]: +) -> AnnData | None: """\ Wrap function scanpy.tl.louvain Cluster cells into subgroups [Blondel08]_ [Levine15]_ [Traag17]_. @@ -45,37 +34,37 @@ def louvain( or explicitly passing a ``adjacency`` matrix. Parameters ---------- - adata + adata: The annotated data matrix. - resolution + resolution: For the default flavor (``'vtraag'``), you can provide a resolution (higher resolution means finding more and smaller clusters), which defaults to 1.0. See “Time as a resolution parameter” in [Lambiotte09]_. - random_state + random_state: Change the initialization of the optimization. - restrict_to + restrict_to: Restrict the cluster to the categories within the key for sample annotation, tuple needs to contain ``(obs_key, list_of_categories)``. - key_added + key_added: Key under which to add the cluster labels. (default: ``'louvain'``) - adjacency + adjacency: Sparse adjacency matrix of the graph, defaults to ``adata.uns['neighbors']['connectivities']``. - flavor + flavor: Choose between to packages for computing the cluster. ``'vtraag'`` is much more powerful, and the default. - directed + directed: Interpret the ``adjacency`` matrix as directed graph? - use_weights + use_weights: Use weights from knn graph. - partition_type + partition_type: Type of partition to use. Only a valid argument if ``flavor`` is ``'vtraag'``. - partition_kwargs + partition_kwargs: Key word arguments to pass to partitioning, if ``vtraag`` method is being used. - copy + copy: Copy adata or modify it inplace. Returns ------- @@ -88,7 +77,7 @@ def louvain( When ``copy=True`` is set, a copy of ``adata`` with those fields is returned. """ - scanpy.tl.louvain( + adata = scanpy.tl.louvain( adata, resolution=resolution, random_state=random_state, @@ -107,3 +96,5 @@ def louvain( print( "Louvain cluster is done! The labels are stored in adata.obs['%s']" % key_added ) + + return adata diff --git a/stlearn/tl/label/__init__.py b/stlearn/tl/label/__init__.py new file mode 100644 index 00000000..f07ffcf5 --- /dev/null +++ b/stlearn/tl/label/__init__.py @@ -0,0 +1,7 @@ +from .label import run_label_transfer, run_rctd, run_singleR + +__all__ = [ + "run_singleR", + "run_rctd", + "run_label_transfer", +] diff --git a/stlearn/tools/label/label.py b/stlearn/tl/label/label.py similarity index 90% rename from stlearn/tools/label/label.py rename to stlearn/tl/label/label.py index 2fb1960d..96b5e84b 100644 --- a/stlearn/tools/label/label.py +++ b/stlearn/tl/label/label.py @@ -3,18 +3,18 @@ """ import os + import numpy as np -import pandas as pd import scanpy as sc -import stlearn.tools.microenv.cci.r_helpers as rhs +import stlearn.tl.cci.r_helpers as rhs def run_label_transfer( st_data, sc_data, sc_label_col, r_path, st_label_col=None, n_highly_variable=2000 ): """Runs Seurat label transfer.""" - st_label_col = sc_label_col if type(st_label_col) == type(None) else st_label_col + st_label_col = sc_label_col if st_label_col is None else st_label_col # Setting up the R environment # rhs.rpy2_setup(r_path) @@ -90,18 +90,20 @@ def run_label_transfer( def get_counts(data): """Gets count data from anndata if available.""" # Standard layer has counts # - if type(data.X) != np.ndarray and np.all(np.mod(data.X[0, :].todense(), 1) == 0): + if not isinstance(data.X, np.ndarray) and np.all( + np.mod(data.X[0, :].todense(), 1) == 0 + ): counts = data.to_df().transpose() - elif type(data.X) == np.ndarray and np.all(np.mod(data.X[0, :], 1) == 0): + elif isinstance(data.X, np.ndarray) and np.all(np.mod(data.X[0, :], 1) == 0): counts = data.to_df().transpose() elif ( - type(data.X) != np.ndarray + not isinstance(data.X, np.ndarray) and hasattr(data, "raw") and np.all(np.mod(data.raw.X[0, :].todense(), 1) == 0) ): counts = data.raw.to_adata()[data.obs_names, data.var_names].to_df().transpose() elif ( - type(data.X) == np.ndarray + isinstance(data.X, np.ndarray) and hasattr(data, "raw") and np.all(np.mod(data.raw.X[0, :], 1) == 0) ): @@ -127,9 +129,9 @@ def run_rctd( n_cores=1, ): """Runs RCTD for deconvolution.""" - st_label_col = sc_label_col if type(st_label_col) == type(None) else st_label_col + st_label_col = sc_label_col if st_label_col is None else st_label_col - ########### Setting up the R environment ############# + # Setting up the R environment rhs.rpy2_setup(r_path) # Adding the source R code # @@ -160,7 +162,7 @@ def run_rctd( sc_data.var["highly_variable"].values, st_data.var["highly_variable"].values ) - ###### Getting the count data (if available) ############ + # Getting the count data (if available) st_counts = get_counts(st_data) sc_counts = get_counts(sc_data) @@ -169,9 +171,9 @@ def run_rctd( st_coords = st_data.obs.loc[:, ["imagecol", "imagerow"]] sc_labels = sc_data.obs[sc_label_col].values.astype(str) - print(f"Finished extracting counts data.") + print("Finished extracting counts data.") - ####### Converting to R objects ######### + # Converting to R objects sc_labels_r = rhs.ro.StrVector(sc_labels) with rhs.localconverter(rhs.ro.default_converter + rhs.pandas2ri.converter): st_coords_r = rhs.ro.conversion.py2rpy(st_coords) @@ -179,7 +181,7 @@ def run_rctd( sc_counts_r = rhs.ro.conversion.py2rpy(sc_counts) print("Finished py->rpy conversion.") - ######## Running RCTD ########## + # Running RCTD print("Running RCTD...") rctd_proportions_r = rctd_r( st_counts_r, @@ -220,8 +222,8 @@ def run_singleR( de_method="t", ): """Runs SingleR spot annotation.""" - st_label_col = sc_label_col if type(st_label_col) == type(None) else st_label_col - ########### Setting up the R environment ############# + st_label_col = sc_label_col if st_label_col is None else st_label_col + # Setting up the R environment rhs.rpy2_setup(r_path) # Adding the source R code # @@ -253,13 +255,13 @@ def run_singleR( ) sc_data = sc_data[:, genes_bool] st_data = st_data[:, genes_bool] - print(f"Finished selecting & subsetting to hvgs.") + print("Finished selecting & subsetting to hvgs.") # Extracting the relevant information from anndatas # st_expr_df = st_data.to_df().transpose() sc_expr_df = sc_data.to_df().transpose() sc_labels = sc_data.obs[sc_label_col].values.astype(str) - print(f"Finished extracting data.") + print("Finished extracting data.") # R conversion of the data # sc_labels_r = rhs.ro.StrVector(sc_labels) diff --git a/stlearn/tools/label/label_transfer.R b/stlearn/tl/label/label_transfer.R similarity index 100% rename from stlearn/tools/label/label_transfer.R rename to stlearn/tl/label/label_transfer.R diff --git a/stlearn/tools/label/rctd.R b/stlearn/tl/label/rctd.R similarity index 100% rename from stlearn/tools/label/rctd.R rename to stlearn/tl/label/rctd.R diff --git a/stlearn/tools/label/singleR.R b/stlearn/tl/label/singleR.R similarity index 100% rename from stlearn/tools/label/singleR.R rename to stlearn/tl/label/singleR.R diff --git a/stlearn/tools/__init__.py b/stlearn/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/stlearn/tools/label/__init__.py b/stlearn/tools/label/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/stlearn/tools/microenv/__init__.py b/stlearn/tools/microenv/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/stlearn/tools/microenv/cci/__init__.py b/stlearn/tools/microenv/cci/__init__.py deleted file mode 100644 index 343fa4f3..00000000 --- a/stlearn/tools/microenv/cci/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# from .base import lr -# from .base_grouping import get_hotspots -# from . import het -# from .het import edge_core, get_between_spot_edge_array -# from .merge import merge -# from .permutation import get_rand_pairs -from .analysis import load_lrs, grid, run, adj_pvals, run_lr_go, run_cci diff --git a/stlearn/tools/microenv/cci/databases/__init__.py b/stlearn/tools/microenv/cci/databases/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/stlearn/types.py b/stlearn/types.py new file mode 100644 index 00000000..3006b748 --- /dev/null +++ b/stlearn/types.py @@ -0,0 +1,8 @@ +from typing import Literal + +_SIMILARITY_MATRIX = Literal["cosine", "euclidean", "pearson", "spearman"] +_METHOD = Literal["mean", "median", "sum"] +_QUALITY = Literal["fulres", "hires", "lowres"] +_BACKGROUND = Literal["black", "white"] + +__all__ = ["_SIMILARITY_MATRIX", "_METHOD", "_QUALITY", "_BACKGROUND"] diff --git a/stlearn/utils.py b/stlearn/utils.py index 0ea54262..6845a23c 100644 --- a/stlearn/utils.py +++ b/stlearn/utils.py @@ -1,18 +1,12 @@ -import numpy as np -import pandas as pd -import io -from PIL import Image -import matplotlib -from anndata import AnnData -import networkx as nx - -from typing import Optional, Union, Mapping # Special -from typing import Sequence, Iterable # ABCs -from typing import Tuple # Classes - +from collections.abc import Mapping +from enum import Enum from textwrap import dedent -from enum import Enum +import networkx as nx +import numpy as np +from anndata import AnnData +from matplotlib import axes +from matplotlib.axes import Axes class Empty(Enum): @@ -21,37 +15,32 @@ class Empty(Enum): _empty = Empty.token -from matplotlib import rcParams, ticker, gridspec, axes -from matplotlib.axes import Axes -from abc import ABC - class _AxesSubplot(Axes, axes.SubplotBase): """Intersection between Axes and SubplotBase: Has methods of both""" -def _check_spot_size( - spatial_data: Optional[Mapping], spot_size: Optional[float] -) -> float: +def _check_spot_size(spatial_data: Mapping | None, spot_size: float | None) -> float: """ Resolve spot_size value. This is a required argument for spatial plots. """ - if spatial_data is None and spot_size is None: + if spot_size is not None: + return spot_size + + if spatial_data is None: raise ValueError( "When .uns['spatial'][library_id] does not exist, spot_size must be " "provided directly." ) - elif spot_size is None: - return spatial_data["scalefactors"]["spot_diameter_fullres"] - else: - return spot_size + + return spatial_data["scalefactors"]["spot_diameter_fullres"] def _check_scale_factor( - spatial_data: Optional[Mapping], - img_key: Optional[str], - scale_factor: Optional[float], + spatial_data: Mapping | None, + img_key: str | None, + scale_factor: float | None, ) -> float: """Resolve scale_factor, defaults to 1.""" if scale_factor is not None: @@ -63,11 +52,16 @@ def _check_scale_factor( def _check_spatial_data( - uns: Mapping, library_id: Union[Empty, None, str] -) -> Tuple[Optional[str], Optional[Mapping]]: + uns: Mapping, library_id: Empty | None | str +) -> tuple[str | Empty | None, Mapping | None]: """ Given a mapping, try and extract a library id/ mapping with spatial data. Assumes this is `.uns` from how we parse visium data. + + Parameters + ---------- + library_id : None | str | Empty + If None - don't find an image. Empty - find best image, or specify with str. """ spatial_mapping = uns.get("spatial", {}) if library_id is _empty: @@ -88,38 +82,85 @@ def _check_spatial_data( def _check_img( - spatial_data: Optional[Mapping], - img: Optional[np.ndarray], - img_key: Union[None, str, Empty], + spatial_data: Mapping | None, + img: np.ndarray | None, + img_key: None | str | Empty, bw: bool = False, -) -> Tuple[Optional[np.ndarray], Optional[str]]: +) -> tuple[np.ndarray | None, str | None]: """ Resolve image for spatial plots. + + Parameters + ---------- + img : np.ndarray | None + If given an image will not look for another image and not check to see if it + was in spatial_data. + img_key : None | str | Empty + If None - don't find an image. Empty - find best image, or specify with str. + + Returns + ------- + tuple[np.ndarray | None, str | None] + The image found or nothing, str of the key of image found or None if none found. + + """ - if img is None and spatial_data is not None and img_key is _empty: - img_key = next( - (k for k in ["hires", "lowres", "fulres"] if k in spatial_data["images"]), - ) # Throws StopIteration Error if keys not present - if img is None and spatial_data is not None and img_key is not None: - img = spatial_data["images"][img_key] - if bw: - img = np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]) - return img, img_key + + # Return [None, None] if there's no anndata mapping or img + if spatial_data is None and img is None: + return None, None + else: + # Find image and key + new_img_key: str | None = None + new_img: np.ndarray | None = None + + # Return the img if not None and convert the key to Empty -> None if Empty + # otherwise keep. + if img is not None: + new_img = img + new_img_key = img_key if img_key is not _empty else None + # Find key if empty or use key. + elif spatial_data is not None: + if img_key is _empty: + # Looks for image - or None if not found. + new_img_key = next( + ( + k + for k in ["hires", "lowres", "fulres"] + if k in spatial_data["images"] + ), + None, + ) + else: + new_img_key = img_key + + if new_img_key is not None: + new_img = spatial_data["images"][new_img_key] + + if new_img is not None and bw: + new_img = np.dot(new_img[..., :3], [0.2989, 0.5870, 0.1140]) + + return new_img, new_img_key def _check_coords( - obsm: Optional[Mapping], scale_factor: Optional[float] -) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + obsm: Mapping | None, scale_factor: float | None +) -> tuple[np.ndarray, np.ndarray]: + if obsm is None: + raise ValueError("obsm cannot be None") + if scale_factor is None: + raise ValueError("scale_factor cannot be None") + if "spatial" not in obsm: + raise ValueError("'spatial' key not found in obsm") image_coor = obsm["spatial"] * scale_factor imagecol = image_coor[:, 0] imagerow = image_coor[:, 1] - return [imagecol, imagerow] - + return (imagecol, imagerow) -def _read_graph(adata: AnnData, graph_type: Optional[str]): +def _read_graph(adata: AnnData, graph_type: str | None): if graph_type == "PTS_graph": graph = nx.from_scipy_sparse_array( adata.uns[graph_type]["graph"], create_using=nx.DiGraph diff --git a/stlearn/wrapper/concatenate_spatial_adata.py b/stlearn/wrapper/concatenate_spatial_adata.py index c5d1ae07..a1c8b8ce 100644 --- a/stlearn/wrapper/concatenate_spatial_adata.py +++ b/stlearn/wrapper/concatenate_spatial_adata.py @@ -15,7 +15,6 @@ def transform_spatial(coordinates, original, resized): def correct_size(adata, fixed_size): - image = adata.uns["spatial"][list(adata.uns["spatial"].keys())[0]]["images"][ "hires" ] @@ -121,7 +120,7 @@ def concatenate_spatial_adata(adata_list, ncols=2, fixed_size=(2000, 2000)): for min_id in range(0, len(use_adata_list), ncols): img_row = np.hstack(imgs[min_id : min_id + ncols]) img_rows.append(img_row) - imgs_comb = np.vstack((i for i in img_rows)) + imgs_comb = np.vstack(i for i in img_rows) adata_concat = use_adata_list[0].concatenate(use_adata_list[1:]) adata_concat.uns["spatial"] = use_adata_list[0].uns["spatial"] diff --git a/stlearn/wrapper/convert_scanpy.py b/stlearn/wrapper/convert_scanpy.py index 4cf7e288..aac9c6a9 100644 --- a/stlearn/wrapper/convert_scanpy.py +++ b/stlearn/wrapper/convert_scanpy.py @@ -1,15 +1,10 @@ -from typing import Optional, Union from anndata import AnnData -from matplotlib import pyplot as plt -from pathlib import Path -import os def convert_scanpy( adata: AnnData, use_quality: str = "hires", -) -> Optional[AnnData]: - +) -> AnnData | None: adata.var_names_make_unique() library_id = list(adata.uns["spatial"].keys())[0] diff --git a/stlearn/wrapper/read.py b/stlearn/wrapper/read.py index a66bf512..02f75209 100644 --- a/stlearn/wrapper/read.py +++ b/stlearn/wrapper/read.py @@ -1,58 +1,57 @@ -"""Reading and Writing -""" +"""Reading and Writing""" -from pathlib import Path, PurePath -from typing import Optional, Union -from anndata import AnnData +import json +import logging as logg +from collections.abc import Iterator +from os import PathLike +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np -from PIL import Image import pandas as pd -import stlearn -from .._compat import Literal import scanpy -import scipy -import matplotlib.pyplot as plt +from anndata import AnnData from matplotlib.image import imread -import json +from PIL import Image -_QUALITY = Literal["fulres", "hires", "lowres"] -_background = ["black", "white"] +import stlearn +from stlearn.types import _BACKGROUND, _QUALITY +from stlearn.wrapper.xenium_alignment import apply_alignment_transformation def Read10X( - path: Union[str, Path], - genome: Optional[str] = None, + path: str | Path, + genome: str | None = None, count_file: str = "filtered_feature_bc_matrix.h5", - library_id: str = None, - load_images: Optional[bool] = True, + library_id: str | None = None, + load_images: bool = True, quality: _QUALITY = "hires", - image_path: Union[str, Path] = None, + image_path: str | Path | None = None, ) -> AnnData: """\ - Read Visium data from 10X (wrap read_visium from scanpy) - - In addition to reading regular 10x output, - this looks for the `spatial` folder and loads images, - coordinates and scale factors. - Based on the `Space Ranger output docs`_. + Read data from 10X. - .. _Space Ranger output docs: https://support.10xgenomics.com/spatial-gene-expression/software/pipelines/latest/output/overview + In addition to reading regular 10x output, this looks for the `spatial` folder + and loads images, coordinates and scale factors. Based on the + https://support.10xgenomics.com/spatial-gene-expression/software/pipelines/latest/output/overview Parameters ---------- path - Path to directory for visium datafiles. + The path to directory for the datafiles. genome Filter expression to genes within this genome. count_file - Which file in the passed directory to use as the count file. Typically would be one of: - 'filtered_feature_bc_matrix.h5' or 'raw_feature_bc_matrix.h5'. + Which file in the directory to use as the count file. Typically, it would be one + of: 'filtered_feature_bc_matrix.h5' or 'raw_feature_bc_matrix.h5'. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the library. Can be modified when concatenating multiple + adata objects. load_images Load image or not. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] image_path Path to image. Only need when loading full resolution image. @@ -92,9 +91,9 @@ def Read10X( with File(path / count_file, mode="r") as f: attrs = dict(f.attrs) + if library_id is None: library_id = str(attrs.pop("library_ids")[0], "utf-8") - adata.uns["spatial"][library_id] = dict() tissue_positions_file = ( @@ -116,8 +115,7 @@ def Read10X( if not f.exists(): if any(x in str(f) for x in ["hires_image", "lowres_image"]): logg.warning( - f"You seem to be missing an image file.\n" - f"Could not find '{f}'." + f"You seem to be missing an image file.\nCould not find '{f}'." ) else: raise OSError(f"Could not find '{f}'") @@ -159,39 +157,38 @@ def Read10X( adata.obsm["spatial"] = ( adata.obs[["pxl_row_in_fullres", "pxl_col_in_fullres"]] .to_numpy() - .astype(int) + .astype(float) ) adata.obs.drop( columns=["barcode", "pxl_row_in_fullres", "pxl_col_in_fullres"], inplace=True, ) - # put image path in uns - if image_path is not None: - # get an absolute path - image_path = str(Path(image_path).resolve()) - adata.uns["spatial"][library_id]["metadata"]["source_image_path"] = str( - image_path - ) - - adata.var_names_make_unique() + if quality == "fulres": + # put image path in uns + if image_path is not None: + # get an absolute path + image_path = str(Path(image_path).resolve()) + adata.uns["spatial"][library_id]["metadata"]["source_image_path"] = str( + image_path + ) + else: + raise ValueError("Trying to load fulres but no image_path set.") - if library_id is None: - library_id = list(adata.uns["spatial"].keys())[0] + image_coor = adata.obsm["spatial"] + img = plt.imread(image_path, None) + adata.uns["spatial"][library_id]["images"]["fulres"] = img + else: + scale = adata.uns["spatial"][library_id]["scalefactors"][ + "tissue_" + quality + "_scalef" + ] + image_coor = adata.obsm["spatial"] * scale - if quality == "fulres": - image_coor = adata.obsm["spatial"] - img = plt.imread(image_path, 0) - adata.uns["spatial"][library_id]["images"]["fulres"] = img - else: - scale = adata.uns["spatial"][library_id]["scalefactors"][ - "tissue_" + quality + "_scalef" - ] - image_coor = adata.obsm["spatial"] * scale + adata.obs["imagecol"] = image_coor[:, 0] + adata.obs["imagerow"] = image_coor[:, 1] + adata.uns["spatial"][library_id]["use_quality"] = quality - adata.obs["imagecol"] = image_coor[:, 0] - adata.obs["imagerow"] = image_coor[:, 1] - adata.uns["spatial"][library_id]["use_quality"] = quality + adata.var_names_make_unique() adata.obs["array_row"] = adata.obs["array_row"].astype(int) adata.obs["array_col"] = adata.obs["array_col"].astype(int) @@ -201,9 +198,9 @@ def Read10X( def ReadOldST( - count_matrix_file: Union[str, Path] = None, - spatial_file: Union[str, Path] = None, - image_file: Union[str, Path] = None, + count_matrix_file: PathLike[str] | str | Iterator[str], + spatial_file: int | str | bytes | PathLike[str] | PathLike[bytes], + image_file: str | Path | None = None, library_id: str = "OldST", scale: float = 1.0, quality: str = "hires", @@ -217,15 +214,17 @@ def ReadOldST( count_matrix_file Path to count matrix file. spatial_file - Path to spatial location file. + Path to the spatial location file. image_file Path to the tissue image file library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the library. Can be modified when concatenating multiple + adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] spot_diameter_fullres Diameter of spot in full resolution @@ -249,13 +248,13 @@ def ReadOldST( def ReadSlideSeq( - count_matrix_file: Union[str, Path], - spatial_file: Union[str, Path], - library_id: str = None, - scale: float = None, + count_matrix_file: str | Path, + spatial_file: str | Path, + library_id: str | None = None, + scale: float | None = None, quality: str = "hires", spot_diameter_fullres: float = 50, - background_color: _background = "white", + background_color: _BACKGROUND = "white", ) -> AnnData: """\ Read Slide-seq data @@ -265,17 +264,19 @@ def ReadSlideSeq( count_matrix_file Path to count matrix file. spatial_file - Path to spatial location file. + Path to the spatial location file. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the library. Can be modified when concatenating + multiple adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] spot_diameter_fullres Diameter of spot in full resolution background_color - Color of the backgound. Only `black` or `white` is allowed. + Color of the background. Only `black` or `white` is allowed. Returns ------- @@ -291,7 +292,7 @@ def ReadSlideSeq( adata.obs["index"] = meta["index"].values - if scale == None: + if scale is None: max_coor = np.max(meta[["x", "y"]].values) scale = 2000 / max_coor @@ -330,13 +331,13 @@ def ReadSlideSeq( def ReadMERFISH( - count_matrix_file: Union[str, Path], - spatial_file: Union[str, Path], - library_id: str = None, - scale: float = None, + count_matrix_file: str | Path, + spatial_file: str | Path, + library_id: str | None = None, + scale: float | None = None, quality: str = "hires", spot_diameter_fullres: float = 50, - background_color: _background = "white", + background_color: _BACKGROUND = "white", ) -> AnnData: """\ Read MERFISH data @@ -346,17 +347,19 @@ def ReadMERFISH( count_matrix_file Path to count matrix file. spatial_file - Path to spatial location file. + Path to the spatial location file. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the library. Can be modified when concatenating + multiple adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] spot_diameter_fullres Diameter of spot in full resolution background_color - Color of the backgound. Only `black` or `white` is allowed. + Color of the background. Only `black` or `white` is allowed. Returns ------- @@ -373,7 +376,7 @@ def ReadMERFISH( adata_merfish = counts[coordinates.index, :] adata_merfish.obsm["spatial"] = coordinates.to_numpy() - if scale == None: + if scale is None: max_coor = np.max(adata_merfish.obsm["spatial"]) scale = 2000 / max_coor @@ -411,14 +414,14 @@ def ReadMERFISH( def ReadSeqFish( - count_matrix_file: Union[str, Path], - spatial_file: Union[str, Path], - library_id: str = None, + count_matrix_file: str | Path, + spatial_file: str | Path, + library_id: str | None = None, scale: float = 1.0, quality: str = "hires", field: int = 0, spot_diameter_fullres: float = 50, - background_color: _background = "white", + background_color: _BACKGROUND = "white", ) -> AnnData: """\ Read SeqFish data @@ -430,17 +433,19 @@ def ReadSeqFish( spatial_file Path to spatial location file. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the library. Can be modified when concatenating multiple + adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] field Set field of view for SeqFish data spot_diameter_fullres Diameter of spot in full resolution background_color - Color of the backgound. Only `black` or `white` is allowed. + Color of the background. Only `black` or `white` is allowed. Returns ------- AnnData @@ -458,7 +463,7 @@ def ReadSeqFish( adata = AnnData(count) - if scale == None: + if scale is None: max_coor = np.max(spatial[["X", "Y"]]) scale = 2000 / max_coor @@ -497,14 +502,17 @@ def ReadSeqFish( def ReadXenium( - feature_cell_matrix_file: Union[str, Path], - cell_summary_file: Union[str, Path], - image_path: Optional[Path] = None, - library_id: str = None, + feature_cell_matrix_file: str | Path, + cell_summary_file: str | Path, + image_path: Path | None = None, + library_id: str | None = None, scale: float = 1.0, quality: str = "hires", spot_diameter_fullres: float = 15, - background_color: _background = "white", + background_color: _BACKGROUND = "white", + alignment_matrix_file: str | Path | None = None, + experiment_xenium_file: str | Path | None = None, + default_pixel_size_microns: float = 0.2125, ) -> AnnData: """\ Read Xenium data @@ -518,15 +526,25 @@ def ReadXenium( image_path Path to image. Only need when loading full resolution image. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the Xenium library. Can be modified when concatenating multiple + adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] spot_diameter_fullres Diameter of spot in full resolution background_color - Color of the backgound. Only `black` or `white` is allowed. + Color of the background. Only `black` or `white` is allowed. + alignment_matrix_file + Path to transformation matrix CSV file exported from Xenium Explorer. + If provided, coordinates will be transformed according to coordinate_space. + experiment_xenium_file + Path to experiment.xenium JSON file. If provided, pixel_size will be read from + here. + default_pixel_size_microns + Pixel size in microns (default 0.2125 for Xenium data). Returns ------- AnnData @@ -536,12 +554,34 @@ def ReadXenium( adata = scanpy.read_10x_h5(feature_cell_matrix_file) - spatial = metadata[["x_centroid", "y_centroid"]] - spatial.columns = ["imagecol", "imagerow"] + # Get original spatial coordinates + spatial = metadata[["x_centroid", "y_centroid"]].copy() + + # Get pixel size from experiment.xenium file or use parameter + if experiment_xenium_file is not None: + with open(experiment_xenium_file) as f: + experiment_data = json.load(f) + pixel_size_microns = experiment_data.get("pixel_size") + else: + pixel_size_microns = default_pixel_size_microns + print( + f"Warning: Using default pixel size of {pixel_size_microns} microns. " + "Consider providing experiment_xenium_file for accurate pixel size." + ) + + # Get and apply alignment transformation if provided + if alignment_matrix_file is not None: + transform_mat = pd.read_csv(alignment_matrix_file, header=None).values + spatial = apply_alignment_transformation( + spatial, + transform_mat, + pixel_size_microns, + ) + spatial.columns = ["imagecol", "imagerow"] adata.obsm["spatial"] = spatial.values - if scale == None: + if scale is None: max_coor = np.max(adata.obsm["spatial"]) scale = 2000 / max_coor @@ -551,7 +591,7 @@ def ReadXenium( adata.obs["imagecol"] = spatial["imagecol"].values * scale adata.obs["imagerow"] = spatial["imagerow"].values * scale - if image_path != None: + if image_path is not None: stlearn.add.image( adata, library_id=library_id, @@ -591,11 +631,11 @@ def create_stlearn( count: pd.DataFrame, spatial: pd.DataFrame, library_id: str, - image_path: Optional[Path] = None, - scale: float = None, + image_path: Path | None = None, + scale: float | None = None, quality: str = "hires", spot_diameter_fullres: float = 50, - background_color: _background = "white", + background_color: _BACKGROUND = "white", ): """\ Create AnnData object for stLearn @@ -607,15 +647,17 @@ def create_stlearn( spatial Pandas Dataframe of spatial location of cells/spots. library_id - Identifier for the visium library. Can be modified when concatenating multiple adata objects. + Identifier for the library. Can be modified when concatenating multiple + adata objects. scale Set scale factor. quality - Set quality that convert to stlearn to use. Store in anndata.obs['imagecol' & 'imagerow'] + Set quality that convert to stlearn to use. Store in + anndata.obs['imagecol' & 'imagerow'] spot_diameter_fullres Diameter of spot in full resolution background_color - Color of the backgound. Only `black` or `white` is allowed. + Color of the background. Only `black` or `white` is allowed. Returns ------- AnnData @@ -624,14 +666,14 @@ def create_stlearn( adata.obsm["spatial"] = spatial.values - if scale == None: + if scale is None: max_coor = np.max(adata.obsm["spatial"]) scale = 2000 / max_coor adata.obs["imagecol"] = spatial["imagecol"].values * scale adata.obs["imagerow"] = spatial["imagerow"].values * scale - if image_path != None: + if image_path is not None: stlearn.add.image( adata, library_id=library_id, diff --git a/stlearn/wrapper/xenium_alignment.py b/stlearn/wrapper/xenium_alignment.py new file mode 100644 index 00000000..368565db --- /dev/null +++ b/stlearn/wrapper/xenium_alignment.py @@ -0,0 +1,38 @@ +import numpy as np +import pandas as pd + + +def apply_alignment_transformation( + coordinates: pd.DataFrame, + transform_mat: np.ndarray, + pixel_size_microns: float = 0.2125, +) -> pd.DataFrame: + """ + Apply transformation matrix to convert coordinates between spaces. + + From https://kb.10xgenomics.com/hc/en-us/articles/35386990499853-How-can-I-convert-coordinates-between-H-E-image-and-Xenium-data + + Parameters + ---------- + coordinates + DataFrame with columns ['x_centroid', 'y_centroid'] in microns + transform_mat + Transformation matrix from Xenium project. + pixel_size_microns + Pixel size in microns + + Returns + ------- + pd.DataFrame + Transformed coordinates + """ + + # Microns to pixels and use inverse transformation matrix + coords_pixels = coordinates.values / pixel_size_microns + transform_mat_inv = np.linalg.inv(transform_mat) + coords_homogeneous = np.column_stack([coords_pixels, np.ones(len(coords_pixels))]) + transformed_coords = np.dot(coords_homogeneous, transform_mat_inv.T) + + # Extract x, y coordinates (ignore homogeneous coordinate) + result_coords = transformed_coords[:, :2] + return pd.DataFrame(result_coords, columns=coordinates.columns) diff --git a/tests/test_CCI.py b/tests/test_CCI.py index b17af940..7dc639f5 100644 --- a/tests/test_CCI.py +++ b/tests/test_CCI.py @@ -2,20 +2,16 @@ """Tests for `stlearn` package.""" -import os - import unittest import numpy as np from numba.typed import List import stlearn as st -import scanpy as sc +import stlearn.tl.cci.het as het +import stlearn.tl.cci.het_helpers as het_hs from tests.utils import read_test_data -import stlearn.tools.microenv.cci.het_helpers as het_hs -import stlearn.tools.microenv.cci.het as het - global adata adata = read_test_data() @@ -26,7 +22,7 @@ class TestCCI(unittest.TestCase): def setUp(self) -> None: """Setup some basic test-cases as sanity checks.""" - ##### Unit neighbourhood, containing just 1 spot and 6 neighbours ###### + # Unit neighbourhood, containing just 1 spot and 6 neighbours """ * A is the middle spot, B/C/D/E/F/G are the neighbouring spots clock- wise starting at the top-left. @@ -55,7 +51,7 @@ def setUp(self) -> None: self.neighbourhood_indices = neighbourhood_indices self.neigh_dict = neigh_dict - ##### Basic tests ####### + # Basic tests def test_load_lrs(self): """Testing loading lr database.""" sizes = [2293, 4071] # lit lr db size, putative lr db size. @@ -71,7 +67,7 @@ def test_load_lrs(self): lrs = st.tl.cci.load_lrs() self.assertEqual(len(lrs), sizes[0]) - ### Testing loading mouse as species #### + # Testing loading mouse as species lrs = st.tl.cci.load_lrs(species="mouse") genes1 = [lr_.split("_")[0] for lr_ in lrs] genes2 = [lr_.split("_")[1] for lr_ in lrs] @@ -80,9 +76,9 @@ def test_load_lrs(self): self.assertTrue(np.all([gene[0].isupper() for gene in genes2])) self.assertTrue(np.all([gene[1:] == gene[1:].lower() for gene in genes2])) - ####### Important, granular tests related to LR scoring ######### + # Important, granular tests related to LR scoring - ###### Important, granular tests related to CCI counting ####### + # Important, granular tests related to CCI counting def test_edge_retrieval_basic(self): """ Basic test of functionality to retrieve edges via \ get_between_spot_edge_array. @@ -93,7 +89,7 @@ def test_edge_retrieval_basic(self): # Initialising the edge list # edge_list = het_hs.init_edge_list(neighbourhood_bcs) - ############# Basic case, should populate with all edges ############### + # Basic case, should populate with all edges neigh_bool = np.array([True] * len(neighbourhood_bcs)) cell_data = np.array([1] * len(neighbourhood_bcs), dtype=np.float64) het_hs.get_between_spot_edge_array( @@ -115,7 +111,7 @@ def test_edge_retrieval_basic(self): np.all([edge in all_edges or edge[::-1] in all_edges for edge in edge_list]) ) - ########### Some neighbours not valid but no effect on edge list ####### + # Some neighbours not valid but no effect on edge list # No effect since though not a valid neighbour, still a valid spot # edge_list = het_hs.init_edge_list(neighbourhood_bcs) invalid_neighs = ["B", "E"] @@ -130,7 +126,7 @@ def test_edge_retrieval_basic(self): np.all([edge in all_edges or edge[::-1] in all_edges for edge in edge_list]) ) - ########### Some neighbours not valid, effects the edge list ########### + # Some neighbours not valid, effects the edge list # Two neighbouring spots no longer valid neighbours # edge_list = het_hs.init_edge_list(neighbourhood_bcs) invalid_neighs = ["B", "C"] @@ -152,7 +148,7 @@ def test_edge_retrieval_basic(self): ) ) - ######### Middle spot not neighbour, cell type, or spot of interest #### + # Middle spot not neighbour, cell type, or spot of interest # Removing the centre-spot as being the cell type of interest # neigh_bool = np.array([True] * len(neighbourhood_bcs)) neigh_bool[0] = False @@ -177,7 +173,7 @@ def test_edge_retrieval_basic(self): ) ) - ### Corner spot valid neighbour, not cell type, not spot of interest ### + # Corner spot valid neighbour, not cell type, not spot of interest neigh_bool = np.array([True] * len(neighbourhood_bcs)) cell_data = np.array([1] * len(neighbourhood_bcs), dtype=np.float64) cell_data[1] = 0 @@ -209,7 +205,7 @@ def test_get_interactions(self): and spots of another cell type expressing the receptor. """ - ####### Case 1 ###### + # Case 1 """ Middle spot only spot of interest. Cell type 1, 2, or 3. Middle spot expresses ligand. diff --git a/tests/test_PSTS.py b/tests/test_PSTS.py index 1a6b7676..08ceba53 100644 --- a/tests/test_PSTS.py +++ b/tests/test_PSTS.py @@ -2,13 +2,14 @@ """Tests for `stlearn` package.""" - import unittest -import stlearn as st +import numpy as np import scanpy as sc + +import stlearn as st + from .utils import read_test_data -import numpy as np global adata adata = read_test_data() diff --git a/tests/test_SME.py b/tests/test_SME.py index 96eec81f..49ea98ec 100644 --- a/tests/test_SME.py +++ b/tests/test_SME.py @@ -2,11 +2,12 @@ """Tests for `stlearn` package.""" - import unittest -import stlearn as st import scanpy as sc + +import stlearn as st + from .utils import read_test_data global adata diff --git a/tests/test_Spatial.py b/tests/test_Spatial.py new file mode 100644 index 00000000..7177466a --- /dev/null +++ b/tests/test_Spatial.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +"""Tests for `stlearn` package.""" + +import unittest + +import numpy.testing as npt + +from stlearn.classes import Spatial + +from .utils import read_test_data + +global adata +adata = read_test_data() + + +class TestSpatial(unittest.TestCase): + """Tests for `stlearn` package.""" + + def test_setup_Spatial(self): + spatial = Spatial(adata) + self.assertIsNotNone(spatial) + self.assertEqual("V1_Breast_Cancer_Block_A_Section_1", spatial.library_id) + self.assertEqual("hires", spatial.img_key) + self.assertEqual(177.4829519178534, spatial.spot_size) + self.assertEqual(True, spatial.crop_coord) + self.assertEqual(False, spatial.use_raw) + npt.assert_array_almost_equal( + [896.782, 1370.627, 1483.498, 1178.713, 1584.901], + spatial.imagecol[:5], + decimal=3, + ) + npt.assert_array_almost_equal( + [1549.092, 1158.003, 1040.594, 1373.267, 1021.205], + spatial.imagerow[:5], + decimal=3, + ) diff --git a/tests/test_cluster_plot.py b/tests/test_cluster_plot.py new file mode 100644 index 00000000..6711ea05 --- /dev/null +++ b/tests/test_cluster_plot.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python + +"""Tests for ClusterPlot.""" + +import unittest +from unittest.mock import MagicMock, patch + +import matplotlib.colors +import numpy as np +import pandas as pd + +from stlearn.pl.classes import ClusterPlot + +from .utils import read_test_data + +global adata +adata = read_test_data() + + +class TestClusterPlot(unittest.TestCase): + """Tests for ClusterPlot.""" + + def setUp(self): + """Set up test data with known clusters.""" + self.adata = adata.copy() + + # Create test clustering data + n_spots = len(self.adata.obs) + cluster_labels = np.random.choice( + ["Cluster_0", "Cluster_1", "Cluster_2"], n_spots + ) + self.adata.obs["test_clusters"] = pd.Categorical(cluster_labels) + + # Ensure we have a clean slate + if "test_clusters_colors" in self.adata.uns: + del self.adata.uns["test_clusters_colors"] + + def test_color_generation_first_call(self): + """Test that colors are generated correctly on first call.""" + with ( + patch("matplotlib.pyplot.subplots") as mock_subplots, + patch.object(ClusterPlot, "_plot_clusters") as _, + patch.object(ClusterPlot, "_add_image"), + ): + # Mock matplotlib components + mock_fig, mock_ax = MagicMock(), MagicMock() + mock_subplots.return_value = (mock_fig, mock_ax) + + # Create ClusterPlot + label_name = "test_clusters" + plot = ClusterPlot( + adata=self.adata, + use_label=label_name, + show_image=False, + show_color_bar=False, + ) + + # Check that colors were generated + colors = plot.adata[0].uns[f"{label_name}_colors"] + self.assertIsNotNone(colors) + self.assertEqual(len(colors), 3) # 3 clusters + + # Check that all colors are valid hex colors + for color in colors: + self.assertTrue(matplotlib.colors.is_color_like(color)) + self.assertTrue(color.startswith("#")) + self.assertEqual(len(color), 7) # #RRGGBB format + + def test_multiple_calls_same_adata(self): + """Test that multiple calls with same adata work correctly.""" + with ( + patch("matplotlib.pyplot.subplots") as mock_subplots, + patch.object(ClusterPlot, "_plot_clusters") as _, + patch.object(ClusterPlot, "_add_image"), + ): + mock_fig, mock_ax = MagicMock(), MagicMock() + mock_subplots.return_value = (mock_fig, mock_ax) + + label_name = "test_clusters" + + # First call + plot1 = ClusterPlot( + adata=self.adata, + use_label=label_name, + show_image=False, + show_color_bar=False, + ) + + # Second call with same adata + plot2 = ClusterPlot( + adata=self.adata, + use_label=label_name, + show_image=False, + show_color_bar=False, + ) + + # Both should succeed and generate consistent colors + colors1 = plot1.adata[0].uns[f"{label_name}_colors"] + colors2 = plot2.adata[0].uns[f"{label_name}_colors"] + + self.assertEqual(len(colors1), len(colors2)) + self.assertEqual(colors1, colors2) + + def test_insufficient_existing_colors_extended(self): + """Test that insufficient existing colors are extended.""" + # Pre-populate adata with insufficient colors (only 2 colors for 3 clusters) + existing_colors = ["#FF0000", "#00FF00"] + label_name = "test_clusters" + self.adata.uns[f"{label_name}_colors"] = existing_colors + + with ( + patch("matplotlib.pyplot.subplots") as mock_subplots, + patch.object(ClusterPlot, "_plot_clusters") as _, + patch.object(ClusterPlot, "_add_image"), + ): + mock_fig, mock_ax = MagicMock(), MagicMock() + mock_subplots.return_value = (mock_fig, mock_ax) + + plot = ClusterPlot( + adata=self.adata, + use_label=label_name, + show_image=False, + show_color_bar=False, + ) + + # Should extend existing colors + colors = plot.adata[0].uns[f"{label_name}_colors"] + self.assertEqual(len(colors), 3) + self.assertNotEqual(colors[:2], existing_colors) + + def tearDown(self): + """Clean up after each test.""" + # Clear any test artifacts + if hasattr(self, "adata") and "test_clusters_colors" in self.adata.uns: + del self.adata.uns["test_clusters_colors"] diff --git a/tests/test_extract_features.py b/tests/test_extract_features.py new file mode 100644 index 00000000..baaa2d06 --- /dev/null +++ b/tests/test_extract_features.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +import os +import shutil +import tempfile +import unittest + +import numpy as np +import scanpy as sc + +import stlearn as st + +from .utils import read_test_data + +global adata +adata = read_test_data() + + +class TestFeatureExtractionPerformance(unittest.TestCase): + """Comprehensive tests for feature extraction.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_data = adata.copy() + self.temp_dir = tempfile.mkdtemp() + sc.pp.pca(self.test_data) + st.pp.tiling(self.test_data, self.temp_dir) + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_deterministic_behavior(self): + """Test that results are deterministic with same seed.""" + data1 = self.test_data.copy() + data2 = self.test_data.copy() + + st.pp.extract_feature(data1, seeds=42) + st.pp.extract_feature(data2, seeds=42) + + np.testing.assert_array_equal( + data1.obsm["X_morphology"], + data2.obsm["X_morphology"], + err_msg="Results should be deterministic with same seed", + ) + np.testing.assert_array_equal( + data1.obsm["X_tile_feature"], + data2.obsm["X_tile_feature"], + err_msg="Results should be deterministic with same seed", + ) + + def test_copy_behavior(self): + """Test copy=True vs copy=False behavior.""" + original_data = self.test_data.copy() + + # Test copy=True + result_copy = st.pp.extract_feature(original_data, copy=True) + self.assertIsNotNone(result_copy) + self.assertNotIn("X_morphology", original_data.obsm) + self.assertIn("X_morphology", result_copy.obsm) + + # Test copy=False + result_inplace = st.pp.extract_feature(original_data, copy=False) + self.assertIsNone(result_inplace) + self.assertIn("X_morphology", original_data.obsm) diff --git a/tests/test_install.py b/tests/test_install.py deleted file mode 100644 index 5ff4160e..00000000 --- a/tests/test_install.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Tests that everything is installed correctly. -""" - -import unittest - - -class TestCCI(unittest.TestCase): - """Tests for `stlearn` importability, i.e. correct installation.""" - - def test_SME(self): - import stlearn.spatials.SME.normalize as sme_normalise - - def test_cci(self): - """Tests CCI can be imported.""" - import stlearn.tools.microenv.cci.analysis as an diff --git a/tests/test_tiling.py b/tests/test_tiling.py new file mode 100644 index 00000000..fa6ce0cc --- /dev/null +++ b/tests/test_tiling.py @@ -0,0 +1,77 @@ +# !/usr/bin/env python + +"""Tests for tiling function.""" + +import os +import shutil +import tempfile +import unittest +from pathlib import Path + +import numpy as np + +import stlearn as st + +from .utils import read_test_data + +global adata +adata = read_test_data() + + +class TestTiling(unittest.TestCase): + """Tests for `stlearn` package.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_data = adata.copy() + self.temp_dir = tempfile.mkdtemp() + self.temp_dir_orig = tempfile.mkdtemp(suffix="_orig") + + # Ensure we have required spatial data + if "spatial" not in self.test_data.uns: + self.skipTest("Test data missing spatial information") + + # Add imagerow/imagecol if missing (for testing) + if "imagerow" not in self.test_data.obs: + # Create synthetic coordinates for testing + n_spots = len(self.test_data) + self.test_data.obs["imagerow"] = np.random.randint(50, 450, n_spots) + self.test_data.obs["imagecol"] = np.random.randint(50, 450, n_spots) + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + if os.path.exists(self.temp_dir_orig): + shutil.rmtree(self.temp_dir_orig) + + def test_directory_creation(self): + """Test directory creation behavior.""" + + # Test nested directory creation + nested_path = Path(self.temp_dir) / "level1" / "level2" / "tiles" + data = self.test_data.copy() + + st.pp.tiling(data, nested_path) + + self.assertTrue(nested_path.exists(), "Nested directories not created") + self.assertGreater( + len(list(nested_path.glob("*"))), 0, "No files in nested directory" + ) + + def test_quality_parameter(self): + """Test JPEG quality parameter.""" + data = self.test_data[:3].copy() # Small subset + + # Test different quality settings + for quality in [50, 95]: + temp_quality = tempfile.mkdtemp(suffix=f"_q{quality}") + test_data = data.copy() + + st.pp.tiling(test_data, temp_quality, img_fmt="JPEG", quality=quality) + + # Verify files exist + jpeg_files = list(Path(temp_quality).glob("*.jpeg")) + self.assertEqual(len(jpeg_files), len(data)) + + shutil.rmtree(temp_quality) diff --git a/tests/utils.py b/tests/utils.py index a10b5d21..98482f96 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,8 @@ import os + +import numpy as np import scanpy as sc from PIL import Image -import numpy as np def read_test_data(): @@ -10,7 +11,7 @@ def read_test_data(): path = os.path.dirname(os.path.realpath(__file__)) adata = sc.read_h5ad(f"{path}/test_data/test_data.h5") im = Image.open(f"{path}/test_data/test_image.jpg") - adata.uns["spatial"]["V1_Breast_Cancer_Block_A_Section_1"]["images"][ - "hires" - ] = np.array(im) + adata.uns["spatial"]["V1_Breast_Cancer_Block_A_Section_1"]["images"]["hires"] = ( + np.array(im) + ) return adata diff --git a/tox.ini b/tox.ini index 9aae612f..dcdf7115 100644 --- a/tox.ini +++ b/tox.ini @@ -1,20 +1,32 @@ [tox] -envlist = py35, py36, py37, py38, flake8 +requires = + tox>=4 +env_list = lint, type, 3.10, ruff -[travis] -python = - 3.8: py38 - 3.7: py37 - 3.6: py36 - 3.5: py35 +[testenv:lint] +description = run linters +skip_install = true +deps = + black +commands = black {posargs:.} -[testenv:flake8] -basepython = python -deps = flake8 -commands = flake8 stlearn +[testenv:type] +description = run type checks +deps = + mypy +commands = + mypy {posargs:stlearn tests} + +[testenv:ruff] +description = run ruff linting and formatting +skip_install = true +deps = ruff +commands = + ruff check stlearn tests [testenv] setenv = PYTHONPATH = {toxinidir} - -commands = python setup.py test +deps = + pytest +commands = pytest {posargs}