diff --git a/docs/algorithm_math.tex b/docs/algorithm_math.tex index 14432ef..bb135a6 100644 --- a/docs/algorithm_math.tex +++ b/docs/algorithm_math.tex @@ -25,17 +25,38 @@ \section{Inputs and notation} \item Expression values: $\mathbf{V} \in \mathbb{R}^{B \times L}$ \item Padding mask: $\mathbf{M} \in \{0,1\}^{B \times L}$, where $M_{b\ell}=1$ means padding \end{itemize} + +\paragraph{Important clarification: $L$ is \emph{not} vocabulary size.} +In this document, $L$ always means the \textbf{token sequence length per cell} after preprocessing/padding (roughly, number of gene tokens fed to the model for one cell). Vocabulary size is a separate quantity, denoted here as $|\mathcal{V}|$, and controls the size of the embedding table. + Optional metadata: \begin{itemize}[leftmargin=2em] \item Batch labels: $\mathbf{b} \in \{1,\dots,K\}^{B}$ \item Modality IDs (multi-omics): $\mathbf{T} \in \{1,\dots,R\}^{B\times L}$ \end{itemize} +\subsection*{Symbol glossary (quick reference)} +\begin{itemize}[leftmargin=2em] + \item $B$: mini-batch size (number of cells per optimization step) + \item $L$: sequence length (tokens per cell), \textbf{not} vocabulary size + \item $|\mathcal{V}|$: gene-token vocabulary size + \item $d$: hidden/embedding width (\texttt{d\_model}) + \item $N$: number of stacked Mamba encoder layers +\end{itemize} + \section{Token/value encoding} \subsection{Gene embedding} \begin{equation} \mathbf{E}_{g} = \mathrm{LN}(\mathrm{Embed}_{g}(\mathbf{G})) \in \mathbb{R}^{B\times L\times d}. \end{equation} +Here, $\mathrm{Embed}_{g}:\{1,\dots,|\mathcal{V}|\}\to\mathbb{R}^{d}$ maps each gene token id to a learnable vector. + +\paragraph{What is LN?} +$\mathrm{LN}$ is \textbf{LayerNorm} (layer normalization). For each token vector $\mathbf{x}\in\mathbb{R}^{d}$, it normalizes across the feature dimension: +\begin{equation} +\mathrm{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}-\mu(\mathbf{x})}{\sqrt{\sigma^2(\mathbf{x})+\epsilon}} + \beta, +\end{equation} +where $\mu(\mathbf{x})=\frac{1}{d}\sum_{j=1}^d x_j$, $\sigma^2(\mathbf{x})=\frac{1}{d}\sum_{j=1}^d(x_j-\mu)^2$, and $\gamma,\beta\in\mathbb{R}^{d}$ are learnable parameters. Intuitively, LN stabilizes optimization by keeping each token representation at a controlled scale. \subsection{Value embedding (continuous style)} For the default continuous value pathway: @@ -46,6 +67,9 @@ \subsection{Value embedding (continuous style)} where $W_1\in\mathbb{R}^{1\times d}$, $W_2\in\mathbb{R}^{d\times d}$, and $\sigma$ is ReLU. Values are clipped by a configured \texttt{max\_value}. +\paragraph{Interpretation.} +Gene identity and expression magnitude are encoded separately and then fused. This helps the model distinguish ``\emph{which gene}" (from $\mathbf{E}_g$) from ``\emph{how much expression}" (from $\mathbf{E}_v$). + \subsection{Fusion and normalization} Default additive fusion: \begin{equation} @@ -56,6 +80,9 @@ \subsection{Fusion and normalization} \mathbf{H}^{(0)} \leftarrow \mathrm{BN}(\mathbf{H}^{(0)}). \end{equation} +\paragraph{BN vs DSBN.} +BN is shared across all batches/domains. DSBN (domain-specific BN) keeps separate normalization statistics (and optionally affine parameters) per batch/domain label, which can reduce domain-shift effects. + \section{Bidirectional Mamba encoder stack} Let the number of Mamba blocks be $N$. For each layer $i\in\{1,\dots,N\}$: @@ -82,6 +109,9 @@ \section{Bidirectional Mamba encoder stack} \mathbf{H} = \mathbf{H}^{(N)} \in \mathbb{R}^{B\times L\times d}. \end{equation} +\paragraph{Why two directions?} +The forward stream captures left-to-right context, while the flipped stream approximates right-to-left context. Averaging both gives a bidirectional representation without Transformer self-attention. + \section{Cell embedding extraction} Given $\mathbf{H}\in\mathbb{R}^{B\times L\times d}$, cell embedding $\mathbf{c}\in\mathbb{R}^{B\times d}$ is computed by one of: \begin{align} @@ -118,6 +148,9 @@ \subsection{MVC head (masked value prediction with cell embedding)} thus $\widehat{\mathbf{Y}}^{\mathrm{mvc}}\in\mathbb{R}^{B\times L}$. An optional zero-probability head is analogously defined. +\paragraph{Why MVC uses gene embeddings.} +The query is produced from raw gene token embeddings to avoid leaking masked expression values directly through the query path. + \subsection{Adversarial batch discriminator (optional)} A gradient-reversal classifier predicts batch labels from $\mathbf{c}$, encouraging batch-invariant latent representations. @@ -171,6 +204,9 @@ \subsection{Total objective} \lambda_7\mathcal{L}_{\mathrm{DAB}}. \end{equation} +\paragraph{Interpretation of the objective.} +The loss combines token-level reconstruction (GEP/MVC), sparsity-aware likelihood (NZLP), and cell-level supervision/regularization (CLS/ECS/DAB). This multitask formulation shapes representations that are useful both for recovering masked expression and for downstream biological tasks. + \section{Why this scales to ultra-long transcriptome sequences} \begin{itemize}[leftmargin=2em] \item Mamba2 replaces quadratic self-attention with state-space sequence modeling.