-
Notifications
You must be signed in to change notification settings - Fork 537
[WIP] Sparse emd implementation #778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] Sparse emd implementation #778
Conversation
- Implement sparse bipartite graph EMD solver in C++
- Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py)
- Add unit tests to verify sparse and dense solvers produce identical results
- Tests use augmented k-NN approach to ensure fair comparison
- Update setup.py to include sparse solver compilation
Both test_emd_sparse_vs_dense() and test_emd2_sparse_vs_dense() verify:
* Identical costs between sparse and dense solvers
* Marginal constraint satisfaction for both solvers
This PR implements a sparse bipartite graph EMD solver for memory-efficient
optimal transport when the cost matrix has many infinite or forbidden edges.
Changes:
- Implement sparse bipartite graph EMD solver in C++
- Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py)
- Add unit tests to verify sparse and dense solvers produce identical results
- Tests use augmented k-NN approach to ensure fair comparison
Tests verify correctness:
* test_emd_sparse_vs_dense() - verifies identical costs and marginal constraints
* test_emd2_sparse_vs_dense() - verifies cost-only version
Status: WIP - seeking feedback on implementation approach
TODO: Add example script and documentation
…trix parameter from emd and fix linting issues
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #778 +/- ##
==========================================
- Coverage 97.15% 97.12% -0.04%
==========================================
Files 107 107
Lines 21906 22195 +289
==========================================
+ Hits 21283 21556 +273
- Misses 623 639 +16 🚀 New features to boost your workflow:
|
- Remove tuple format support for sparse matrices (use scipy.sparse only) - Change index types from int64_t to uint64_t throughout (indices are never negative) - Refactor emd() and emd2() with clear sparse/dense code path separation - Add sparse_bipartitegraph.h to MANIFEST.in to fix build - Add test_emd_sparse_backends() to verify backend compatibility
rflamary
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much @nathanneike for this PR.
I have many small comments but it is already looking very nice.
| The figure below illustrates the advantages of sparse OT solvers over dense ones in terms of speed and memory usage for different sparsity levels of the transport plan. | ||
| .. image:: /_static/images/comparison.png |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us not add static image file. could you do a quick bench below and compare computational time?
|
|
||
| # %% | ||
|
|
||
| X = np.array([[0, 0], [1, 0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very simple example, maybe we can design a sparse example with more points?
|
|
||
| # Solve sparse OT (intra-cluster only) | ||
| G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True) | ||
| cost_sparse = log_sparse["cost"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for log np.sum(G_sparse*M_sparse_large) shoudl work
|
|
||
| # Dense OT | ||
| plt.subplot(1, 2, 1) | ||
| for i in range(nA): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a function for that, ot.plot.plot2D_samples_mat
https://pythonot.github.io/gen_modules/ot.plot.html#id1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the function shoud be update to handle sparse OT matrices maybe
|
|
||
| return None, log_dict | ||
| else: | ||
| raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return teh OT plan in sparse fomart (same type and device as M)
| np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7) | ||
|
|
||
| # Reconstruct sparse matrix from flow for marginal checks | ||
| if G_sparse is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
erturn spasre matrix insteda of doing that
| cols = [] | ||
| data = [] | ||
|
|
||
| for i in range(n_source): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
| ) | ||
|
|
||
| C_augmented_dense = np.full((n_source, n_target), large_cost) | ||
| C_augmented_array = C_augmented.toarray() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this alerady returs a dense matrix
| b, G_sparse_reconstructed.sum(0), rtol=1e-5, atol=1e-7 | ||
| ) | ||
| else: | ||
| np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep just do that
| cols_aug.append(j) | ||
| data_aug.append(C[i, j]) | ||
|
|
||
| C_augmented = coo_matrix( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use nx.from_numpy here
Types of changes
Motivation and context / Related issue
This PR implements a sparse EMD solver for memory-efficient optimal transport when the cost matrix has many infinite or forbidden edges (e.g., k-NN graphs, sparse networks).
Problem: The current dense EMD solver requires O(n²) memory for the full cost matrix, which becomes prohibitive for large-scale
problems even when most edges are forbidden.
Solution: This PR adds a sparse bipartite graph solver that only stores edges with finite costs, reducing memory usage from O(n²) to O(E) where E is the number of edges.
Use cases:
How has this been tested
Unit Tests
Added two comprehensive tests in
test/test_ot.py:test_emd_sparse_vs_dense()- Verifies sparse and dense solvers produce identical transport matricestest_emd2_sparse_vs_dense()- Verifies sparse and dense solvers produce identical costsBoth tests use the augmented k-NN approach:
Test results: All 50 tests in
test/test_ot.pypassVerification
PR checklist
TODO before [MRG]:
examples/folder demonstrating sparse solver usageFeedback requested:
sparse=Trueparameter)