Skip to content

[MRG] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812

Open
Ferdinand-Genans wants to merge 6 commits into
PythonOT:masterfrom
Ferdinand-Genans:feature/semi-discrete
Open

[MRG] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812
Ferdinand-Genans wants to merge 6 commits into
PythonOT:masterfrom
Ferdinand-Genans:feature/semi-discrete

Conversation

@Ferdinand-Genans
Copy link
Copy Markdown

Types of changes

  • New feature (non-breaking change which adds functionality)
  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation / Examples update

Motivation and context / Related issue

Semi-discrete optimal transport — continuous source, discrete target — is a setting that does not naturally fit POT's discrete-discrete solvers, yet appears in many applications: Monge–Kantorovich statistical depth and quantiles, generative modeling with a continuous prior, Brenier-map estimation, and any setup where the source is given by a sampler rather than a finite empirical distribution.

This PR adds ot.semidiscrete, a single-file backend-agnostic solver for the semi-dual of semi-discrete OT. Every public function takes a reg: float argument, so the user can switch between unregularized OT (reg=0) and the entropic formulation (reg>0) with a single parameter. The underlying algorithm is Averaged SGD on the semi-dual; optionally, the regularization can be decreased throughout the iterations via the DRAG schedule ([83] in README.md; Genans et al., NeurIPS 2025), which empirically improves convergence, especially on large scale problems, when the target regularization is small.

No existing issue tracks this addition.

How has this been tested

  • test/test_semidiscrete.py34 tests, automatically parametrized over every available POT backend (NumPy and PyTorch on the CI runner). Covers:
    • convergence to a known closed-form optimum on three toy problems (regular grid, nonuniform target weights, shifted 1D);
    • both plain SGD and the DRAG schedule;
    • the entropic regime;
    • a user-supplied custom cost;
    • helper functions (atom_weights row-stochasticity, c_transform reducing to min_j c(x, y_j) at g = 0, ot_map shape and finiteness);
    • solver options (warm start, projection bound, log dict, polyak_average=False, init_potential is not mutated).
  • One doctest on solve_semidiscrete (a 500-iter smoke run).
  • Full suite + doctest run locally: 35 passed in 5.88 s.
  • examples/others/plot_semidiscrete.py runs end-to-end in ~5 s on CPU NumPy and produces the Laguerre-cells figure that becomes the gallery page.
  • Also verified locally a larger scale problem on GPU with PyTorch: discrete measure with 20 000 atoms and 30 000 iterations of minibatch of 64 run in ~17 s on a single CUDA device.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Introduces ot.semidiscrete: Projected Averaged SGD on the semi-dual,
with an optional decreasing entropic-regularization schedule (DRAG)
from Genans et al. 2025. Works with NumPy, PyTorch, JAX, CuPy and
TensorFlow via ot.backend.

- ot/semidiscrete.py: solve_semidiscrete, atom_weights, ot_map,
  c_transform. Closed-form gradient, no autograd graph through the
  loop; quadratic cost by default with custom-callable override.
- ot/__init__.py: register the new submodule.
- test/test_semidiscrete.py: convergence on three toy problems
  with known optimal potentials, helper-function contracts (row-
  stochasticity of atom_weights, identity for c_transform at g=0,
  shape and finiteness of ot_map), and solver options (warm-start,
  projection, log, polyak_average off, entropic regime, custom
  cost). All tests parametrized over the nx fixture (NumPy + PyTorch).
- examples/others/plot_semidiscrete.py: gallery example on a small
  2D toy problem with Laguerre cells, empirical cell masses and a
  Monte Carlo estimate of the semi-dual cost.
- RELEASES.md: new-features entry under 0.9.7.dev0.
	modified:   RELEASES.md
	modified:   examples/others/plot_semidiscrete.py

Final small doc modifications.
solve_semidiscrete, and added a more detailed
explanation of the effect of this argument on
convergence in the example scipt.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant