Source code for pina.domain.simplex_domain

"""Module for the Simplex Domain."""

from copy import deepcopy
import torch
from .base_domain import BaseDomain
from ..label_tensor import LabelTensor
from ..utils import check_consistency


[docs] class SimplexDomain(BaseDomain): """ Implementation of the simplex domain. :Example: >>> simplex_domain = SimplexDomain( [ LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]), LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]), LabelTensor(torch.tensor([[0, 1]]), labels=["x", "y"]), ] ) """ def __init__(self, simplex_matrix, sample_surface=False): """ Initialization of the :class:`SimplexDomain` class. :param simplex_matrix: The matrix of the simplex vertices. :type simplex_matrix: list[LabelTensor] | tuple[LabelTensor] :param bool sample_surface: If ``True``, only the surface of the simplex is considered part of the domain. Default is ``False``. :raises ValueError: If any element of ``simplex_matrix`` is not a :class:`LabelTensor`. :raises TypeError: If ``simplex_matrix`` is not a list or tuple. :raises ValueError: If ``sample_surface`` is not a boolean. :raises ValueError: If the labels of the vertices do not match. :raises ValueError: If the number of vertices is not equal to the dimension of the simplex plus one. """ super().__init__() # Initialization self._sample_modes = ("random",) self.sample_surface = sample_surface self.vert_matrix = simplex_matrix
[docs] def is_inside(self, point, check_border=False): """ Check if a point is inside the simplex. :param LabelTensor point: The point to check. :param bool check_border: If ``True``, the boundary is considered inside the domain. Default is ``False``. :raises ValueError: If ``point`` is not a :class:`LabelTensor`. :raises ValueError: If the labels of ``point`` differ from the variables of the domain. :return: Whether the point is inside the domain or not. :rtype: bool """ # Checks on point check_consistency(point, LabelTensor) if set(self.variables) != set(point.labels): raise ValueError( "Point labels differ from constructor vertices labels. " f"Got {sorted(point.labels)}, expected {self.variables}." ) # Shift the point by the last vertex shift_point = point[self.variables] - self._vert_matrix[-1] shift_point = shift_point.tensor.reshape(-1, 1) # Shift the vertices by the last vertex shift_vert = (self._vert_matrix[:-1] - self._vert_matrix[-1]).T # Compute barycentric coordinates coords = torch.linalg.solve(shift_vert, shift_point) last_coord = 1.0 - torch.sum(coords) coords = torch.vstack([coords, last_coord]) # If check_border is False -- use tolerance for numerical errors if not check_border: return torch.all(coords > 1e-6) & torch.all(coords < 1 - 1e-6) return torch.all(coords >= -1e-6) & torch.all(coords <= 1 + 1e-6)
[docs] def update(self, domain): """ Update the current domain by substituting the simplex vertices with those contained in ``domain``. Only domains of the same type can be used for update. :param SimplexDomain domain: The domain whose vertices are to be set into the current one. :raises TypeError: If the domain is not a :class:`SimplexDomain` object. :return: A new domain instance with the merged labels. :rtype: SimplexDomain """ # Raise an error if the domain types do not match if not isinstance(domain, type(self)): raise TypeError( f"Cannot update domain of type {type(self)} " f"with domain of type {type(domain)}." ) # Compute new vertex matrix vert_matrix = [] for v in domain.vert_matrix: vert = v.reshape(1, -1) vert.labels = domain.variables vert_matrix.append(vert) # Replace geometry updated = deepcopy(self) updated.vert_matrix = vert_matrix return updated
[docs] def sample(self, n, mode="random", variables="all"): """ Sampling routine. :param int n: The number of samples to generate. :param str mode: The sampling method. Available modes: ``random`` for random sampling. Default is ``random``. :param variables: The list of variables to sample. If ``all``, all variables are sampled. Default is ``all``. :type variables: list[str] | str :raises AssertionError: If ``n`` is not a positive integer. :raises ValueError: If the sampling mode is invalid. :raises ValueError: If ``variables`` is neither ``all``, a string, nor a list/tuple of strings. :raises ValueError: If any of the specified variables is unknown. :return: The sampled points. :rtype: LabelTensor :Example: >>> simplex_domain = SimplexDomain( [ LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]), LabelTensor(torch.tensor([[1, 1]]), labels=["x", "y"]), LabelTensor(torch.tensor([[0, 1]]), labels=["x", "y"]), ] ) >>> simplex_domain.sample(n=5) LabelTensor([[0.0125, 0.0439], [0.1346, 0.1950], [0.8811, 0.9939], [0.2722, 0.5535], [0.4750, 0.7433]]) """ # Validate sampling settings variables = self._validate_sampling(n, mode, variables) # Extract vertex matrix for the requested variables vert_matrix = self._vert_matrix[variables].tensor # Sample barycentric coordinates using the Dirichlet distribution over # the simplex. This can be efficiently done by using samples obtained # via: -log(U(0,1)) ~ Exp(1) ~ Gamma(1, 1) ~ Dirichlet(1, ..., 1). coords = -torch.rand((n, vert_matrix.shape[0])).clamp_min(1e-12).log() # If only the surface is to be sampled if self._sample_surface: # Pick one face of the simplex at random for each point and set the # corresponding barycentric coordinate to zero. face_idx = torch.randint(0, vert_matrix.shape[0], (n,)) coords.scatter_(1, face_idx.view(-1, 1), 0.0) # Normalize the coords coords = coords / coords.sum(dim=1, keepdim=True).clamp_min(1e-12) # Prepare output pts = (coords @ vert_matrix).as_subclass(LabelTensor) pts.labels = variables return pts[sorted(pts.labels)]
[docs] def partial(self): """ Return the boundary of the domain as a new domain object. :return: The boundary of the domain. :rtype: SimplexDomain """ boundary = deepcopy(self) boundary.sample_surface = True return boundary
@property def variables(self): """ The list of variables of the domain. :return: The list of variables of the domain. :rtype: list[str] """ return sorted(self._vert_matrix.labels) @property def domain_dict(self): """ The dictionary representing the domain. For the simplex domain, the keys are of the form 'v0', 'v1', ..., 'vn', where each key corresponds to a vertex of the simplex. :return: The dictionary representing the domain. :rtype: dict """ return { f"v{i}": self._vert_matrix[i] for i in range(self._vert_matrix.shape[0]) } @property def range(self): """ Return an empty dictionary since the simplex domain does not have range variables. Implemented to comply with the :class:`BaseDomain` interface. :return: The range variables of the domain. :rtype: dict """ return {} @property def fixed(self): """ Return an empty dictionary since the simplex domain does not have fixed variables. Implemented to comply with the :class:`BaseDomain` interface. :return: The fixed variables of the domain. :rtype: dict """ return {} @property def sample_surface(self): """ Whether only the surface of the simplex is considered part of the domain. :return: ``True`` if only the surface is considered part of the domain, ``False`` otherwise. :rtype: bool """ return self._sample_surface @sample_surface.setter def sample_surface(self, value): """ Setter for the sample_surface property. :param bool value: The new value for the sample_surface property. :raises ValueError: If ``value`` is not a boolean. """ check_consistency(value, bool) self._sample_surface = value @property def vert_matrix(self): """ The vertex matrix of the simplex. :return: The vertex matrix. :rtype: LabelTensor """ return self._vert_matrix @vert_matrix.setter def vert_matrix(self, value): """ Setter for the vertex matrix. :param LabelTensor value: The new vertex matrix. :raises ValueError: If any element of ``value`` is not a :class:`LabelTensor`. :raises TypeError: If ``value`` is not a list or tuple. :raises ValueError: If the labels of the vertices do not match. :raises ValueError: If the number of vertices is not equal to the dimension of the simplex plus one. """ # Check consistency check_consistency(value, LabelTensor) if not isinstance(value, (list, tuple)): raise TypeError( "The simplex matrix must be a list or tuple of LabelTensor." ) # Check that all labels match matrix_labels = value[0].labels if not all(vert.labels == matrix_labels for vert in value): raise ValueError("Labels of all vertices must match.") # Check dimensionality if len(value) != len(matrix_labels) + 1: raise ValueError( "An n-dimensional simplex needs n+1 vertices in R^n." ) self._vert_matrix = LabelTensor.vstack(value).to(torch.float32)