"""Module for the Simplex Domain."""
from copy import deepcopy
import torch
from .base_domain import BaseDomain
from ..label_tensor import LabelTensor
from ..utils import check_consistency
[docs]
class SimplexDomain(BaseDomain):
"""
Implementation of the simplex domain.
:Example:
>>> simplex_domain = SimplexDomain(
[
LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[0, 1]]), labels=["x", "y"]),
]
)
"""
def __init__(self, simplex_matrix, sample_surface=False):
"""
Initialization of the :class:`SimplexDomain` class.
:param simplex_matrix: The matrix of the simplex vertices.
:type simplex_matrix: list[LabelTensor] | tuple[LabelTensor]
:param bool sample_surface: If ``True``, only the surface of the simplex
is considered part of the domain. Default is ``False``.
:raises ValueError: If any element of ``simplex_matrix`` is not a
:class:`LabelTensor`.
:raises TypeError: If ``simplex_matrix`` is not a list or tuple.
:raises ValueError: If ``sample_surface`` is not a boolean.
:raises ValueError: If the labels of the vertices do not match.
:raises ValueError: If the number of vertices is not equal to the
dimension of the simplex plus one.
"""
super().__init__()
# Initialization
self._sample_modes = ("random",)
self.sample_surface = sample_surface
self.vert_matrix = simplex_matrix
[docs]
def is_inside(self, point, check_border=False):
"""
Check if a point is inside the simplex.
:param LabelTensor point: The point to check.
:param bool check_border: If ``True``, the boundary is considered inside
the domain. Default is ``False``.
:raises ValueError: If ``point`` is not a :class:`LabelTensor`.
:raises ValueError: If the labels of ``point`` differ from the variables
of the domain.
:return: Whether the point is inside the domain or not.
:rtype: bool
"""
# Checks on point
check_consistency(point, LabelTensor)
if set(self.variables) != set(point.labels):
raise ValueError(
"Point labels differ from constructor vertices labels. "
f"Got {sorted(point.labels)}, expected {self.variables}."
)
# Shift the point by the last vertex
shift_point = point[self.variables] - self._vert_matrix[-1]
shift_point = shift_point.tensor.reshape(-1, 1)
# Shift the vertices by the last vertex
shift_vert = (self._vert_matrix[:-1] - self._vert_matrix[-1]).T
# Compute barycentric coordinates
coords = torch.linalg.solve(shift_vert, shift_point)
last_coord = 1.0 - torch.sum(coords)
coords = torch.vstack([coords, last_coord])
# If check_border is False -- use tolerance for numerical errors
if not check_border:
return torch.all(coords > 1e-6) & torch.all(coords < 1 - 1e-6)
return torch.all(coords >= -1e-6) & torch.all(coords <= 1 + 1e-6)
[docs]
def update(self, domain):
"""
Update the current domain by substituting the simplex vertices with
those contained in ``domain``. Only domains of the same type can be used
for update.
:param SimplexDomain domain: The domain whose vertices are to be set
into the current one.
:raises TypeError: If the domain is not a :class:`SimplexDomain` object.
:return: A new domain instance with the merged labels.
:rtype: SimplexDomain
"""
# Raise an error if the domain types do not match
if not isinstance(domain, type(self)):
raise TypeError(
f"Cannot update domain of type {type(self)} "
f"with domain of type {type(domain)}."
)
# Compute new vertex matrix
vert_matrix = []
for v in domain.vert_matrix:
vert = v.reshape(1, -1)
vert.labels = domain.variables
vert_matrix.append(vert)
# Replace geometry
updated = deepcopy(self)
updated.vert_matrix = vert_matrix
return updated
[docs]
def sample(self, n, mode="random", variables="all"):
"""
Sampling routine.
:param int n: The number of samples to generate.
:param str mode: The sampling method. Available modes: ``random`` for
random sampling. Default is ``random``.
:param variables: The list of variables to sample. If ``all``, all
variables are sampled. Default is ``all``.
:type variables: list[str] | str
:raises AssertionError: If ``n`` is not a positive integer.
:raises ValueError: If the sampling mode is invalid.
:raises ValueError: If ``variables`` is neither ``all``, a string, nor a
list/tuple of strings.
:raises ValueError: If any of the specified variables is unknown.
:return: The sampled points.
:rtype: LabelTensor
:Example:
>>> simplex_domain = SimplexDomain(
[
LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[0, 1]]), labels=["x", "y"]),
]
)
>>> simplex_domain.sample(n=5)
LabelTensor([[0.0125, 0.0439],
[0.1346, 0.1950],
[0.8811, 0.9939],
[0.2722, 0.5535],
[0.4750, 0.7433]])
"""
# Validate sampling settings
variables = self._validate_sampling(n, mode, variables)
# Extract vertex matrix for the requested variables
vert_matrix = self._vert_matrix[variables].tensor
# Sample barycentric coordinates using the Dirichlet distribution over
# the simplex. This can be efficiently done by using samples obtained
# via: -log(U(0,1)) ~ Exp(1) ~ Gamma(1, 1) ~ Dirichlet(1, ..., 1).
coords = -torch.rand((n, vert_matrix.shape[0])).clamp_min(1e-12).log()
# If only the surface is to be sampled
if self._sample_surface:
# Pick one face of the simplex at random for each point and set the
# corresponding barycentric coordinate to zero.
face_idx = torch.randint(0, vert_matrix.shape[0], (n,))
coords.scatter_(1, face_idx.view(-1, 1), 0.0)
# Normalize the coords
coords = coords / coords.sum(dim=1, keepdim=True).clamp_min(1e-12)
# Prepare output
pts = (coords @ vert_matrix).as_subclass(LabelTensor)
pts.labels = variables
return pts[sorted(pts.labels)]
[docs]
def partial(self):
"""
Return the boundary of the domain as a new domain object.
:return: The boundary of the domain.
:rtype: SimplexDomain
"""
boundary = deepcopy(self)
boundary.sample_surface = True
return boundary
@property
def variables(self):
"""
The list of variables of the domain.
:return: The list of variables of the domain.
:rtype: list[str]
"""
return sorted(self._vert_matrix.labels)
@property
def domain_dict(self):
"""
The dictionary representing the domain. For the simplex domain, the keys
are of the form 'v0', 'v1', ..., 'vn', where each key corresponds to a
vertex of the simplex.
:return: The dictionary representing the domain.
:rtype: dict
"""
return {
f"v{i}": self._vert_matrix[i]
for i in range(self._vert_matrix.shape[0])
}
@property
def range(self):
"""
Return an empty dictionary since the simplex domain does not have range
variables. Implemented to comply with the :class:`BaseDomain` interface.
:return: The range variables of the domain.
:rtype: dict
"""
return {}
@property
def fixed(self):
"""
Return an empty dictionary since the simplex domain does not have fixed
variables. Implemented to comply with the :class:`BaseDomain` interface.
:return: The fixed variables of the domain.
:rtype: dict
"""
return {}
@property
def sample_surface(self):
"""
Whether only the surface of the simplex is considered part of the
domain.
:return: ``True`` if only the surface is considered part of the domain,
``False`` otherwise.
:rtype: bool
"""
return self._sample_surface
@sample_surface.setter
def sample_surface(self, value):
"""
Setter for the sample_surface property.
:param bool value: The new value for the sample_surface property.
:raises ValueError: If ``value`` is not a boolean.
"""
check_consistency(value, bool)
self._sample_surface = value
@property
def vert_matrix(self):
"""
The vertex matrix of the simplex.
:return: The vertex matrix.
:rtype: LabelTensor
"""
return self._vert_matrix
@vert_matrix.setter
def vert_matrix(self, value):
"""
Setter for the vertex matrix.
:param LabelTensor value: The new vertex matrix.
:raises ValueError: If any element of ``value`` is not a
:class:`LabelTensor`.
:raises TypeError: If ``value`` is not a list or tuple.
:raises ValueError: If the labels of the vertices do not match.
:raises ValueError: If the number of vertices is not equal to the
dimension of the simplex plus one.
"""
# Check consistency
check_consistency(value, LabelTensor)
if not isinstance(value, (list, tuple)):
raise TypeError(
"The simplex matrix must be a list or tuple of LabelTensor."
)
# Check that all labels match
matrix_labels = value[0].labels
if not all(vert.labels == matrix_labels for vert in value):
raise ValueError("Labels of all vertices must match.")
# Check dimensionality
if len(value) != len(matrix_labels) + 1:
raise ValueError(
"An n-dimensional simplex needs n+1 vertices in R^n."
)
self._vert_matrix = LabelTensor.vstack(value).to(torch.float32)