Skip to content

stevenk42/Embedding-Stability-Module

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Embedding Stability Module (ESM)

A PyTorch / NumPy utility class for stabilizing hyperbolic (Poincaré ball) embeddings in transformer models and related deep learning systems.


Why Hyperbolic Embeddings?

Hyperbolic space is ideal for representing hierarchies and tree-like data (ontologies, linguistic structures, taxonomies). Compared to Euclidean embeddings:

  • 🌳 Trees can be embedded compactly.
  • ⚡ Distances grow exponentially with radius, naturally fitting branching structures.

But hyperbolic geometry is numerically fragile:

  • Points near the unit boundary (‖z‖ → 1) cause distances and gradients to explode.
  • Standard vector ops (u + v) aren’t valid — you need Möbius (gyro) operations.
  • Encoders/decoders can suffer Jacobian collapse or explosion, destroying trainability.

Why This Library?

PyTorch and Hugging Face’s transformers do not support gyro operations natively. This library provides a lightweight drop-in guardrail for transformer pipelines:

  • ✅ Correct Möbius addition with clamping
  • Per-sample safe gyroaddition (Euclidean fallback only when necessary)
  • ✅ Curvature checks + projection to a configurable soft threshold
  • ✅ Accurate Jacobian operator norm monitoring (J/Jᵀ power iteration)
  • ✅ Ricci-flow-like smoothing with pre/post projection to keep embeddings valid
  • ✅ NumPy/PyTorch interoperability
  • ✅ Encode/decode wrappers for safely surrounding risky Euclidean ops

Visual Primer

Safe vs Unsafe Dynamics in the Poincaré Disk

Spirals in Poincaré Disk

  • Red: collapsing spiral drifting toward boundary → unstable
  • Blue: stabilized orbit maintained inside the ball

Boundary Instability and Projection

Projection to Disk

The Poincaré disk is a projection of the hyperboloid model. Crossing the boundary corresponds to crossing the light cone — geometrically invalid. EmbeddingStability prevents this by projecting vectors back inside a safe radius.


Installation

git clone https://github.com/stevenk42/embedding-stability.git
cd embedding-stability
pip install -r requirements.txt

Dependencies:

  • torch >= 1.12
  • numpy

Usage

import torch
from embedding_stability import EmbeddingStability

esm = EmbeddingStability(
    curvature_threshold=0.95,
    recon_eps=1e-6,
    jacobian_opnorm_eps=5.0,
    ricci_steps=3,
    ricci_lr=1e-2,
    use_autograd=True,
    device="cuda"
)

# Example embeddings
u = torch.randn(4, 16) * 0.1
v = torch.randn(4, 16) * 0.1

# Möbius addition
res = esm.mobius_add(u, v)

# Stabilize embeddings
stable_u = esm.stabilize(u)

# Example encoder/decoder
E = torch.nn.Linear(16, 16)
D = torch.nn.Linear(16, 16)

# Jacobian safeguard
smoothed = esm.jacobian_safeguard(u, E, D)

# Safe encode-decode block
out = esm.hyperbolic_safe_block(E, D, u)

API

  • mobius_add(u, v) – Correct Möbius addition.
  • safe_gyroadd(u, v) – Per-sample Euclidean fallback when near boundary.
  • check_curvature(z) – Returns True if ‖z‖ <= threshold.
  • stabilize(embedding) – Projects embeddings back inside the soft wall.
  • jacobian_op_norm(f, x, iters=10) – Estimates Jacobian spectral norm via J/Jᵀ iteration.
  • jacobian_safeguard(C, E, D, max_iter=10) – Detects instability and applies Ricci-like smoothing.
  • hyperbolic_safe_encode(E, x) – Encode and immediately project inside the ball.
  • hyperbolic_safe_block(E, D, x) – Encode → project → (external ops) → project → decode.

Positioning: Lightweight vs Geoopt

This library is a guardrail, not a geometry engine.

  • Use this module when your model is mostly Euclidean but touches hyperbolic embeddings, and you want stability against runaway norms and invalid operations.
  • Use Geoopt when you need full manifold training: exponential/log maps, Mobius matmul, Riemannian optimizers, and multi-manifold support.

Think of ESM as a safety harness: drop it into existing pipelines to prevent collapse without restructuring your whole training loop.


Notes

  • With use_autograd=True, E and D must be torch-native functions/modules.
  • With use_autograd=False, you can pass NumPy functions, but Jacobian monitoring is skipped.
  • Thresholds (recon_eps, jacobian_opnorm_eps) should be tuned to your model scale.
  • For safety-critical use, set curvature_threshold < 1.0 (e.g. 0.90–0.95) to keep a margin from the boundary.

Roadmap

  • Add Möbius scalar multiplication + exp/log maps
  • Add manifold-aware optimizers for Ricci flow
  • Hugging Face transformer integration examples

License

MIT License. See LICENSE for details.