Skip to content

Bug fixes, performance improvements, and new features#7

Merged
fcarli merged 16 commits intodevfrom
improvements
Feb 22, 2026
Merged

Bug fixes, performance improvements, and new features#7
fcarli merged 16 commits intodevfrom
improvements

Conversation

@fcarli
Copy link
Copy Markdown
Owner

@fcarli fcarli commented Feb 22, 2026

Summary

  • Bug fixes: Prevent NaN/crash in training loop, fix FAISS self-match filtering by index, restore correct L^(2b) norm in UMAP loss, use adaptive upper bound for sigma binary search, save/validate input_dim in checkpoints, forward all fit() params in fit_transform(), vary random_state per epoch, accept torch.device in device parameter
  • Performance: Replace O(N^2) negative sampling with rejection sampling, store edge weights alongside edges to eliminate sparse lookups, precompute input-space distances for all edges once, use torch.as_tensor() to avoid unnecessary copies, remove ProcessPoolExecutor from negative sampling
  • Features: Add batch_size parameter to transform() for large inputs, add optional torch.compile support for MLP

Test plan

  • uv run pytest passes
  • uv run ruff check . passes
  • Swiss roll example notebook runs end-to-end

🤖 Generated with Claude Code

fcarli and others added 15 commits February 22, 2026 19:02
- Guard correlation loss against zero-variance inputs by clamping stds
- Clamp qs values before BCELoss to prevent log(0) = -inf
- Handle negative FAISS squared distances before sqrt
- Reuse precomputed Z_distances instead of redundant torch.norm call
- Suppress tqdm progress bars when verbose=False
- Use optimizer.zero_grad(set_to_none=True) for minor perf gain

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous approach allocated np.arange(n_nodes) and ran np.isin for
every node, making negative edge sampling O(N) per node and O(N^2)
total. Rejection sampling draws random candidates and checks against
the (small) adjacency set, reducing cost to O(k) per node for sparse
graphs. Also adds input validation for negative k values.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace list[tuple] edge storage with numpy arrays of shape (n, 2)
and co-located weight arrays. EdgeBatchIterator now yields numpy
slices + indices, enabling direct tensor indexing into a pre-placed
device tensor of weights.

This eliminates TorchSparseDataset from the training loop entirely:
- Per-batch weight lookup: O(B log nnz) searchsorted → O(B) index
- Edge unpacking: Python list comprehensions → numpy column slices
- Shuffle: Python list copy → numpy permutation indexing
- Edge memory: ~72 bytes/edge (Python tuples) → 16 bytes/edge

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Commit 426be57 incorrectly refactored the q-value computation from
torch.norm(..., p=2*b) to torch.pow(torch.norm(...), 2*b). These are
not equivalent: the former computes the L^(2b) norm, while the latter
computes the L2 norm raised to power 2b. With default b=1.0 this
squares the distance, compressing embeddings and degrading quality.

Restore the original p=2*b norm parameter and use Z_distances directly
without the redundant pow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Explicitly persist input_dim in save dict instead of inferring it from
weight matrix key names, and check feature dimensions in transform()
to give a clear ValueError on mismatch. Old checkpoints still load
via fallback to weight introspection.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
fit_transform() was missing resample_negatives, n_processes, and
random_state, silently ignoring them if passed by callers. Also fixes
pre-existing EM102 lint violation in transform().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Input-space L2 distances between edge endpoints never change across
epochs, yet were recomputed via torch.norm every batch. Now computed
once with np.linalg.norm and placed on device alongside edge weights,
then indexed per batch. Also fixes fit_transform test assertions to
match the full parameter forwarding from the earlier commit.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace hardcoded upper bound (10.0) with a doubling strategy that
finds the correct bracket per-sample before binary search begins.
Prevents silent convergence to incorrect sigma on unscaled data.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds a `compile_model` flag (default False) that wraps the MLP with
torch.compile for 10-30% faster training on PyTorch 2.x. Save/load
uses the unwrapped module so checkpoints stay portable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previously the resample call used the default seed (0) every epoch,
producing identical negative edges. Now passes random_state + epoch + 1
so each epoch gets distinct but deterministic samples.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace torch.tensor() with torch.as_tensor() in transform() and
VariableDataset to share memory with the backing numpy array instead
of always copying.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Rejection sampling is O(k) per node for sparse graphs, making the
multiprocessing path pure overhead (pickling adj_sets, spawning
processes). Single-threaded is now the only code path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
With exact duplicate points, FAISS may return a different point at
position 0 (same zero distance), causing self to remain in the neighbor
list. Filter by matching point index instead of assuming position.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
By default transform() sends all data to the device in one pass. When
that causes OOM, users can now pass batch_size to process in chunks.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@fcarli fcarli self-assigned this Feb 22, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@fcarli fcarli merged commit 494ef89 into dev Feb 22, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant