"""Module for the SinkhornLoss class."""
import torch
from pina._src.loss.base_dual_loss import BaseDualLoss
from pina._src.core.utils import check_consistency, check_positive_integer
[docs]
class SinkhornLoss(BaseDualLoss):
r"""
Implementation of the Sinkhorn loss measuring the entropy-regularized
optimal transport distance between two empirical distributions.
Given an input tensor :math:`x` with :math:`N` samples and a target tensor
:math:`y` with :math:`M` samples, both in :math:`\mathbb{R}^D`, the loss is
defined through the entropy-regularized optimal transport problem:
.. math::
W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)}
\langle C, \pi \rangle - \varepsilon H(\pi)
where :math:`\mu` and :math:`\nu` are the empirical distributions associated
with :math:`x` and :math:`y`, :math:`\pi` is a transport plan, and
:math:`\Pi(\mu, \nu)` is the set of admissible transport plans with
marginals :math:`\mu` and :math:`\nu`.
The cost matrix is defined as:
.. math::
C_{ij} = \left\| x_i - y_j \right\|_2^p
and the entropy term is:
.. math::
H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij}
where :math:`\varepsilon > 0` controls the strength of the entropic
regularization.
The Sinkhorn iterations compute the optimal dual potentials :math:`f^\ast`
and :math:`g^\ast` in log space. The regularized optimal transport cost is
then recovered from the dual formulation as:
.. math::
W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle
where :math:`a` and :math:`b` are uniform probability weights over the
:math:`N` input samples and :math:`M` target samples, respectively.
Unlike pointwise losses, the Sinkhorn loss compares whole empirical
distributions. Therefore, the output is always a scalar value.
Smaller values of ``eps`` provide a closer approximation to the true
Wasserstein distance, but may require more Sinkhorn iterations to converge.
.. seealso::
**Original reference:** Patrini, G., Carioni, M., Forr'e, P., Bhargav,
S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019).
*Sinkhorn AutoEncoders*.
In Proceedings of the 35th Conference on Uncertainty in Artificial
Intelligence.
URL: `<https://openreview.net/forum?id=BygNqoR9tm>`_.
"""
def __init__(self, p=2, eps=0.1, iterations=100):
"""
Initialization of the :class:`SinkhornLoss` class.
:param int p: The exponent of the cost function. Default is ``2``.
:param eps: The entropy regularization strength. Smaller values provide
a closer approximation to the unregularized Wasserstein distance,
but may require more iterations for convergence. Default is ``0.1``.
:type eps: int | float
:param int iterations: The number of Sinkhorn iterations.
Default is ``100``.
:raises AssertionError: If ``iterations`` is not a positive integer.
:raises AssertionError: If ``p`` is not a positive integer.
:raises ValueError: If ``eps`` is not a positive numeric value.
"""
# Initialize the base class with mean reduction
super().__init__(reduction="mean")
# Check consistency
check_positive_integer(iterations, strict=True)
check_positive_integer(p, strict=True)
check_consistency(eps, (int, float))
if eps <= 0:
raise ValueError(
f"Expected 'eps' to be strictly positive, but got {eps}."
)
# Initialize parameters
self.iterations = iterations
self.eps = eps
self.p = p
[docs]
def forward(self, input, target):
"""
Forward method of the loss function.
:param torch.Tensor input: The input tensor.
:param torch.Tensor target: The target tensor.
:return: The computed Sinkhorn loss value.
:rtype: torch.Tensor
"""
# Extract the number of samples in input and target
n, m = input.shape[0], target.shape[0]
# Initialize log-uniform weights for the empirical distributions
log_a = -input.new_tensor(n).log().expand(n)
log_b = -target.new_tensor(m).log().expand(m)
# Initialize dual potentials f and g
f = torch.zeros(n, dtype=input.dtype, device=input.device)
g = torch.zeros(m, dtype=target.dtype, device=target.device)
# Define the cost matrix, shape (n, m)
C = torch.cdist(input, target, p=self.p) ** self.p
# Perform Sinkhorn iterations in log space for numerical stability
for _ in range(self.iterations):
# Update dual potential f with the softmin operation in log space
softmin_f = torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1)
f = self.eps * (log_a - softmin_f)
# Update dual potential g with the softmin operation in log space
softmin_g = torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0)
g = self.eps * (log_b - softmin_g)
# Compute the Sinkhorn loss as the sum of the means of f and g
loss = f.mean() + g.mean()
return self._reduction(loss.unsqueeze(0))