Bug fixes, performance improvements, and new features#7
Merged
Conversation
- Guard correlation loss against zero-variance inputs by clamping stds - Clamp qs values before BCELoss to prevent log(0) = -inf - Handle negative FAISS squared distances before sqrt - Reuse precomputed Z_distances instead of redundant torch.norm call - Suppress tqdm progress bars when verbose=False - Use optimizer.zero_grad(set_to_none=True) for minor perf gain Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous approach allocated np.arange(n_nodes) and ran np.isin for every node, making negative edge sampling O(N) per node and O(N^2) total. Rejection sampling draws random candidates and checks against the (small) adjacency set, reducing cost to O(k) per node for sparse graphs. Also adds input validation for negative k values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace list[tuple] edge storage with numpy arrays of shape (n, 2) and co-located weight arrays. EdgeBatchIterator now yields numpy slices + indices, enabling direct tensor indexing into a pre-placed device tensor of weights. This eliminates TorchSparseDataset from the training loop entirely: - Per-batch weight lookup: O(B log nnz) searchsorted → O(B) index - Edge unpacking: Python list comprehensions → numpy column slices - Shuffle: Python list copy → numpy permutation indexing - Edge memory: ~72 bytes/edge (Python tuples) → 16 bytes/edge Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Commit 426be57 incorrectly refactored the q-value computation from torch.norm(..., p=2*b) to torch.pow(torch.norm(...), 2*b). These are not equivalent: the former computes the L^(2b) norm, while the latter computes the L2 norm raised to power 2b. With default b=1.0 this squares the distance, compressing embeddings and degrading quality. Restore the original p=2*b norm parameter and use Z_distances directly without the redundant pow. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Explicitly persist input_dim in save dict instead of inferring it from weight matrix key names, and check feature dimensions in transform() to give a clear ValueError on mismatch. Old checkpoints still load via fallback to weight introspection. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
fit_transform() was missing resample_negatives, n_processes, and random_state, silently ignoring them if passed by callers. Also fixes pre-existing EM102 lint violation in transform(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Input-space L2 distances between edge endpoints never change across epochs, yet were recomputed via torch.norm every batch. Now computed once with np.linalg.norm and placed on device alongside edge weights, then indexed per batch. Also fixes fit_transform test assertions to match the full parameter forwarding from the earlier commit. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace hardcoded upper bound (10.0) with a doubling strategy that finds the correct bracket per-sample before binary search begins. Prevents silent convergence to incorrect sigma on unscaled data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds a `compile_model` flag (default False) that wraps the MLP with torch.compile for 10-30% faster training on PyTorch 2.x. Save/load uses the unwrapped module so checkpoints stay portable. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Previously the resample call used the default seed (0) every epoch, producing identical negative edges. Now passes random_state + epoch + 1 so each epoch gets distinct but deterministic samples. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace torch.tensor() with torch.as_tensor() in transform() and VariableDataset to share memory with the backing numpy array instead of always copying. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Rejection sampling is O(k) per node for sparse graphs, making the multiprocessing path pure overhead (pickling adj_sets, spawning processes). Single-threaded is now the only code path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
With exact duplicate points, FAISS may return a different point at position 0 (same zero distance), causing self to remain in the neighbor list. Filter by matching point index instead of assuming position. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
By default transform() sends all data to the device in one pass. When that causes OOM, users can now pass batch_size to process in chunks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
torch.devicein device parametertorch.as_tensor()to avoid unnecessary copies, remove ProcessPoolExecutor from negative samplingbatch_sizeparameter totransform()for large inputs, add optionaltorch.compilesupport for MLPTest plan
uv run pytestpassesuv run ruff check .passes🤖 Generated with Claude Code