Sinkhorn Loss#
Module for the SinkhornLoss class.
- class SinkhornLoss(p=2, eps=0.1, iterations=100)[source]#
Bases:
BaseDualLossImplementation of the Sinkhorn loss measuring the entropy-regularized optimal transport distance between two empirical distributions.
Given an input tensor \(x\) with \(N\) samples and a target tensor \(y\) with \(M\) samples, both in \(\mathbb{R}^D\), the loss is defined through the entropy-regularized optimal transport problem:
\[W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} \langle C, \pi \rangle - \varepsilon H(\pi)\]where \(\mu\) and \(\nu\) are the empirical distributions associated with \(x\) and \(y\), \(\pi\) is a transport plan, and \(\Pi(\mu, \nu)\) is the set of admissible transport plans with marginals \(\mu\) and \(\nu\).
The cost matrix is defined as:
\[C_{ij} = \left\| x_i - y_j \right\|_2^p\]and the entropy term is:
\[H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij}\]where \(\varepsilon > 0\) controls the strength of the entropic regularization.
The Sinkhorn iterations compute the optimal dual potentials \(f^\ast\) and \(g^\ast\) in log space. The regularized optimal transport cost is then recovered from the dual formulation as:
\[W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle\]where \(a\) and \(b\) are uniform probability weights over the \(N\) input samples and \(M\) target samples, respectively.
Unlike pointwise losses, the Sinkhorn loss compares whole empirical distributions. Therefore, the output is always a scalar value.
Smaller values of
epsprovide a closer approximation to the true Wasserstein distance, but may require more Sinkhorn iterations to converge.See also
Original reference: Patrini, G., Carioni, M., Forr’e, P., Bhargav, S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019). Sinkhorn AutoEncoders. In Proceedings of the 35th Conference on Uncertainty in Artificial Intelligence. URL: https://openreview.net/forum?id=BygNqoR9tm.
Initialization of the
SinkhornLossclass.- Parameters:
p (int) – The exponent of the cost function. Default is
2.eps (int | float) – The entropy regularization strength. Smaller values provide a closer approximation to the unregularized Wasserstein distance, but may require more iterations for convergence. Default is
0.1.iterations (int) – The number of Sinkhorn iterations. Default is
100.
- Raises:
AssertionError – If
iterationsis not a positive integer.AssertionError – If
pis not a positive integer.ValueError – If
epsis not a positive numeric value.
- forward(input, target)[source]#
Forward method of the loss function.
- Parameters:
input (torch.Tensor) – The input tensor.
target (torch.Tensor) – The target tensor.
- Returns:
The computed Sinkhorn loss value.
- Return type: