Skip to content
95 changes: 95 additions & 0 deletions ot/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import scipy as sp
from scipy.stats import ortho_group, multivariate_normal
from .utils import check_random_state, deprecated


Expand Down Expand Up @@ -180,3 +181,97 @@ def make_data_classif(dataset, n, nz=0.5, theta=0, p=0.5, random_state=None, **k
def get_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs):
"""Deprecated see make_data_classif"""
return make_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs)


def make_gauss_hd(
ns, nt, p=100, dim=5, m_diff=3, a=(10, 15), b=(3, 3), sub_the_same=False
):
"""Generation of source and target domains from Gaussian HD distributions

Parameters
----------
ns : int
number of samples (source)
nt : int
number of samples (target)
p : int
dimension of the ambient space the data live in
dim : (int,int) or int
the intrinsic dimensions of the source and target Gaussian HD distriutions. If a single int the intrinsic dimension is assumed to be the same
m_diff : float
the shift in the first coordinate of the means of the Gaussian HD distributions, i.e. ms_0 and mt_0, respectively (see code)
a : (float, float)
positive floating numbers corresponding to the isotropic variances in the principal subspace, for the source and target distributions, respectively. The same as \delta in :ref:`[1] <references-make_gauss-hd>`, Proposition 2.2
b : (float, float)
positive floating numbers corresponding to the isotropic variance outside the principal subspace for the source and target distributions, respectively.
sub_the_same : bool
should the source/target Gaussian HD distributions live in the same principal subspace?

Returns
-------
Xs : ndarray, shape (ns, p)
`ns` observations of size `p` (source)
Xt : ndarray, shape (nt, p)
`nt` observations of size `p` (destination)
pmts : list
a list containing the parameters of the Gaussian HD distributions

.. _references-make_gauss_hd:
References
----------

.. [1] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions")

"""
d = (dim, dim) if isinstance(dim, int) else dim
mu = np.zeros((2, p))
S = []
mu[1, 0] = m_diff
Q = [ortho_group.rvs(p) for _ in range(2)]

if sub_the_same:
Q[1] = Q[0]

S.append(
Q[0]
@ np.diag(np.hstack((np.full(d[0], a[0]), np.full(p - d[0], b[0]))))
@ Q[0].T
)
S.append(
Q[1]
@ np.diag(np.hstack((np.full(d[1], a[1]), np.full(p - d[1], b[1]))))
@ Q[1].T
)

Xs = multivariate_normal.rvs(mean=mu[0], cov=S[0], size=ns)
Xt = multivariate_normal.rvs(mean=mu[1], cov=S[1], size=ns)

ms = mu[0]
mt = mu[1]
ds = d[0]
dt = d[1]
sigma2_s = np.array(b[0])
sigma2_t = np.array(b[1])
ls = np.repeat(a[0], ds) - sigma2_s
lt = np.repeat(a[1], dt) - sigma2_t
Us = Q[0][:, :ds]
Ut = Q[1][:, :dt]
ds = np.array([ds])
dt = np.array([dt])

prmts = {
"ms": ms,
"mt": mt,
"sigma2_s": sigma2_s,
"sigma2_t": sigma2_t,
"ls": ls,
"lt": lt,
"Us": Us,
"Ut": Ut,
"ds": ds,
"dt": dt,
"Cs": S[0],
"Ct": S[1],
}

return Xs, Xt, prmts
262 changes: 261 additions & 1 deletion ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,119 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
return A, b


