Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
- Add numerically stable log-domain entropic partial Wasserstein solver `entropic_partial_wasserstein_logscale` (Issue #723)

#### Closed issues

- Mitigate NaN regime of `entropic_partial_wasserstein` at small `reg` via a new log-domain alternative `entropic_partial_wasserstein_logscale` (Issue #723; the standard solver itself is unchanged — callers must opt into the log-domain variant)
- Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788)
- Fix MSVC Windows build by removing __restrict__ keyword (PR #788)
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785)
Expand Down
7 changes: 5 additions & 2 deletions docs/source/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,11 @@ Interestingly the problem can be casted into a regular OT problem by adding rese
in which the surplus mass is sent [29]_. We provide a solver for partial OT
in :any:`ot.partial`. The exact resolution of the problem is computed in :any:`ot.partial.partial_wasserstein`
and :any:`ot.partial.partial_wasserstein2` that return respectively the OT matrix and the value of the
linear term. The entropic solution of the problem is computed in :any:`ot.partial.entropic_partial_wasserstein`
(see [3]_).
linear term. The entropic solution of the problem is computed in :any:`ot.partial.entropic_partial_wasserstein`
(see [3]_). A numerically stable log-domain variant
:any:`ot.partial.entropic_partial_wasserstein_logscale` is also provided for small regularisation
values where the standard solver returns NaN; it solves exactly the same problem but is slower
because it computes everything in log-space.

The partial Gromov-Wasserstein formulation of the problem

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
==========================================================================
Numerically-stable entropic partial Wasserstein (log-domain solver)
==========================================================================

.. note::
Example added in release: 0.9.7.

`ot.partial.entropic_partial_wasserstein` is numerically unstable at small
regularisation: the iterates underflow to zero and the returned plan
contains NaNs (see PythonOT/POT issue #723). This example reproduces the
failure mode on a small problem and shows that
:any:`ot.partial.entropic_partial_wasserstein_logscale` produces a finite
plan over the same sweep, agreeing with the original solver at large
``reg`` and degrading gracefully at small ``reg``.

The log-domain solver is slower per iteration than the standard one, so
the recommendation is to use the standard solver by default and fall
back to the log-domain solver when ``reg`` is small enough to risk
underflow.
"""

# Author: wzm2256 <wzm2256@qq.com> (original PR #724)
# License: MIT License

import numpy as np
import scipy as sp
import matplotlib.pylab as pl

import ot

##############################################################################
# Construct a 50x50 cost matrix
# -----------------------------
#
# Mirrors the cost-matrix scale (~50) used in PythonOT/POT issue #723.

rng = np.random.RandomState(0)
n = 50
xs = rng.rand(n, 2)
xt = rng.rand(n, 2)
M = sp.spatial.distance.cdist(xs, xt) * 50.0

a = np.ones(n) / n
b = np.ones(n) / n
m = 0.6 # transport ~60% of the mass

##############################################################################
# Sweep regularisation
# --------------------
#
# Run both solvers across a range of ``reg`` values. On this 50×50 problem
# at cost-scale 50 the standard solver returns NaN at the ``reg`` values
# closest to the underflow boundary (typically ``reg`` ~0.05–0.01 in our
# runs, though the exact transition depends on the BLAS / platform's
# float64 underflow behaviour); the log-domain solver stays finite over
# the whole sweep, including the very small ``reg`` regime where the
# standard exp(−M/reg) path would underflow to zero everywhere.

regs = [1.0, 0.5, 0.1, 0.05, 0.01, 5e-3, 1e-3, 5e-4]
standard_finite = []
logscale_finite = []
standard_mass = []
logscale_mass = []

for reg in regs:
G_std = ot.partial.entropic_partial_wasserstein(
a, b, M, reg=reg, m=m, numItermax=2000
)
G_log = ot.partial.entropic_partial_wasserstein_logscale(
a, b, M, reg=reg, m=m, numItermax=2000
)
standard_finite.append(bool(np.isfinite(G_std).all()))
logscale_finite.append(bool(np.isfinite(G_log).all()))
standard_mass.append(float(G_std.sum()) if np.isfinite(G_std).all() else np.nan)
logscale_mass.append(float(G_log.sum()))

print(
"reg standard_finite logscale_finite std_mass logscale_mass (target m={:.2f})".format(
m
)
)
for reg, sf, lf, sm, lm in zip(
regs, standard_finite, logscale_finite, standard_mass, logscale_mass
):
print(f"{reg:>10.4g} {str(sf):<14} {str(lf):<14} {sm:>8.3f} {lm:>8.3f}")

##############################################################################
# Plot the resulting plans at large vs. small reg
# -----------------------------------------------

fig, axes = pl.subplots(2, 2, figsize=(9, 8))
for ax, reg in zip(axes[:, 0], (1.0, 0.01)):
G_std = ot.partial.entropic_partial_wasserstein(
a, b, M, reg=reg, m=m, numItermax=2000
)
if not np.isfinite(G_std).all():
G_std = np.zeros_like(G_std)
ax.set_title(f"standard, reg={reg} (NaN)")
else:
ax.set_title(f"standard, reg={reg}")
ax.imshow(G_std, cmap="viridis", aspect="auto")
ax.set_xlabel("target")
ax.set_ylabel("source")

for ax, reg in zip(axes[:, 1], (1.0, 0.01)):
G_log = ot.partial.entropic_partial_wasserstein_logscale(
a, b, M, reg=reg, m=m, numItermax=2000
)
ax.set_title(f"logscale, reg={reg}")
ax.imshow(G_log, cmap="viridis", aspect="auto")
ax.set_xlabel("target")
ax.set_ylabel("source")

fig.tight_layout()
pl.show()
2 changes: 2 additions & 0 deletions ot/partial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
partial_wasserstein,
partial_wasserstein2,
entropic_partial_wasserstein,
entropic_partial_wasserstein_logscale,
gwgrad_partial,
gwloss_partial,
partial_gromov_wasserstein,
Expand All @@ -28,6 +29,7 @@
"partial_wasserstein",
"partial_wasserstein2",
"entropic_partial_wasserstein",
"entropic_partial_wasserstein_logscale",
"gwgrad_partial",
"gwloss_partial",
"partial_gromov_wasserstein",
Expand Down
165 changes: 165 additions & 0 deletions ot/partial/partial_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,171 @@ def entropic_partial_wasserstein(
return K


def entropic_partial_wasserstein_logscale(
a, b, M, reg, m=None, numItermax=1000, stopThr=1e-100, verbose=False, log=False
):
r"""
Solves the partial optimal transport problem
and returns the OT plan

This function solves exactly the same problem as
:any:`entropic_partial_wasserstein`, but it is in log-scale, i.e., the log-sum-exp trick is used,
so it is more stable but slower.

The input and output format is the same as :any:`entropic_partial_wasserstein`.

The function considers the following problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma,
\mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)

s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\
\gamma^T \mathbf{1} &\leq \mathbf{b} \\
\gamma &\geq 0 \\
\mathbf{1}^T \gamma^T \mathbf{1} = m
&\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\

where :

- :math:`\mathbf{M}` is the metric cost matrix
- :math:`\Omega` is the entropic regularization term,
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- `m` is the amount of mass to be transported

The formulation of the problem has been proposed in
:ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)


