-
Notifications
You must be signed in to change notification settings - Fork 537
[WIP] Sparse emd implementation #778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nathanneike
wants to merge
10
commits into
PythonOT:master
Choose a base branch
from
nathanneike:sparse-emd-implementation
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,698
−113
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
04c12a0
Add sparse EMD solver with unit tests
nathanneike 0eee6f1
[WIP] Add sparse EMD solver with unit tests
nathanneike 022720b
Fix int64_t type compatibility for Linux, remove sparse and return ma…
nathanneike aa5f1c9
refactor: Clean up sparse EMD implementation
nathanneike fae9f02
fix : Quick test file fix
nathanneike 1e28771
Merge branch 'master' into sparse-emd-implementation
rflamary b184cd4
Added Example for documentation and modified back setup file to origi…
nathanneike 1152398
feat: Add backend-agnostic sparse EMD support
nathanneike 1a3dc41
throw error when unsupported backend used for sparse
nathanneike 54479d5
Fix sparse tensor gradients and add backend checks
nathanneike File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,225 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| ============================================ | ||
| Sparse Optimal Transport | ||
| ============================================ | ||
|
|
||
| In many real-world optimal transport (OT) problems, the transport plan is | ||
| naturally sparse: only a small fraction of all possible source-target pairs | ||
| actually exchange mass. Using sparse OT solvers can provide significant | ||
| computational speedups and memory savings compared to dense solvers. | ||
|
|
||
| This example demonstrates how to use sparse cost matrices with POT's EMD solver, | ||
| comparing sparse and dense formulations on both a minimal example and a larger | ||
| concentric circles dataset. | ||
| """ | ||
|
|
||
| # Author: Nathan Neike <nathan.neike@example.com> | ||
| # License: MIT License | ||
| # sphinx_gallery_thumbnail_number = 2 | ||
|
|
||
| import numpy as np | ||
| import matplotlib.pyplot as plt | ||
| from scipy.sparse import coo_matrix | ||
| import ot | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Minimal example with 4 points | ||
| # ------------------------------ | ||
|
|
||
| # %% | ||
|
|
||
| X = np.array([[0, 0], [1, 0], [0.5, 0], [1.5, 0]]) | ||
| Y = np.array([[0, 1], [1, 1], [0.5, 1], [1.5, 1]]) | ||
| a = np.array([0.25, 0.25, 0.25, 0.25]) | ||
| b = np.array([0.25, 0.25, 0.25, 0.25]) | ||
|
|
||
| # Build sparse cost matrix allowing only selected edges | ||
| rows = [0, 1, 2, 3] | ||
| cols = [0, 1, 2, 3] | ||
| vals = [np.linalg.norm(X[i] - Y[j]) for i, j in zip(rows, cols)] | ||
| M_sparse = coo_matrix((vals, (rows, cols)), shape=(4, 4)) | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Solve and display sparse OT solution | ||
| # ------------------------------------- | ||
|
|
||
| # %% | ||
|
|
||
| G, log = ot.emd(a, b, M_sparse, log=True) | ||
|
|
||
| print("Sparse OT cost:", log["cost"]) | ||
| print("Solution format:", type(G)) | ||
| print("Non-zero edges:", G.nnz) | ||
| print("\nEdges:") | ||
| G_coo = G if isinstance(G, coo_matrix) else G.tocoo() | ||
| for i, j, v in zip(G_coo.row, G_coo.col, G_coo.data): | ||
| if v > 1e-10: | ||
| print(f" source {i} -> target {j}, flow={v:.3f}") | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Visualize sparse vs dense edge structure | ||
| # ----------------------------------------- | ||
|
|
||
| # %% | ||
|
|
||
| plt.figure(figsize=(8, 4)) | ||
|
|
||
| plt.subplot(1, 2, 1) | ||
| plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) | ||
| plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) | ||
| for i, j in zip(rows, cols): | ||
| plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.6) | ||
| plt.title("Sparse OT: Allowed Edges Only") | ||
| plt.xlim(-0.5, 2.0) | ||
| plt.ylim(-0.5, 1.5) | ||
|
|
||
| plt.subplot(1, 2, 2) | ||
| plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) | ||
| plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) | ||
| for i in range(len(X)): | ||
| for j in range(len(Y)): | ||
| plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.3) | ||
| plt.title("Dense OT: All Possible Edges") | ||
| plt.xlim(-0.5, 2.0) | ||
| plt.ylim(-0.5, 1.5) | ||
|
|
||
| plt.tight_layout() | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Larger example: concentric circles | ||
| # ----------------------------------- | ||
|
|
||
| # %% | ||
|
|
||
| n_clusters = 8 | ||
| points_per_cluster = 25 | ||
| n = n_clusters * points_per_cluster | ||
| k_neighbors = 8 | ||
| rng = np.random.default_rng(0) | ||
|
|
||
| r_source = 1.0 | ||
| r_target = 2.0 | ||
| noise_scale = 0.06 | ||
|
|
||
| theta = np.linspace(0.0, 2.0 * np.pi, n, endpoint=False) | ||
| cluster_labels = np.repeat(np.arange(n_clusters), points_per_cluster) | ||
|
|
||
| X_large = np.column_stack( | ||
| [r_source * np.cos(theta), r_source * np.sin(theta)] | ||
| ) + rng.normal(scale=noise_scale, size=(n, 2)) | ||
| Y_large = np.column_stack( | ||
| [r_target * np.cos(theta), r_target * np.sin(theta)] | ||
| ) + rng.normal(scale=noise_scale, size=(n, 2)) | ||
|
|
||
| a_large = np.zeros(n) | ||
| b_large = np.zeros(n) | ||
| for k in range(n_clusters): | ||
| idx = np.where(cluster_labels == k)[0] | ||
| a_large[idx] = 1.0 / n_clusters / points_per_cluster | ||
| b_large[idx] = 1.0 / n_clusters / points_per_cluster | ||
|
|
||
| M_full = ot.dist(X_large, Y_large, metric="euclidean") | ||
|
|
||
| # Build sparse cost matrix: intra-cluster k-nearest neighbors | ||
| angles_X = np.arctan2(X_large[:, 1], X_large[:, 0]) | ||
| angles_Y = np.arctan2(Y_large[:, 1], Y_large[:, 0]) | ||
|
|
||
| rows = [] | ||
| cols = [] | ||
| vals = [] | ||
| for k in range(n_clusters): | ||
| src_idx = np.where(cluster_labels == k)[0] | ||
| tgt_idx = np.where(cluster_labels == k)[0] | ||
| for i in src_idx: | ||
| diff = np.angle(np.exp(1j * (angles_Y[tgt_idx] - angles_X[i]))) | ||
| idx = np.argsort(np.abs(diff))[:k_neighbors] | ||
| for j_local in idx: | ||
| j = tgt_idx[j_local] | ||
| rows.append(i) | ||
| cols.append(j) | ||
| vals.append(M_full[i, j]) | ||
|
|
||
| M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(n, n)) | ||
| allowed_sparse = set(zip(rows, cols)) | ||
|
|
||
| ############################################################################## | ||
| # Visualize edge structures | ||
| # -------------------------- | ||
|
|
||
| # %% | ||
|
|
||
| plt.figure(figsize=(16, 6)) | ||
|
|
||
| plt.subplot(1, 2, 1) | ||
| for i in range(n): | ||
| for j in range(n): | ||
| plt.plot( | ||
| [X_large[i, 0], Y_large[j, 0]], | ||
| [X_large[i, 1], Y_large[j, 1]], | ||
| color="blue", | ||
| alpha=0.2, | ||
| linewidth=0.05, | ||
| ) | ||
| plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) | ||
| plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) | ||
| plt.axis("equal") | ||
| plt.title("Dense OT: All Possible Edges") | ||
|
|
||
| plt.subplot(1, 2, 2) | ||
| for i, j in allowed_sparse: | ||
| plt.plot( | ||
| [X_large[i, 0], Y_large[j, 0]], | ||
| [X_large[i, 1], Y_large[j, 1]], | ||
| color="blue", | ||
| alpha=1, | ||
| linewidth=0.05, | ||
| ) | ||
| plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) | ||
| plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) | ||
| plt.axis("equal") | ||
| plt.title("Sparse OT: Intra-Cluster k-NN Edges") | ||
|
|
||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
||
| ############################################################################## | ||
| # Solve and visualize transport plans | ||
| # ------------------------------------ | ||
|
|
||
| # %% | ||
|
|
||
| G_dense = ot.emd(a_large, b_large, M_full) | ||
| cost_dense = np.sum(G_dense * M_full) | ||
| print(f"Dense OT cost: {cost_dense:.6f}") | ||
|
|
||
| G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True) | ||
| cost_sparse = log_sparse["cost"] | ||
| print(f"Sparse OT cost: {cost_sparse:.6f}") | ||
|
|
||
| plt.figure(figsize=(16, 6)) | ||
|
|
||
| plt.subplot(1, 2, 1) | ||
| ot.plot.plot2D_samples_mat( | ||
| X_large, Y_large, G_dense, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5 | ||
| ) | ||
| plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3) | ||
| plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3) | ||
| plt.axis("equal") | ||
| plt.title("Dense OT: Optimal Transport Plan") | ||
|
|
||
| plt.subplot(1, 2, 2) | ||
| ot.plot.plot2D_samples_mat( | ||
| X_large, Y_large, G_sparse, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5 | ||
| ) | ||
| plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3) | ||
| plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3) | ||
| plt.axis("equal") | ||
| plt.title("Sparse OT: Optimal Transport Plan") | ||
|
|
||
| plt.tight_layout() | ||
| plt.show() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for log np.sum(G_sparse*M_sparse_large) shoudl work