Skip to content

Commit dc82e14

Browse files
authored
Merge branch 'master' into doc_cleanup
2 parents e7a58a4 + d2ea084 commit dc82e14

File tree

11 files changed

+790
-19
lines changed

11 files changed

+790
-19
lines changed

.github/labeler.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ ot.solvers:
7474

7575
ot.partial:
7676
- changed-files:
77-
- any-glob-to-any-file: ot/partial.py
77+
- any-glob-to-any-file: ot/partial/**
7878

7979
ot.sliced:
8080
- changed-files:
@@ -94,4 +94,4 @@ ot.dr:
9494

9595
ot.gnn:
9696
- changed-files:
97-
- any-glob-to-any-file: ot/gnn/**
97+
- any-glob-to-any-file: ot/gnn/**

.github/workflows/build_doc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
python -m pip install --user --upgrade --progress-bar off pip
2727
python -m pip install --user --upgrade --progress-bar off -r requirements_all.txt
2828
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
29-
python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
29+
python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler
3030
python -m pip install --user -e .
3131
# Look at what we have and fail early if there is some library conflict
3232
- name: Check installation

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ docs/modules/
1515

1616
# Cython output
1717
ot/lp/emd_wrap.cpp
18+
ot/partial/partial_cython.cpp
1819

1920
# Distribution / packaging
2021
.Python

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,5 @@ Artificial Intelligence.
435435
[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR.
436436

437437
[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.
438+
439+
[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations.

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- Backend implementation of `ot.dist` for (PR #701)
1818
- Updated documentation Quickstart guide and User guide with new API (PR #726)
1919
- Fix jax version for auto-grad (PR #732)
20+
- Implement 1d solver for partial optimal transport (PR #741)
2021
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
2122
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)
2223
- Removed release information from quickstart guide (PR #744)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
=========================
3+
Partial Wasserstein in 1D
4+
=========================
5+
6+
This script demonstrates how to compute and visualize the Partial Wasserstein distance between two 1D discrete distributions using `ot.partial.partial_wasserstein_1d`.
7+
8+
We illustrate the intermediate transport plans for all `k = 1...n`, where `n = min(len(x_a), len(x_b))`.
9+
"""
10+
11+
# sphinx_gallery_thumbnail_number = 5
12+
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
from ot.partial import partial_wasserstein_1d
16+
17+
18+
def plot_partial_transport(
19+
ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None
20+
):
21+
y_a = np.ones_like(x_a)
22+
y_b = -np.ones_like(x_b)
23+
min_min = min(x_a.min(), x_b.min())
24+
max_max = max(x_a.max(), x_b.max())
25+
26+
ax.plot([min_min - 1, max_max + 1], [1, 1], "k-", lw=0.5, alpha=0.5)
27+
ax.plot([min_min - 1, max_max + 1], [-1, -1], "k-", lw=0.5, alpha=0.5)
28+
29+
# Plot transport lines
30+
if indices_a is not None and indices_b is not None:
31+
subset_a = np.sort(x_a[indices_a])
32+
subset_b = np.sort(x_b[indices_b])
33+
34+
for x_a_i, x_b_j in zip(subset_a, subset_b):
35+
ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7)
36+
37+
# Plot all points
38+
ax.plot(x_a, y_a, "o", color="C0", label="x_a", markersize=8)
39+
ax.plot(x_b, y_b, "o", color="C1", label="x_b", markersize=8)
40+
41+
if marginal_costs is not None:
42+
k = len(marginal_costs)
43+
ax.set_title(
44+
f"Partial Transport - k = {k}, Cumulative Cost = {sum(marginal_costs):.2f}",
45+
fontsize=16,
46+
)
47+
else:
48+
ax.set_title("Original 1D Discrete Distributions", fontsize=16)
49+
ax.legend(loc="upper right", fontsize=14)
50+
ax.set_yticks([])
51+
ax.set_xticks([])
52+
ax.set_ylim(-2, 2)
53+
ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1)
54+
ax.axis("off")
55+
56+
57+
# Simulate two 1D discrete distributions
58+
np.random.seed(0)
59+
n = 6
60+
x_a = np.sort(np.random.uniform(0, 10, size=n))
61+
x_b = np.sort(np.random.uniform(0, 10, size=n))
62+
63+
# Plot original distributions
64+
plt.figure(figsize=(6, 2))
65+
plot_partial_transport(plt.gca(), x_a, x_b)
66+
plt.show()
67+
68+
# %%
69+
indices_a, indices_b, marginal_costs = partial_wasserstein_1d(x_a, x_b)
70+
71+
# Compute cumulative cost
72+
cumulative_costs = np.cumsum(marginal_costs)
73+
74+
# Visualize all partial transport plans
75+
for k in range(n):
76+
plt.figure(figsize=(6, 2))
77+
plot_partial_transport(
78+
plt.gca(),
79+
x_a,
80+
x_b,
81+
indices_a[: k + 1],
82+
indices_b[: k + 1],
83+
marginal_costs[: k + 1],
84+
)
85+
plt.show()

ot/partial/__init__.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Efficient 1D solver for the partial optimal transport problem.
4+
"""
5+
6+
# Author: Romain Tavenard <romain.tavenard@univ-rennes2.fr>
7+
#
8+
# License: MIT License
9+
10+
# import compiled emd
11+
from .partial_solvers import (
12+
partial_wasserstein_lagrange,
13+
partial_wasserstein,
14+
partial_wasserstein2,
15+
entropic_partial_wasserstein,
16+
gwgrad_partial,
17+
gwloss_partial,
18+
partial_gromov_wasserstein,
19+
partial_gromov_wasserstein2,
20+
entropic_partial_gromov_wasserstein,
21+
entropic_partial_gromov_wasserstein2,
22+
partial_wasserstein_1d,
23+
)
24+
25+
__all__ = [
26+
"partial_wasserstein_1d",
27+
"partial_wasserstein_lagrange",
28+
"partial_wasserstein",
29+
"partial_wasserstein2",
30+
"entropic_partial_wasserstein",
31+
"gwgrad_partial",
32+
"gwloss_partial",
33+
"partial_gromov_wasserstein",
34+
"partial_gromov_wasserstein2",
35+
"entropic_partial_gromov_wasserstein",
36+
"entropic_partial_gromov_wasserstein2",
37+
]

0 commit comments

Comments
 (0)