Parameters
----------
a : np.ndarray (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix
reg : float
Regularization term > 0
m : float, optional
Amount of mass to be transported
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True


Returns
-------
gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`


Examples
--------
>>> import ot
>>> a = [.1, .2]
>>> b = [.1, .1]
>>> M = [[0., 1.], [2., 3.]]
>>> np.round(entropic_partial_wasserstein_logscale(a, b, M, 1, 0.1), 2)
array([[0.06, 0.02],
[0.01, 0. ]])


References
----------
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.

See Also
--------
ot.partial.partial_wasserstein: exact Partial Wasserstein
ot.partial.entropic_partial_wasserstein: numerically unstable entropic version
"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(a, b, M)

dim_a, dim_b = M.shape

Ldx = nx.zeros(dim_a, type_as=a)
Ldy = nx.zeros(dim_b, type_as=b)

if len(a) == 0:
a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
b = nx.ones(dim_b, type_as=b) / dim_b

La = nx.log(a)
Lb = nx.log(b)

if m is None:
m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError(
"Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1)."
)

log_e = {"err": []}

LK = -M / reg
LK = LK + nx.log(m) - nx.logsumexp(LK)

err, cpt = 1, 0

Lq1 = nx.zeros(M.shape, type_as=M)
Lq2 = nx.zeros(M.shape, type_as=M)
Lq3 = nx.zeros(M.shape, type_as=M)

while err > stopThr and cpt < numItermax:
LKprev = LK
LK = LK + Lq1
LK1 = nx.reshape(nx.minimum(La - nx.logsumexp(LK, 1), Ldx), (-1, 1)) + LK
Lq1 = Lq1 + LKprev - LK1
LK1prev = LK1
LK1 = LK1 + Lq2
LK2 = LK1 + nx.reshape(nx.minimum(Lb - nx.logsumexp(LK1, 0), Ldy), (1, -1))
Lq2 = Lq2 + LK1prev - LK2
LK2prev = LK2
LK2 = LK2 + Lq3
LK = LK2 + nx.log(m) - nx.logsumexp(LK2)
Lq3 = Lq3 + LK2prev - LK

if nx.any(nx.isnan(LK)) or nx.any(nx.isinf(LK)):
warnings.warn(
f"Numerical errors at iteration {cpt} of entropic_partial_wasserstein_logscale",
stacklevel=2,
)
break
if cpt % 10 == 0:
err = nx.norm(LKprev - LK)
if log:
log_e["err"].append(err)
if verbose:
if cpt % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 11)
print("{:5d}|{:8e}|".format(cpt, err))

cpt = cpt + 1
log_e["partial_w_dist"] = nx.sum(M * nx.exp(LK))
if log:
return nx.exp(LK), log_e
else:
return nx.exp(LK)


def gwgrad_partial(C1, C2, T):
"""Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>`
as the marginals may not sum to 1.
Expand Down
Loading
Loading