def bures_wasserstein_mapping_hd(
ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt, log=False
):
r"""Return OT linear operator between HD Gaussian distritutions.

The function estimates the optimal linear operator that aligns the two
HD Gaussian distributions :math:`\mathcal{N}(\mu_s, U_s, l_s, \sigma_s^2, d_s)`
and :math:`\mathcal{N}(\mu_t, U_t, l_t, \sigma_t^2, d_t)` as proposed in
:ref:`[3] <references-OT-mapping-linear>`, Th. 2.9
.

The linear operator from source to target :math:`M`

.. math::
M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}

where :

.. math::
\mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2} \\

\Sigma_s^{1/2} &=\sigma_s I_p + U_s C_s U_s^T \\

C_s &=\diag(\sqrt{l_{s1} + \sigma_s^2} - \sigma_s, \dots, \sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s) \\

\Sigma_s^{-1/2} &= \frac{1}{\sigma_s} (I_p - U_s D_s U_s^T ) \\

D_s &= \diag((\sqrt{l_{s1} + \sigma_s^2} - \sigma_s)/\sqrt{l_{s1} + \sigma_s^2}, \dots, (\sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s)/\sqrt{l_{sd_s} + \sigma_s^2}) \\

\Sigma_t &= U_t \diag(l_t) U_t^T + \sigma_t^2 I_p \\

\mathbf{b} &= \mu_t - \mathbf{A} \mu_s

Parameters
----------
ms : array-like (p,)
mean of the source distribution
mt : array-like (p,)
mean of the target distribution
Us : array-like (p,ds)
orthogonal matrix spanning the principal subspace of the source distribution
Ut : array-like (p,dt)
orthogonal matrix spanning the principal subspace of the target distribution
ls : array-like (ds,)
the variances associated with the principal sub-axes for the source distribution
lt : array-like (dt,)
the variances associated with the principal sub-axes for the target distribution
sigma_s^2 : array-like (1,)
the residual variance of the source distribution
sigma_t^2 : array-like (1,)
the residual variance of the target distribution
ds : array-like (1,)
the intrinsic dimension of the source distribution
dt : array-like (1,)
the intrinsic dimension of the target distribution
log : bool, optional
record log if True


Returns
-------
A : (d, d) array-like
Linear operator
b : (1, d) array-like
bias
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear:
References
----------
.. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984

.. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.

.. [3] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions")
"""

ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt = list_to_array(
ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt
)
nx = get_backend(ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt)

p = Us.shape[0]

# source
Cs = nx.diag(nx.sqrt(ls + sigma2_s) - nx.sqrt(sigma2_s))
Ss_sq = dots(Us, Cs, Us.T) + nx.sqrt(sigma2_s) * nx.eye(p)
Ds = nx.diag((nx.sqrt(ls + sigma2_s) - nx.sqrt(sigma2_s)) / nx.sqrt(ls + sigma2_s))
Ss_sqinv = (1 / nx.sqrt(sigma2_s)) * (nx.eye(p) - dots(Us, Ds, Us.T))

# destination
St = dots(Ut, nx.diag(lt), Ut.T) + sigma2_t * nx.eye(p)

M0 = nx.sqrtm(dots(Ss_sq, St, Ss_sq))

A = dots(Ss_sqinv, M0, Ss_sqinv)
b = mt - nx.dot(ms, A)

if log:
log = {}
log["Ss_sq"] = Ss_sq
log["Ss_sqinv"] = Ss_sqinv
return A, b, log
else:
return A, b


