Skip to content

Nizben/glass

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GLASS: Official Implementation (NeurIPS 2025 NEGEL workshop paper)

This repository contains the official implementation of the GLASS layer from the paper:

  • GLASS: A Differentiable Geometric Alignment Layer on Manifolds and Graphs via Learned Slices and Soft Optimal Transport

GLASS is a lightweight, fully differentiable alignment head for compact manifolds and graphs. It learns intrinsic one-dimensional projections, performs entropic 1D optimal transport (a smooth soft sorting), and lifts the soft plan back to the ambient geometry to compute a geometry-aware loss while exposing the coupling for downstream reuse.

  • Manifolds (S²): gated projections on log maps, geodesic lifted loss
  • Graphs (SBM): eigenvector-free diffusion features, linear projection, spectral lifted loss
  • Fast and scalable banded Sinkhorn implementation with near-linear runtime in practice

Installation

  • Python ≥ 3.10 recommended
  • Install dependencies:
pip install -r requirements.txt

How GLASS works in a nutshell

  1. Learned intrinsic projections
  • S²: anchors $a_k$ with tangent directions $u_k$, gated mixture over log maps $Log_{a_k}(x)$
  • Graphs: diffusion features $ψ_τ(v)$ via Chebyshev polynomials (no eigendecomposition) and a learned linear form
  1. Soft one-dimensional alignment
  • Compute scores $s_i$, $t_j$ and solve entropic 1D OT with Sinkhorn (dense or banded)
  1. Lift to non-Euclidean space
  • Manifolds: $L_{geo} = ⟨P, d_M^2(X, Y)⟩$
  • Graphs: $L_{spec} = ⟨P, ||ψ_τ^X - ψ_τ^Y||^2⟩$

Repository structure

src/
  data/
    graph_data.py          # SBM pair + Laplacians
    sphere_data.py         # Pair of S² point sets with optional rotation
  geometry/
    graph_poly.py          # Chebyshev diffusion features (eigen-free)
    sphere.py              # S² geometry utilities (log map, geodesics, gates)
  models/
    projections_graph.py   # Linear projection on diffusion features
    projections_manifold.py# Gated manifold projection on S²
  ot/
    sinkhorn.py            # Dense and banded Sinkhorn solvers
    losses.py              # Lifted costs (sphere, graph)
  utils/
    logging.py             # CSV logging + artifact helpers
    seed.py                # Reproducible seeds

scripts/
  run_manifold.py          # Train/eval on S² experiments
  run_graph.py             # Train/eval on SBM experiments
  benchmark.sinkhorn.py    # Microbenchmark dense vs banded Sinkhorn
  compile_runs.py          # Aggregate run folders into CSV

bash/                      # Sweep helpers
configs/                   # Base and auto-generated configs

Quickstart

S² (manifold) example

python scripts/run_manifold.py --config configs/_auto_s2_M1_tie_K8_g0.05.yaml

Graph (SBM) example

python scripts/run_graph.py --config configs/_auto_sbm_G2_q12_tau0.5.yaml

Logs and artifacts are saved to runs/<exp_name>_<timestamp>/.

Reproducing paper experiments

We provide concrete YAMLs under configs/ matching the experiments reported in the paper. Each row can be run directly with the listed command.

Manifold on S²

  • M1 — Capacity and gate tying

    • Tied gates (recommended)
      • configs/_auto_s2_M1_tie_K8_g0.05.yaml
      • configs/_auto_s2_M1_tie_K16_g0.05.yaml
      • configs/_auto_s2_M1_tie_K32_g0.05.yaml
      • configs/_auto_s2_M1_tie_K8_g0.10.yaml
      • configs/_auto_s2_M1_tie_K16_g0.10.yaml
      • configs/_auto_s2_M1_tie_K32_g0.10.yaml
      • configs/_auto_s2_M1_tie_K8_g0.15.yaml
    • Fixed gates (baseline)
      • configs/_auto_s2_M1_K8_g0.10.yaml
      • configs/_auto_s2_M1_K16_g0.10.yaml
      • configs/_auto_s2_M1_K32_g0.10.yaml
    • Run: python scripts/run_manifold.py --config <CONFIG>
  • M2 — Chart radius / curvature ablation

    • configs/_auto_s2_M2_r10.yaml
    • configs/_auto_s2_M2_r20.yaml
    • configs/_auto_s2_M2_r30.yaml
    • configs/_auto_s2_M2_r40.yaml
    • Run: python scripts/run_manifold.py --config <CONFIG>
  • M4 — Runtime scaling with n

    • configs/_auto_s2_M4_n64.yaml, _n128.yaml, _n256.yaml, _n512.yaml, _n1024.yaml
    • Run: python scripts/run_manifold.py --config <CONFIG>

