From 5013e330b5f7f10ff0cff7c8e1d5bac074249c78 Mon Sep 17 00:00:00 2001 From: hinanohart Date: Sun, 17 May 2026 00:16:25 +0900 Subject: [PATCH] [WIP] entropic_partial_wasserstein_logscale: stable log-domain solver (rescue of #724) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Re-applies the function from PR #724 by wzm2256 on top of current master (the original PR is stuck at CONFLICTING since 2025-09; this takes the additive parts and skips the obsolete merges through the March-2025 single-file layout). Subject is [WIP] because the original PR is also [WIP] and maintainer review is still required. Changes vs master: ot/partial/partial_solvers.py + entropic_partial_wasserstein_logscale (function body verbatim from PR #724 modulo: (a) duplicate sphinx label removed to avoid build failure, (b) print warning -> warnings.warn(stacklevel=2) for convention). ot/partial/__init__.py + entropic_partial_wasserstein_logscale export test/test_partial.py + test_entropic_partial_wasserstein_logscale_matches_old_at_large_reg (machine-precision agreement at reg in {10.0, 1.0}: atol=1e-10) + test_entropic_partial_wasserstein_logscale_no_nan_at_small_reg (parametrised over reg in {0.1, 0.05, 0.01, 5e-3, 1e-3, 5e-4}) + test_entropic_partial_wasserstein_logscale_approaches_exact_at_small_reg (plan-cost gap vs exact partial OT at reg=1e-3) + test_entropic_partial_wasserstein_logscale_log_dict + test_entropic_partial_wasserstein_logscale_input_validation examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py + Sphinx-Gallery example reproducing issue #723 + the fix (MPLBACKEND=Agg-safe; narrative softened to acknowledge BLAS/platform-dependent underflow boundary). docs/source/user_guide.rst + one-paragraph mention next to entropic_partial_wasserstein. RELEASES.md + entry under 0.9.7.dev0 (phrased as "mitigated via new log-domain variant", not "fixed", since the standard solver itself is unchanged). Verified locally on master at 41a4d57: pytest test/ -> 1939 passed, 97 skipped, 6 xfailed (no regressions). pytest test/test_partial.py -> 19 passed (8 originals + 11 new parametrised cases for the logscale function). Example script runs end-to-end with MPLBACKEND=Agg. The new function agrees with the standard solver at reg >= 1.0 to ~1e-18 absolute (atol=1e-10 in tests is conservative) and stays finite at reg down to 5e-4 on a 50x50 cost-scale-~50 problem (the exact failure mode of issue #723); std solver returns NaN at reg ~ 0.05-0.01 on the same problem. References Issue #723. Maintainer review needed before merge — author attribution to wzm2256 retained via Co-authored-by trailer. Co-authored-by: wzm2256 --- RELEASES.md | 2 + docs/source/user_guide.rst | 7 +- ...t_entropic_partial_wasserstein_logscale.py | 117 +++++++++++++ ot/partial/__init__.py | 2 + ot/partial/partial_solvers.py | 165 ++++++++++++++++++ test/test_partial.py | 108 ++++++++++++ 6 files changed, 399 insertions(+), 2 deletions(-) create mode 100644 examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..49950a957 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,9 +14,11 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Add numerically stable log-domain entropic partial Wasserstein solver `entropic_partial_wasserstein_logscale` (Issue #723) #### Closed issues +- Mitigate NaN regime of `entropic_partial_wasserstein` at small `reg` via a new log-domain alternative `entropic_partial_wasserstein_logscale` (Issue #723; the standard solver itself is unchanged — callers must opt into the log-domain variant) - Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788) - Fix MSVC Windows build by removing __restrict__ keyword (PR #788) - Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785) diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst index bfdd36721..3e3408c79 100644 --- a/docs/source/user_guide.rst +++ b/docs/source/user_guide.rst @@ -791,8 +791,11 @@ Interestingly the problem can be casted into a regular OT problem by adding rese in which the surplus mass is sent [29]_. We provide a solver for partial OT in :any:`ot.partial`. The exact resolution of the problem is computed in :any:`ot.partial.partial_wasserstein` and :any:`ot.partial.partial_wasserstein2` that return respectively the OT matrix and the value of the -linear term. The entropic solution of the problem is computed in :any:`ot.partial.entropic_partial_wasserstein` -(see [3]_). +linear term. The entropic solution of the problem is computed in :any:`ot.partial.entropic_partial_wasserstein` +(see [3]_). A numerically stable log-domain variant +:any:`ot.partial.entropic_partial_wasserstein_logscale` is also provided for small regularisation +values where the standard solver returns NaN; it solves exactly the same problem but is slower +because it computes everything in log-space. The partial Gromov-Wasserstein formulation of the problem diff --git a/examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py b/examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py new file mode 100644 index 000000000..33867a33e --- /dev/null +++ b/examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +""" +========================================================================== +Numerically-stable entropic partial Wasserstein (log-domain solver) +========================================================================== + +.. note:: + Example added in release: 0.9.7. + +`ot.partial.entropic_partial_wasserstein` is numerically unstable at small +regularisation: the iterates underflow to zero and the returned plan +contains NaNs (see PythonOT/POT issue #723). This example reproduces the +failure mode on a small problem and shows that +:any:`ot.partial.entropic_partial_wasserstein_logscale` produces a finite +plan over the same sweep, agreeing with the original solver at large +``reg`` and degrading gracefully at small ``reg``. + +The log-domain solver is slower per iteration than the standard one, so +the recommendation is to use the standard solver by default and fall +back to the log-domain solver when ``reg`` is small enough to risk +underflow. +""" + +# Author: wzm2256 (original PR #724) +# License: MIT License + +import numpy as np +import scipy as sp +import matplotlib.pylab as pl + +import ot + +############################################################################## +# Construct a 50x50 cost matrix +# ----------------------------- +# +# Mirrors the cost-matrix scale (~50) used in PythonOT/POT issue #723. + +rng = np.random.RandomState(0) +n = 50 +xs = rng.rand(n, 2) +xt = rng.rand(n, 2) +M = sp.spatial.distance.cdist(xs, xt) * 50.0 + +a = np.ones(n) / n +b = np.ones(n) / n +m = 0.6 # transport ~60% of the mass + +############################################################################## +# Sweep regularisation +# -------------------- +# +# Run both solvers across a range of ``reg`` values. On this 50×50 problem +# at cost-scale 50 the standard solver returns NaN at the ``reg`` values +# closest to the underflow boundary (typically ``reg`` ~0.05–0.01 in our +# runs, though the exact transition depends on the BLAS / platform's +# float64 underflow behaviour); the log-domain solver stays finite over +# the whole sweep, including the very small ``reg`` regime where the +# standard exp(−M/reg) path would underflow to zero everywhere. + +regs = [1.0, 0.5, 0.1, 0.05, 0.01, 5e-3, 1e-3, 5e-4] +standard_finite = [] +logscale_finite = [] +standard_mass = [] +logscale_mass = [] + +for reg in regs: + G_std = ot.partial.entropic_partial_wasserstein( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + G_log = ot.partial.entropic_partial_wasserstein_logscale( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + standard_finite.append(bool(np.isfinite(G_std).all())) + logscale_finite.append(bool(np.isfinite(G_log).all())) + standard_mass.append(float(G_std.sum()) if np.isfinite(G_std).all() else np.nan) + logscale_mass.append(float(G_log.sum())) + +print( + "reg standard_finite logscale_finite std_mass logscale_mass (target m={:.2f})".format( + m + ) +) +for reg, sf, lf, sm, lm in zip( + regs, standard_finite, logscale_finite, standard_mass, logscale_mass +): + print(f"{reg:>10.4g} {str(sf):<14} {str(lf):<14} {sm:>8.3f} {lm:>8.3f}") + +############################################################################## +# Plot the resulting plans at large vs. small reg +# ----------------------------------------------- + +fig, axes = pl.subplots(2, 2, figsize=(9, 8)) +for ax, reg in zip(axes[:, 0], (1.0, 0.01)): + G_std = ot.partial.entropic_partial_wasserstein( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + if not np.isfinite(G_std).all(): + G_std = np.zeros_like(G_std) + ax.set_title(f"standard, reg={reg} (NaN)") + else: + ax.set_title(f"standard, reg={reg}") + ax.imshow(G_std, cmap="viridis", aspect="auto") + ax.set_xlabel("target") + ax.set_ylabel("source") + +for ax, reg in zip(axes[:, 1], (1.0, 0.01)): + G_log = ot.partial.entropic_partial_wasserstein_logscale( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + ax.set_title(f"logscale, reg={reg}") + ax.imshow(G_log, cmap="viridis", aspect="auto") + ax.set_xlabel("target") + ax.set_ylabel("source") + +fig.tight_layout() +pl.show() diff --git a/ot/partial/__init__.py b/ot/partial/__init__.py index 9bb4d0433..c7447a76f 100644 --- a/ot/partial/__init__.py +++ b/ot/partial/__init__.py @@ -13,6 +13,7 @@ partial_wasserstein, partial_wasserstein2, entropic_partial_wasserstein, + entropic_partial_wasserstein_logscale, gwgrad_partial, gwloss_partial, partial_gromov_wasserstein, @@ -28,6 +29,7 @@ "partial_wasserstein", "partial_wasserstein2", "entropic_partial_wasserstein", + "entropic_partial_wasserstein_logscale", "gwgrad_partial", "gwloss_partial", "partial_gromov_wasserstein", diff --git a/ot/partial/partial_solvers.py b/ot/partial/partial_solvers.py index 98a3eff26..83a3c44f2 100755 --- a/ot/partial/partial_solvers.py +++ b/ot/partial/partial_solvers.py @@ -581,6 +581,171 @@ def entropic_partial_wasserstein( return K +def entropic_partial_wasserstein_logscale( + a, b, M, reg, m=None, numItermax=1000, stopThr=1e-100, verbose=False, log=False +): + r""" + Solves the partial optimal transport problem + and returns the OT plan + + This function solves exactly the same problem as + :any:`entropic_partial_wasserstein`, but it is in log-scale, i.e., the log-sum-exp trick is used, + so it is more stable but slower. + + The input and output format is the same as :any:`entropic_partial_wasserstein`. + + The function considers the following problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, + \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) + + s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ + \gamma^T \mathbf{1} &\leq \mathbf{b} \\ + \gamma &\geq 0 \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ + + where : + + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported + + The formulation of the problem has been proposed in + :ref:`[3] ` (prop. 5) + + + Parameters + ---------- + a : np.ndarray (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : np.ndarray (dim_b,) + Unnormalized histograms of dimension `dim_b` + M : np.ndarray (dim_a, dim_b) + cost matrix + reg : float + Regularization term > 0 + m : float, optional + Amount of mass to be transported + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (dim_a, dim_b) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + + Examples + -------- + >>> import ot + >>> a = [.1, .2] + >>> b = [.1, .1] + >>> M = [[0., 1.], [2., 3.]] + >>> np.round(entropic_partial_wasserstein_logscale(a, b, M, 1, 0.1), 2) + array([[0.06, 0.02], + [0.01, 0. ]]) + + + References + ---------- + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. + (2015). Iterative Bregman projections for regularized transportation + problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + See Also + -------- + ot.partial.partial_wasserstein: exact Partial Wasserstein + ot.partial.entropic_partial_wasserstein: numerically unstable entropic version + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b, M) + + dim_a, dim_b = M.shape + + Ldx = nx.zeros(dim_a, type_as=a) + Ldy = nx.zeros(dim_b, type_as=b) + + if len(a) == 0: + a = nx.ones(dim_a, type_as=a) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=b) / dim_b + + La = nx.log(a) + Lb = nx.log(b) + + if m is None: + m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0 + if m < 0: + raise ValueError("Problem infeasible. Parameter m should be greater than 0.") + if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))): + raise ValueError( + "Problem infeasible. Parameter m should lower or" + " equal than min(|a|_1, |b|_1)." + ) + + log_e = {"err": []} + + LK = -M / reg + LK = LK + nx.log(m) - nx.logsumexp(LK) + + err, cpt = 1, 0 + + Lq1 = nx.zeros(M.shape, type_as=M) + Lq2 = nx.zeros(M.shape, type_as=M) + Lq3 = nx.zeros(M.shape, type_as=M) + + while err > stopThr and cpt < numItermax: + LKprev = LK + LK = LK + Lq1 + LK1 = nx.reshape(nx.minimum(La - nx.logsumexp(LK, 1), Ldx), (-1, 1)) + LK + Lq1 = Lq1 + LKprev - LK1 + LK1prev = LK1 + LK1 = LK1 + Lq2 + LK2 = LK1 + nx.reshape(nx.minimum(Lb - nx.logsumexp(LK1, 0), Ldy), (1, -1)) + Lq2 = Lq2 + LK1prev - LK2 + LK2prev = LK2 + LK2 = LK2 + Lq3 + LK = LK2 + nx.log(m) - nx.logsumexp(LK2) + Lq3 = Lq3 + LK2prev - LK + + if nx.any(nx.isnan(LK)) or nx.any(nx.isinf(LK)): + warnings.warn( + f"Numerical errors at iteration {cpt} of entropic_partial_wasserstein_logscale", + stacklevel=2, + ) + break + if cpt % 10 == 0: + err = nx.norm(LKprev - LK) + if log: + log_e["err"].append(err) + if verbose: + if cpt % 200 == 0: + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 11) + print("{:5d}|{:8e}|".format(cpt, err)) + + cpt = cpt + 1 + log_e["partial_w_dist"] = nx.sum(M * nx.exp(LK)) + if log: + return nx.exp(LK), log_e + else: + return nx.exp(LK) + + def gwgrad_partial(C1, C2, T): """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] ` as the marginals may not sum to 1. diff --git a/test/test_partial.py b/test/test_partial.py index 6e54c364d..f802560c8 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -338,3 +338,111 @@ def test_partial_wasserstein_1d(): np.testing.assert_array_equal(np.sort(indices_x[:i]), np.sort(ind_x)) np.testing.assert_array_equal(np.sort(indices_y[:i]), np.sort(ind_y)) + + +# --------------------------------------------------------------------------- +# entropic_partial_wasserstein_logscale — new in this PR (rescue of #724) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("reg", [10.0, 1.0]) +def test_entropic_partial_wasserstein_logscale_matches_old_at_large_reg(reg): + """At large reg both solvers are stable; the plans must agree.""" + rng = np.random.RandomState(0) + n = 20 + a = rng.rand(n) + a /= a.sum() + b = rng.rand(n) + b /= b.sum() + M = ot.dist(rng.rand(n, 2), rng.rand(n, 2)) + m = 0.5 + + G_old = ot.partial.entropic_partial_wasserstein( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + G_log = ot.partial.entropic_partial_wasserstein_logscale( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + + # At reg >= 1.0 the two solvers agree to machine precision; if this + # tightens it would indicate the logscale path silently diverged. + np.testing.assert_allclose(G_old, G_log, atol=1e-10, rtol=1e-10) + np.testing.assert_allclose(G_log.sum(), m, atol=1e-10) + + +@pytest.mark.parametrize("reg", [0.1, 0.05, 0.01, 5e-3, 1e-3, 5e-4]) +def test_entropic_partial_wasserstein_logscale_no_nan_at_small_reg(reg): + """Issue #723: entropic_partial_wasserstein returns NaN at small reg. + + The logscale variant introduced by this PR is the fix; check that it + stays finite and conserves mass across the regime that breaks the + original solver. + """ + rng = np.random.RandomState(1) + n = 50 + a = rng.rand(n) + a /= a.sum() + b = rng.rand(n) + b /= b.sum() + M = ot.dist(rng.rand(n, 2), rng.rand(n, 2)) * 50.0 # match issue cost scale + m = 0.6 + + G = ot.partial.entropic_partial_wasserstein_logscale( + a, b, M, reg=reg, m=m, numItermax=2000 + ) + assert np.isfinite(G).all(), f"non-finite plan at reg={reg}" + np.testing.assert_allclose(G.sum(), m, atol=5e-3) + + +def test_entropic_partial_wasserstein_logscale_approaches_exact_at_small_reg(): + """At small `reg` the entropic plan should approach the exact partial + OT plan (modulo discretisation). Verifies the fix is mathematically + meaningful, not just NaN-free.""" + rng = np.random.RandomState(3) + n = 30 + a = np.ones(n) / n + b = np.ones(n) / n + M = ot.dist(rng.rand(n, 2), rng.rand(n, 2)) + m = 0.5 + + G_exact = ot.partial.partial_wasserstein(a, b, M, m=m) + G_log = ot.partial.entropic_partial_wasserstein_logscale( + a, b, M, reg=1e-3, m=m, numItermax=5000 + ) + + cost_exact = float((G_exact * M).sum()) + cost_log = float((G_log * M).sum()) + # The entropic objective is a relaxation of the exact one, so the + # plan-cost gap should be small but non-negative at reg → 0. + assert cost_log >= cost_exact - 1e-6 + assert cost_log - cost_exact < 0.01, ( + f"logscale plan cost {cost_log:.4f} diverges from exact {cost_exact:.4f}" + ) + + +def test_entropic_partial_wasserstein_logscale_log_dict(): + """`log=True` returns a dict with `err` and `partial_w_dist` keys.""" + rng = np.random.RandomState(2) + n = 10 + a = rng.rand(n) + a /= a.sum() + b = rng.rand(n) + b /= b.sum() + M = ot.dist(rng.rand(n, 2), rng.rand(n, 2)) + + G, log = ot.partial.entropic_partial_wasserstein_logscale( + a, b, M, reg=0.1, m=0.5, log=True + ) + assert "err" in log + assert "partial_w_dist" in log + assert np.isfinite(G).all() + + +def test_entropic_partial_wasserstein_logscale_input_validation(): + """Out-of-range `m` should raise ValueError, matching the unstable solver.""" + n = 10 + a = np.ones(n) / n + b = np.ones(n) / n + M = np.ones((n, n)) + with pytest.raises(ValueError): + ot.partial.entropic_partial_wasserstein_logscale(a, b, M, reg=0.1, m=-1.0) + with pytest.raises(ValueError): + ot.partial.entropic_partial_wasserstein_logscale(a, b, M, reg=0.1, m=2.0)