From 3b5be43ac1913ee30232edee76bcb9728246dbeb Mon Sep 17 00:00:00 2001 From: Hong Qin Date: Tue, 17 Feb 2026 14:41:42 -0800 Subject: [PATCH] Expand LaTeX algorithm doc with symbol and LayerNorm explanations --- docs/algorithm_math.tex | 223 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 docs/algorithm_math.tex diff --git a/docs/algorithm_math.tex b/docs/algorithm_math.tex new file mode 100644 index 0000000..633ed5e --- /dev/null +++ b/docs/algorithm_math.tex @@ -0,0 +1,223 @@ +\documentclass[11pt]{article} +\usepackage[a4paper,margin=1in]{geometry} +\usepackage{amsmath,amssymb,amsfonts,bm,mathtools} +\usepackage{booktabs} +\usepackage{enumitem} +\usepackage{hyperref} +\usepackage{graphicx} +\usepackage{float} + +\title{SC-MAMBA2 Algorithm: Mathematical Formulation with Tensor Dimensions} +\author{} +\date{} + +\begin{document} +\maketitle + +\noindent +This document summarizes the core SC-MAMBA2 forward pass and objectives implemented in +\texttt{scmamba2/model/model.py} and \texttt{scmamba2/trainer.py}. + +\section{Inputs and notation} +For a mini-batch with batch size $B$, sequence length $L$, and model width $d$: +\begin{itemize}[leftmargin=2em] + \item Gene token IDs: $\mathbf{G} \in \mathbb{N}^{B \times L}$ + \item Expression values: $\mathbf{V} \in \mathbb{R}^{B \times L}$ + \item Padding mask: $\mathbf{M} \in \{0,1\}^{B \times L}$, where $M_{b\ell}=1$ means padding +\end{itemize} + +\paragraph{Important clarification: $L$ is \emph{not} vocabulary size.} +In this document, $L$ always means the \textbf{token sequence length per cell} after preprocessing/padding (roughly, number of gene tokens fed to the model for one cell). Vocabulary size is a separate quantity, denoted here as $|\mathcal{V}|$, and controls the size of the embedding table. + +Optional metadata: +\begin{itemize}[leftmargin=2em] + \item Batch labels: $\mathbf{b} \in \{1,\dots,K\}^{B}$ + \item Modality IDs (multi-omics): $\mathbf{T} \in \{1,\dots,R\}^{B\times L}$ +\end{itemize} + +\subsection*{Symbol glossary (quick reference)} +\begin{itemize}[leftmargin=2em] + \item $B$: mini-batch size (number of cells per optimization step) + \item $L$: sequence length (tokens per cell), \textbf{not} vocabulary size + \item $|\mathcal{V}|$: gene-token vocabulary size + \item $d$: hidden/embedding width (\texttt{d\_model}) + \item $N$: number of stacked Mamba encoder layers +\end{itemize} + +\section{Token/value encoding} +\subsection{Gene embedding} +\begin{equation} +\mathbf{E}_{g} = \mathrm{LN}(\mathrm{Embed}_{g}(\mathbf{G})) \in \mathbb{R}^{B\times L\times d}. +\end{equation} +Here, $\mathrm{Embed}_{g}:\{1,\dots,|\mathcal{V}|\}\to\mathbb{R}^{d}$ maps each gene token id to a learnable vector. + +\paragraph{What is LN?} +$\mathrm{LN}$ is \textbf{LayerNorm} (layer normalization). For each token vector $\mathbf{x}\in\mathbb{R}^{d}$, it normalizes across the feature dimension: +\begin{equation} +\mathrm{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}-\mu(\mathbf{x})}{\sqrt{\sigma^2(\mathbf{x})+\epsilon}} + \beta, +\end{equation} +where $\mu(\mathbf{x})=\frac{1}{d}\sum_{j=1}^d x_j$, $\sigma^2(\mathbf{x})=\frac{1}{d}\sum_{j=1}^d(x_j-\mu)^2$, and $\gamma,\beta\in\mathbb{R}^{d}$ are learnable parameters. Intuitively, LN stabilizes optimization by keeping each token representation at a controlled scale. + +\subsection{Value embedding (continuous style)} +For the default continuous value pathway: +\begin{equation} +\mathbf{E}_{v} = \mathrm{Dropout}\!\left(\mathrm{LN}\!\left(W_2\,\sigma\!\left(W_1\,\mathrm{clip}(\mathbf{V})\right)\right)\right) +\in \mathbb{R}^{B\times L\times d}, +\end{equation} +where $W_1\in\mathbb{R}^{1\times d}$, $W_2\in\mathbb{R}^{d\times d}$, and $\sigma$ is ReLU. +Values are clipped by a configured \texttt{max\_value}. + +\paragraph{Interpretation.} +Gene identity and expression magnitude are encoded separately and then fused. This helps the model distinguish ``\emph{which gene}" (from $\mathbf{E}_g$) from ``\emph{how much expression}" (from $\mathbf{E}_v$). + +\subsection{Fusion and normalization} +Default additive fusion: +\begin{equation} +\mathbf{H}^{(0)} = \mathbf{E}_{g} + \mathbf{E}_{v} \in \mathbb{R}^{B\times L\times d}. +\end{equation} +Then channel-wise normalization is applied (BN or DSBN): +\begin{equation} +\mathbf{H}^{(0)} \leftarrow \mathrm{BN}(\mathbf{H}^{(0)}). +\end{equation} + +\paragraph{BN vs DSBN.} +BN is shared across all batches/domains. DSBN (domain-specific BN) keeps separate normalization statistics (and optionally affine parameters) per batch/domain label, which can reduce domain-shift effects. + +\section{Bidirectional Mamba encoder stack} +Let the number of Mamba blocks be $N$. For each layer $i\in\{1,\dots,N\}$: + +\paragraph{Forward branch} +\begin{equation} +\mathbf{F}^{(i)} = \mathrm{Mamba}_i\!\left(\mathbf{H}^{(i-1)}\right) + \mathbf{H}^{(i-1)}. +\end{equation} + +\paragraph{Backward branch} +SC-MAMBA2 flips the non-padding portion of each sequence (padding-aware \texttt{smart\_flip}): +\begin{align} +\widetilde{\mathbf{H}}^{(i-1)} &= \mathrm{flip}\!\left(\mathbf{H}^{(i-1)}\right),\\ +\widetilde{\mathbf{B}}^{(i)} &= \mathrm{Mamba}_i\!\left(\widetilde{\mathbf{H}}^{(i-1)}\right) + \widetilde{\mathbf{H}}^{(i-1)},\\ +\mathbf{B}^{(i)} &= \mathrm{flip}^{-1}\!\left(\widetilde{\mathbf{B}}^{(i)}\right). +\end{align} + +\paragraph{Bidirectional fusion} +\begin{equation} +\mathbf{H}^{(i)} = \frac{1}{2}\left(\mathbf{F}^{(i)} + \mathbf{B}^{(i)}\right) +\in \mathbb{R}^{B\times L\times d}. +\end{equation} +Final encoder output: +\begin{equation} +\mathbf{H} = \mathbf{H}^{(N)} \in \mathbb{R}^{B\times L\times d}. +\end{equation} + +\paragraph{Why two directions?} +The forward stream captures left-to-right context, while the flipped stream approximates right-to-left context. Averaging both gives a bidirectional representation without Transformer self-attention. + +\section{Cell embedding extraction} +Given $\mathbf{H}\in\mathbb{R}^{B\times L\times d}$, cell embedding $\mathbf{c}\in\mathbb{R}^{B\times d}$ is computed by one of: +\begin{align} +\texttt{cls}:\quad &\mathbf{c}=\mathbf{H}_{:,0,:},\\ +\texttt{avg-pool}:\quad &\mathbf{c}=\frac{1}{L}\sum_{\ell=1}^{L}\mathbf{H}_{:,\ell,:},\\ +\texttt{w-pool}:\quad &\mathbf{c}=\mathrm{norm}\!\left(\sum_{\ell=1}^{L} w_{:\ell}\,\mathbf{H}_{:,\ell,:}\right). +\end{align} + +\paragraph{Practical meaning.} +$\mathbf{c}$ is the cell-level summary vector used by downstream tasks (annotation, integration, adversarial batch correction, etc.). + +\section{Decoder heads} +\subsection{Expression decoder (MLM/GEP head)} +Token-wise MLP predicts expression values: +\begin{equation} +\widehat{\mathbf{Y}} = f_{\mathrm{expr}}(\mathbf{H}) \in \mathbb{R}^{B\times L}. +\end{equation} +When enabled, batch/modality embeddings are concatenated with token states, changing decoder input width from $d$ to $2d$. + +Optional explicit zero-probability head: +\begin{equation} +\mathbf{P}_{0} = \sigma\!\left(f_{0}(\mathbf{H})\right) \in [0,1]^{B\times L}. +\end{equation} + +\paragraph{Why explicit zero modeling?} +Single-cell data are sparse (many zeros). Predicting a dedicated zero probability can model dropout/zero inflation better than pure regression alone. + +\subsection{Cell-type classification head} +\begin{equation} +\widehat{\mathbf{z}} = f_{\mathrm{cls}}(\mathbf{c}) \in \mathbb{R}^{B\times C}, +\end{equation} +where $C$ is the number of cell classes. + +\subsection{MVC head (masked value prediction with cell embedding)} +For default inner-product MVC: +\begin{align} +\mathbf{Q} &= \phi\!\left(W_q\,\mathbf{E}_{g}\right) \in \mathbb{R}^{B\times L\times d'},\\ +\widehat{Y}^{\mathrm{mvc}}_{b\ell} &= \left\langle W\mathbf{Q}_{b\ell:},\, \mathbf{c}_{b}\right\rangle, +\end{align} +thus $\widehat{\mathbf{Y}}^{\mathrm{mvc}}\in\mathbb{R}^{B\times L}$. +An optional zero-probability head is analogously defined. + +\paragraph{Why MVC uses gene embeddings.} +The query is produced from raw gene token embeddings to avoid leaking masked expression values directly through the query path. + +\subsection{Adversarial batch discriminator (optional)} +A gradient-reversal classifier predicts batch labels from $\mathbf{c}$, encouraging batch-invariant latent representations. + +\section{Training objectives} +Let +\begin{equation} +\Omega = \{(b,\ell)\,|\,V_{b\ell}\ \text{is masked}\}. +\end{equation} + +\subsection{Masked expression regression (GEP)} +\begin{equation} +\mathcal{L}_{\mathrm{GEP}} = \frac{1}{|\Omega|}\sum_{(b,\ell)\in\Omega} +\left(\widehat{Y}_{b\ell}-Y_{b\ell}\right)^2. +\end{equation} + +\subsection{Explicit zero modeling (optional)} +Bernoulli negative log-likelihood: +\begin{equation} +\mathcal{L}_{\mathrm{NZLP}} = -\frac{1}{|\Omega|}\sum_{(b,\ell)\in\Omega} +\left[\mathbb{1}(Y_{b\ell}=0)\log P_{0,b\ell} + \mathbb{1}(Y_{b\ell}>0)\log(1-P_{0,b\ell})\right]. +\end{equation} + +\subsection{MVC regression (optional)} +\begin{equation} +\mathcal{L}_{\mathrm{MVC}} = \frac{1}{|\Omega|}\sum_{(b,\ell)\in\Omega} +\left(\widehat{Y}^{\mathrm{mvc}}_{b\ell}-Y_{b\ell}\right)^2. +\end{equation} + +\subsection{Cell-type classification (optional)} +\begin{equation} +\mathcal{L}_{\mathrm{CLS}} = \mathrm{CE}(\widehat{\mathbf{z}},\mathbf{y}_{\mathrm{cell}}). +\end{equation} + +\subsection{ECS regularizer (optional)} +With normalized cell embeddings and pairwise cosine matrix $S$: +\begin{equation} +\mathcal{L}_{\mathrm{ECS}} = \mathbb{E}\left[1-(S-\tau)^2\right], +\end{equation} +where $\tau$ is \texttt{ecs\_threshold}. + +\subsection{Total objective} +Enabled terms are summed during training: +\begin{equation} +\mathcal{L} = +\lambda_1\mathcal{L}_{\mathrm{GEP}} + +\lambda_2\mathcal{L}_{\mathrm{NZLP}} + +\lambda_3\mathcal{L}_{\mathrm{MVC}} + +\lambda_4\mathcal{L}_{\mathrm{MVC\text{-}NZLP}} + +\lambda_5\mathcal{L}_{\mathrm{CLS}} + +\lambda_6\mathcal{L}_{\mathrm{ECS}} + +\lambda_7\mathcal{L}_{\mathrm{DAB}}. +\end{equation} + +\paragraph{Interpretation of the objective.} +The loss combines token-level reconstruction (GEP/MVC), sparsity-aware likelihood (NZLP), and cell-level supervision/regularization (CLS/ECS/DAB). This multitask formulation shapes representations that are useful both for recovering masked expression and for downstream biological tasks. + +\section{Why this scales to ultra-long transcriptome sequences} +\begin{itemize}[leftmargin=2em] + \item Mamba2 replaces quadratic self-attention with state-space sequence modeling. + \item SC-MAMBA2 adds explicit bidirectionality by forward/reversed shared-layer passes. + \item A single encoder supports both token-level reconstruction and cell-level downstream objectives. +\end{itemize} + +\end{document}