def empirical_bures_wasserstein_mapping(
xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False
):
Expand Down Expand Up @@ -128,7 +241,7 @@ def empirical_bures_wasserstein_mapping(
regularization added to the diagonals of covariances (>0)
ws : array-like (ns,1), optional
weights for the source samples
wt : array-like (ns,1), optional
wt : array-like (nt,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
Expand Down Expand Up @@ -200,6 +313,153 @@ def empirical_bures_wasserstein_mapping(
return A, b


def empirical_bures_wasserstein_mapping_hd(
xs, xt, ds, dt, reg=0.0, ws=None, wt=None, bias=True, log=False
):
r"""Return OT HD linear operator between samples.

The function estimates the optimal linear HD operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
form mapping between two HD Gaussian distributions :math:`\mathcal{N}(\mu_s, U_s, l_s, \sigma_s^2, d_s)`
and :math:`\mathcal{N}(\mu_t, U_t, l_t, \sigma_t^2, d_t)` as proposed in
:ref:`[3] <references-OT-mapping-linear>`, Th. 2.9.

The linear operator from source to target :math:`M`

.. math::
M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}

where :

.. math::
\mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2} \\

\Sigma_s^{1/2} &=\sigma_s I_p + U_s C_s U_s^T \\

C_s &=\diag(\sqrt{l_{s1} + \sigma_s^2} - \sigma_s, \dots, \sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s) \\

\Sigma_s^{-1/2} &= \frac{1}{\sigma_s} (I_p - U_s D_s U_s^T ) \\

D_s &= \diag((\sqrt{l_{s1} + \sigma_s^2} - \sigma_s)/\sqrt{l_{s1} + \sigma_s^2}, \dots, (\sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s)/\sqrt{l_{sd_s} + \sigma_s^2}) \\

\Sigma_t &= U_t \diag(l_t) U_t^T + \sigma_t^2 I_p \\

\mathbf{b} &= \mu_t - \mathbf{A} \mu_s


Parameters
----------
xs : array-like (ns,p)
samples in the source domain
xt : array-like (nt,p)
samples in the target domain
ds : array-like (1,)
the intrinsic dimension of the source distribution
dt : array-like(1,)
the intrinsic dimension of the target distribution
reg : float,optional
regularization added to the diagonals of covariances (null by default)
ws : array-like (ns,1), optional
weights for the source samples
wt : array-like (nt,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True


Returns
-------
A : (p, p) array-like
Linear operator
b : (1, p) array-like
bias
log : dict
log dictionary return only if log==True in parameters


.. _references-OT-mapping-linear:
References
----------
.. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984

.. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.

.. [3] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions")

"""

xs, xt, ds, dt = list_to_array(xs, xt, ds, dt)
nx = get_backend(xs, xt, ds, dt)
is_input_finite = is_all_finite(xs, xt, ds, dt)

p = xs.shape[1]

if ws is None:
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]

if wt is None:
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]

if bias:
mxs = nx.dot(ws.T, xs) / nx.sum(ws)
mxt = nx.dot(wt.T, xt) / nx.sum(wt)

xs = xs - mxs
xt = xt - mxt
else:
mxs = nx.zeros((1, p), type_as=xs)
mxt = nx.zeros((1, p), type_as=xs)

Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(p, type_as=xs)
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(p, type_as=xt)

eigs = nx.eigh(Cs)
a_s = eigs[0][-ds[0] :]
sgm2_s = (nx.trace(Cs) - nx.sum(a_s)) / (p - ds)
Qs = eigs[1]
Us = Qs[:, -ds[0] :]
ls = a_s - sgm2_s

eigt = nx.eigh(Ct)
a_t = eigt[0][-dt[0] :]
sgm2_t = (nx.trace(Ct) - nx.sum(a_t)) / (p - dt)
Qt = eigt[1]
Ut = Qt[:, -dt[0] :]
lt = a_t - sgm2_t

if log:
A, b, log = bures_wasserstein_mapping_hd(
mxs, mxt, Us, Ut, ls, lt, sgm2_s, sgm2_t, ds, dt, log=log
)
else:
A, b = bures_wasserstein_mapping_hd(
mxs, mxt, Us, Ut, ls, lt, sgm2_s, sgm2_t, ds, dt
)

if is_input_finite and not is_all_finite(A, b):
warnings.warn(
"Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping_hd. "
"Consider increasing the regularization parameter `reg` or reducing the intrinsic dimensions ds/dt."
)

if log:
log["Us"] = Us
log["Ut"] = Ut
log["ls"] = ls
log["lt"] = lt
log["sigma2_s"] = sgm2_s
log["sigma2_t"] = sgm2_t
return A, b, log
else:
return A, b


def bures_distance(Cs, Ct, paired=False, log=False, nx=None):
r"""Return Bures distance.

Expand Down
Loading
Loading