Sinkhorn Loss#

Module for the SinkhornLoss class.

class SinkhornLoss(p=2, eps=0.1, iterations=100)[source]#

Bases: BaseDualLoss

Implementation of the Sinkhorn loss measuring the entropy-regularized optimal transport distance between two empirical distributions.

Given an input tensor \(x\) with \(N\) samples and a target tensor \(y\) with \(M\) samples, both in \(\mathbb{R}^D\), the loss is defined through the entropy-regularized optimal transport problem:

\[W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} \langle C, \pi \rangle - \varepsilon H(\pi)\]

where \(\mu\) and \(\nu\) are the empirical distributions associated with \(x\) and \(y\), \(\pi\) is a transport plan, and \(\Pi(\mu, \nu)\) is the set of admissible transport plans with marginals \(\mu\) and \(\nu\).

The cost matrix is defined as:

\[C_{ij} = \left\| x_i - y_j \right\|_2^p\]

and the entropy term is:

\[H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij}\]

where \(\varepsilon > 0\) controls the strength of the entropic regularization.

The Sinkhorn iterations compute the optimal dual potentials \(f^\ast\) and \(g^\ast\) in log space. The regularized optimal transport cost is then recovered from the dual formulation as:

\[W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle\]

where \(a\) and \(b\) are uniform probability weights over the \(N\) input samples and \(M\) target samples, respectively.

Unlike pointwise losses, the Sinkhorn loss compares whole empirical distributions. Therefore, the output is always a scalar value.

Smaller values of eps provide a closer approximation to the true Wasserstein distance, but may require more Sinkhorn iterations to converge.

See also

Original reference: Patrini, G., Carioni, M., Forr’e, P., Bhargav, S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019). Sinkhorn AutoEncoders. In Proceedings of the 35th Conference on Uncertainty in Artificial Intelligence. URL: https://openreview.net/forum?id=BygNqoR9tm.

Initialization of the SinkhornLoss class.

Parameters:
  • p (int) – The exponent of the cost function. Default is 2.

  • eps (int | float) – The entropy regularization strength. Smaller values provide a closer approximation to the unregularized Wasserstein distance, but may require more iterations for convergence. Default is 0.1.

  • iterations (int) – The number of Sinkhorn iterations. Default is 100.

Raises:
forward(input, target)[source]#

Forward method of the loss function.

Parameters:
Returns:

The computed Sinkhorn loss value.

Return type:

torch.Tensor