Graphs (SBM)

  • G1 — Size and noise sweeps

    • Sizes: configs/_auto_sbm_G1_n200.yaml, _n500.yaml, _n1000.yaml
    • Noise: configs/_auto_sbm_G1_noise_rw0.00.yaml, _rw0.02.yaml, _rw0.05.yaml, _rw0.10.yaml
    • Run: python scripts/run_graph.py --config <CONFIG>
  • G2 — Diffusion scale and degree

    • Grid (q × τ):
      • τ=0.2: q8, q12, q16, q24configs/_auto_sbm_G2_q{Q}_tau0.2.yaml
      • τ=0.5: q8, q12, q16, q24configs/_auto_sbm_G2_q{Q}_tau0.5.yaml
      • τ=1.0: q8, q12, q16, q24configs/_auto_sbm_G2_q{Q}_tau1.0.yaml
    • Fine τ (q=12): configs/_auto_sbm_G2_tau0.10.yaml, _0.15.yaml, _0.20.yaml, _0.25.yaml, _0.30.yaml, _0.40.yaml
    • Run: python scripts/run_graph.py --config <CONFIG>
  • G3 — Runtime scaling with n

    • configs/_auto_sbm_G3_n500.yaml, _n1000.yaml, _n2000.yaml, _n4000.yaml
    • Run: python scripts/run_graph.py --config <CONFIG>

Notes:

  • Base templates: configs/manifold_s2_small.yaml, configs/graph_sbm_small.yaml (used by sweep scripts; you can also run them directly).
  • You can edit any _auto_*.yaml to tweak parameters without using the bash sweep helpers.

Configuration reference

Minimal S² config:

data:
  n: 256
  noise_deg: 2.0
  rotation: true
exp_name: s2_M1_tie_K8_g0.05
log:
  out_dir: runs
  save_every: 50
model:
  manifold:
    K: 8
    chart_radius_deg: 20.0
    gate_mode: tie      # fixed | tie | learn
    gate_c_per_sqrtk_deg: 45.0
    init_u_scale: 0.2
  ot:
    use_banded: true
    band_eps: 1.0e-6
    gamma: 0.05
    max_iter: 200
    tol: 1.0e-6
  train:
    lr: 1.0e-3
    steps: 400
    slices: 2
    anneal_gamma: true
    gamma_decay: 0.98
    gamma_min: 0.01
seed: 0

Minimal SBM graph config:

data:
  n: 400
  blocks: 3
  p_in: 0.15
  p_out: 0.02
  rewire_prob: 0.03
  planted_perm: true
exp_name: sbm_G1_example
log:
  out_dir: runs
  save_every: 50
model:
  graph:
    tau: 0.5
    cheby_degree: 12
    feat_dim: 32
  ot:
    use_banded: true
    band_eps: 1.0e-6
    gamma: 0.10
    max_iter: 200
    tol: 1.0e-6
  train:
    lr: 1.0e-3
    steps: 300
    slices: 2
    anneal_gamma: true
    gamma_decay: 0.98
    gamma_min: 0.01
seed: 0

Solvers and losses

  • Dense Sinkhorn: sinkhorn_dense(C, mu, nu, gamma, max_iter, tol)
  • Banded (sparse) Sinkhorn: sinkhorn_banded(s_sorted, t_sorted, mu_sorted, nu_sorted, gamma, max_iter, tol, band_eps)
  • Band radius ≈ sqrt(-gamma * log(eps_in))
  • Lifted costs: lifted_loss_sphere (geodesic on S²), lifted_loss_graph (diffusion features)

Logging and analysis

Each run creates runs/<exp_name>_<timestamp>/ with:

  • config.json (exact config)
  • metrics.csv (step, loss, plan_entropy, gamma, runtime_sec, nnz, sinkhorn_iter, band_radius, perm_acc)
  • periodic checkpoints (e.g., state_stepXXX.pt)

Aggregate final metrics across runs (no plotting):

python scripts/compile_runs.py --logdir runs

Benchmark: dense vs banded Sinkhorn

python scripts/benchmark.sinkhorn.py --n 2048 --gamma 0.05 --band_eps 1e-6 --max_iter 200

About

GLASS official implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published