Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/algorithm_math.tex
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,38 @@ \section{Inputs and notation}
\item Expression values: $\mathbf{V} \in \mathbb{R}^{B \times L}$
\item Padding mask: $\mathbf{M} \in \{0,1\}^{B \times L}$, where $M_{b\ell}=1$ means padding
\end{itemize}

\paragraph{Important clarification: $L$ is \emph{not} vocabulary size.}
In this document, $L$ always means the \textbf{token sequence length per cell} after preprocessing/padding (roughly, number of gene tokens fed to the model for one cell). Vocabulary size is a separate quantity, denoted here as $|\mathcal{V}|$, and controls the size of the embedding table.

Optional metadata:
\begin{itemize}[leftmargin=2em]
\item Batch labels: $\mathbf{b} \in \{1,\dots,K\}^{B}$
\item Modality IDs (multi-omics): $\mathbf{T} \in \{1,\dots,R\}^{B\times L}$
\end{itemize}

\subsection*{Symbol glossary (quick reference)}
\begin{itemize}[leftmargin=2em]
\item $B$: mini-batch size (number of cells per optimization step)
\item $L$: sequence length (tokens per cell), \textbf{not} vocabulary size
\item $|\mathcal{V}|$: gene-token vocabulary size
\item $d$: hidden/embedding width (\texttt{d\_model})
\item $N$: number of stacked Mamba encoder layers
\end{itemize}

\section{Token/value encoding}
\subsection{Gene embedding}
\begin{equation}
\mathbf{E}_{g} = \mathrm{LN}(\mathrm{Embed}_{g}(\mathbf{G})) \in \mathbb{R}^{B\times L\times d}.
\end{equation}
Here, $\mathrm{Embed}_{g}:\{1,\dots,|\mathcal{V}|\}\to\mathbb{R}^{d}$ maps each gene token id to a learnable vector.

\paragraph{What is LN?}
$\mathrm{LN}$ is \textbf{LayerNorm} (layer normalization). For each token vector $\mathbf{x}\in\mathbb{R}^{d}$, it normalizes across the feature dimension:
\begin{equation}
\mathrm{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x}-\mu(\mathbf{x})}{\sqrt{\sigma^2(\mathbf{x})+\epsilon}} + \beta,
\end{equation}
where $\mu(\mathbf{x})=\frac{1}{d}\sum_{j=1}^d x_j$, $\sigma^2(\mathbf{x})=\frac{1}{d}\sum_{j=1}^d(x_j-\mu)^2$, and $\gamma,\beta\in\mathbb{R}^{d}$ are learnable parameters. Intuitively, LN stabilizes optimization by keeping each token representation at a controlled scale.

\subsection{Value embedding (continuous style)}
For the default continuous value pathway:
Expand All @@ -46,6 +67,9 @@ \subsection{Value embedding (continuous style)}
where $W_1\in\mathbb{R}^{1\times d}$, $W_2\in\mathbb{R}^{d\times d}$, and $\sigma$ is ReLU.
Values are clipped by a configured \texttt{max\_value}.

\paragraph{Interpretation.}
Gene identity and expression magnitude are encoded separately and then fused. This helps the model distinguish ``\emph{which gene}" (from $\mathbf{E}_g$) from ``\emph{how much expression}" (from $\mathbf{E}_v$).

\subsection{Fusion and normalization}
Default additive fusion:
\begin{equation}
Expand All @@ -56,6 +80,9 @@ \subsection{Fusion and normalization}
\mathbf{H}^{(0)} \leftarrow \mathrm{BN}(\mathbf{H}^{(0)}).
\end{equation}

\paragraph{BN vs DSBN.}
BN is shared across all batches/domains. DSBN (domain-specific BN) keeps separate normalization statistics (and optionally affine parameters) per batch/domain label, which can reduce domain-shift effects.

\section{Bidirectional Mamba encoder stack}
Let the number of Mamba blocks be $N$. For each layer $i\in\{1,\dots,N\}$:

Expand All @@ -82,6 +109,9 @@ \section{Bidirectional Mamba encoder stack}
\mathbf{H} = \mathbf{H}^{(N)} \in \mathbb{R}^{B\times L\times d}.
\end{equation}

\paragraph{Why two directions?}
The forward stream captures left-to-right context, while the flipped stream approximates right-to-left context. Averaging both gives a bidirectional representation without Transformer self-attention.

\section{Cell embedding extraction}
Given $\mathbf{H}\in\mathbb{R}^{B\times L\times d}$, cell embedding $\mathbf{c}\in\mathbb{R}^{B\times d}$ is computed by one of:
\begin{align}
Expand Down Expand Up @@ -118,6 +148,9 @@ \subsection{MVC head (masked value prediction with cell embedding)}
thus $\widehat{\mathbf{Y}}^{\mathrm{mvc}}\in\mathbb{R}^{B\times L}$.
An optional zero-probability head is analogously defined.

\paragraph{Why MVC uses gene embeddings.}
The query is produced from raw gene token embeddings to avoid leaking masked expression values directly through the query path.

\subsection{Adversarial batch discriminator (optional)}
A gradient-reversal classifier predicts batch labels from $\mathbf{c}$, encouraging batch-invariant latent representations.

Expand Down Expand Up @@ -171,6 +204,9 @@ \subsection{Total objective}
\lambda_7\mathcal{L}_{\mathrm{DAB}}.
\end{equation}

\paragraph{Interpretation of the objective.}
The loss combines token-level reconstruction (GEP/MVC), sparsity-aware likelihood (NZLP), and cell-level supervision/regularization (CLS/ECS/DAB). This multitask formulation shapes representations that are useful both for recovering masked expression and for downstream biological tasks.

\section{Why this scales to ultra-long transcriptome sequences}
\begin{itemize}[leftmargin=2em]
\item Mamba2 replaces quadratic self-attention with state-space sequence modeling.
Expand Down