[MRG] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812
Open
Ferdinand-Genans wants to merge 6 commits into
Open
Conversation
Introduces ot.semidiscrete: Projected Averaged SGD on the semi-dual, with an optional decreasing entropic-regularization schedule (DRAG) from Genans et al. 2025. Works with NumPy, PyTorch, JAX, CuPy and TensorFlow via ot.backend. - ot/semidiscrete.py: solve_semidiscrete, atom_weights, ot_map, c_transform. Closed-form gradient, no autograd graph through the loop; quadratic cost by default with custom-callable override. - ot/__init__.py: register the new submodule. - test/test_semidiscrete.py: convergence on three toy problems with known optimal potentials, helper-function contracts (row- stochasticity of atom_weights, identity for c_transform at g=0, shape and finiteness of ot_map), and solver options (warm-start, projection, log, polyak_average off, entropic regime, custom cost). All tests parametrized over the nx fixture (NumPy + PyTorch). - examples/others/plot_semidiscrete.py: gallery example on a small 2D toy problem with Laguerre cells, empirical cell masses and a Monte Carlo estimate of the semi-dual cost. - RELEASES.md: new-features entry under 0.9.7.dev0.
modified: RELEASES.md modified: examples/others/plot_semidiscrete.py Final small doc modifications.
solve_semidiscrete, and added a more detailed explanation of the effect of this argument on convergence in the example scipt.
"max_cost" explanation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Types of changes
Motivation and context / Related issue
Semi-discrete optimal transport — continuous source, discrete target — is a setting that does not naturally fit POT's discrete-discrete solvers, yet appears in many applications: Monge–Kantorovich statistical depth and quantiles, generative modeling with a continuous prior, Brenier-map estimation, and any setup where the source is given by a sampler rather than a finite empirical distribution.
This PR adds
ot.semidiscrete, a single-file backend-agnostic solver for the semi-dual of semi-discrete OT. Every public function takes areg: floatargument, so the user can switch between unregularized OT (reg=0) and the entropic formulation (reg>0) with a single parameter. The underlying algorithm is Averaged SGD on the semi-dual; optionally, the regularization can be decreased throughout the iterations via the DRAG schedule ([83] inREADME.md; Genans et al., NeurIPS 2025), which empirically improves convergence, especially on large scale problems, when the target regularization is small.No existing issue tracks this addition.
How has this been tested
test/test_semidiscrete.py— 34 tests, automatically parametrized over every available POT backend (NumPy and PyTorch on the CI runner). Covers:atom_weightsrow-stochasticity,c_transformreducing tomin_j c(x, y_j)atg = 0,ot_mapshape and finiteness);polyak_average=False,init_potentialis not mutated).solve_semidiscrete(a 500-iter smoke run).35 passed in 5.88 s.examples/others/plot_semidiscrete.pyruns end-to-end in ~5 s on CPU NumPy and produces the Laguerre-cells figure that becomes the gallery page.PR checklist