From 3e573404055ada83cb02493e483f5d7cbf913be1 Mon Sep 17 00:00:00 2001 From: mcorneli Date: Wed, 13 May 2026 16:18:59 +0200 Subject: [PATCH 1/4] ?? --- ot/gaussian.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/ot/gaussian.py b/ot/gaussian.py index 7c25cd660..233c06a40 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -93,6 +93,119 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): return A, b +def bures_wasserstein_mapping_hd( + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt, log=False +): + r"""Return OT linear operator between HD Gaussian distritutions. + + The function estimates the optimal linear operator that aligns the two + HD Gaussian distributions :math:`\mathcal{N}(\mu_s, U_s, l_s, \sigma_s^2, d_s)` + and :math:`\mathcal{N}(\mu_t, U_t, l_t, \sigma_t^2, d_t)` as proposed in + :ref:`[3] `, Th. 2.9 + . + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} \\ + + \Sigma_s^{1/2} &=\sigma_s I_p + U_s C_s U_s^T \\ + + C_s &=\diag(\sqrt{l_{s1} + \sigma_s^2} - \sigma_s, \dots, \sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s) \\ + + \Sigma_s^{-1/2} &= \frac{1}{\sigma_s} (I_p - U_s D_s U_s^T ) \\ + + D_s &= \diag((\sqrt{l_{s1} + \sigma_s^2} - \sigma_s)/\sqrt{l_{s1} + \sigma_s^2}, \dots, (\sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s)/\sqrt{l_{sd_s} + \sigma_s^2}) \\ + + \Sigma_t &= U_t \diag(l_t) U_t^T + \sigma_t^2 I_p \\ + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + Parameters + ---------- + ms : array-like (p,) + mean of the source distribution + mt : array-like (p,) + mean of the target distribution + Us : array-like (p,ds) + orthogonal matrix spanning the principal subspace of the source distribution + Ut : array-like (p,dt) + orthogonal matrix spanning the principal subspace of the target distribution + ls : array-like (ds,) + the variances associated with the principal sub-axes for the source distribution + lt : array-like (dt,) + the variances associated with the principal sub-axes for the target distribution + sigma_s^2 : array-like (1,) + the residual variance of the source distribution + sigma_t^2 : array-like (1,) + the residual variance of the target distribution + ds : array-like (1,) + the intrinsic dimension of the source distribution + dt : array-like (1,) + the intrinsic dimension of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + A : (d, d) array-like + Linear operator + b : (1, d) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + .. [3] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions") + """ + + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt = list_to_array( + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt + ) + nx = get_backend(ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt) + + p = Us.shape[0] + + # source + Cs = nx.diag(nx.sqrt(ls + sigma2_s) - nx.sqrt(sigma2_s)) + Ss_sq = dots(Us, Cs, Us.T) + nx.sqrt(sigma2_s) * nx.eye(p) + Ds = nx.diag((nx.sqrt(ls + sigma2_s) - nx.sqrt(sigma2_s)) / nx.sqrt(ls + sigma2_s)) + Ss_sqinv = (1 / nx.sqrt(sigma2_s)) * (nx.eye(p) - dots(Us, Ds, Us.T)) + + # destination + St = dots(Ut, nx.diag(lt + sigma2_t), Ut.T) + sigma2_t * nx.eye(p) + + M0 = nx.sqrtm(dots(Ss_sq, St, Ss_sq)) + + A = dots(Ss_sqinv, M0, Ss_sqinv) + b = mt - nx.dot(ms, A) + + if log: + log = {} + log["Ss_sq"] = Ss_sq + log["Ss_sqinv"] = Ss_sqinv + return A, b, log + else: + return A, b + + def empirical_bures_wasserstein_mapping( xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False ): From 4dcd41b28b6f6f8f77517267444395478c9c9669 Mon Sep 17 00:00:00 2001 From: mcorneli Date: Mon, 18 May 2026 12:53:52 +0200 Subject: [PATCH 2/4] =?UTF-8?q?bug=20corrig=C3=A9=20=C3=A0=20la=20ligne=20?= =?UTF-8?q?193?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ot/gaussian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index 233c06a40..5d9c97622 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -190,7 +190,7 @@ def bures_wasserstein_mapping_hd( Ss_sqinv = (1 / nx.sqrt(sigma2_s)) * (nx.eye(p) - dots(Us, Ds, Us.T)) # destination - St = dots(Ut, nx.diag(lt + sigma2_t), Ut.T) + sigma2_t * nx.eye(p) + St = dots(Ut, nx.diag(lt), Ut.T) + sigma2_t * nx.eye(p) M0 = nx.sqrtm(dots(Ss_sq, St, Ss_sq)) From 2d672d10a8777c11d8a78108e3fe29820c8111cf Mon Sep 17 00:00:00 2001 From: mcorneli Date: Tue, 19 May 2026 14:23:42 +0200 Subject: [PATCH 3/4] First draft of make_gauss_hd after some issues with pre-commit --- ot/datasets.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/ot/datasets.py b/ot/datasets.py index 6e3be518a..161fce670 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -8,6 +8,7 @@ import numpy as np import scipy as sp +from scipy.stats import ortho_group, multivariate_normal from .utils import check_random_state, deprecated @@ -180,3 +181,97 @@ def make_data_classif(dataset, n, nz=0.5, theta=0, p=0.5, random_state=None, **k def get_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs): """Deprecated see make_data_classif""" return make_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs) + + +def make_gauss_hd( + ns, nt, p=100, dim=5, m_diff=3, a=(10, 15), b=(3, 3), sub_the_same=False +): + """Generation of source and target domains from Gaussian HD distributions + + Parameters + ---------- + ns : int + number of samples (source) + nt : int + number of samples (target) + p : int + dimension of the ambient space the data live in + dim : (int,int) or int + the intrinsic dimensions of the source and target Gaussian HD distriutions. If a single int the intrinsic dimension is assumed to be the same + m_diff : float + the shift in the first coordinate of the means of the Gaussian HD distributions, i.e. ms_0 and mt_0, respectively (see code) + a : (float, float) + positive floating numbers corresponding to the isotropic variances in the principal subspace, for the source and target distributions, respectively. The same as \delta in :ref:`[1] `, Proposition 2.2 + b : (float, float) + positive floating numbers corresponding to the isotropic variance outside the principal subspace for the source and target distributions, respectively. + sub_the_same : bool + should the source/target Gaussian HD distributions live in the same principal subspace? + + Returns + ------- + Xs : ndarray, shape (ns, p) + `ns` observations of size `p` (source) + Xt : ndarray, shape (nt, p) + `nt` observations of size `p` (destination) + pmts : list + a list containing the parameters of the Gaussian HD distributions + + .. _references-make_gauss_hd: + References + ---------- + + .. [1] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions") + + """ + d = (dim, dim) if isinstance(dim, int) else dim + mu = np.zeros((2, p)) + S = [] + mu[1, 0] = m_diff + Q = [ortho_group.rvs(p) for _ in range(2)] + + if sub_the_same: + Q[1] = Q[0] + + S.append( + Q[0] + @ np.diag(np.hstack((np.full(d[0], a[0]), np.full(p - d[0], b[0])))) + @ Q[0].T + ) + S.append( + Q[1] + @ np.diag(np.hstack((np.full(d[1], a[1]), np.full(p - d[1], b[1])))) + @ Q[1].T + ) + + Xs = multivariate_normal.rvs(mean=mu[0], cov=S[0], size=ns) + Xt = multivariate_normal.rvs(mean=mu[1], cov=S[1], size=ns) + + ms = mu[0] + mt = mu[1] + ds = d[0] + dt = d[1] + sigma2_s = np.array(b[0]) + sigma2_t = np.array(b[1]) + ls = np.repeat(a[0], ds) - sigma2_s + lt = np.repeat(a[1], dt) - sigma2_t + Us = Q[0][:, :ds] + Ut = Q[1][:, :dt] + ds = np.array([ds]) + dt = np.array([dt]) + + prmts = { + "ms": ms, + "mt": mt, + "sigma2_s": sigma2_s, + "sigma2_t": sigma2_t, + "ls": ls, + "lt": lt, + "Us": Us, + "Ut": Ut, + "ds": ds, + "dt": dt, + "Cs": S[0], + "Ct": S[1], + } + + return Xs, Xt, prmts From f163035991978e4a42136eb4c8a01604676bada7 Mon Sep 17 00:00:00 2001 From: mcorneli Date: Tue, 19 May 2026 14:51:11 +0200 Subject: [PATCH 4/4] test bures wasserstein hd --- test/test_gaussian.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 733fcfab9..fdb174caf 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -10,7 +10,7 @@ import pytest import ot -from ot.datasets import make_data_classif +from ot.datasets import make_data_classif, make_gauss_hd from ot.utils import is_all_finite @@ -42,6 +42,34 @@ def test_bures_wasserstein_mapping(nx): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) +def test_bures_wasserstein_mapping_hd(nx): + ns = 100 + nt = 100 + + Xs, Xt, ll = make_gauss_hd(ns, nt, p=50, dim=10, m_diff=5, a=(7, 7), b=(1, 1)) + + ms = ll["ms"] + mt = ll["mt"] + sigma2_s = ll["sigma2_s"] + sigma2_t = ll["sigma2_t"] + ls = ll["ls"] + lt = ll["lt"] + Us = ll["Us"] + Ut = ll["Ut"] + ds = ll["ds"] + dt = ll["dt"] + Cs = ll["Cs"] + Ct = ll["Ct"] + + A_hd, b_hd = ot.gaussian.bures_wasserstein_mapping_hd( + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt, log=False + ) + A, b = ot.gaussian.bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False) + + np.testing.assert_allclose(A_hd, A, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(b_hd, b, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("bias", [True, False]) def test_empirical_bures_wasserstein_mapping(nx, bias): ns = 50