Source code for pina.domain.base_domain
"""Module for the Base class for domains."""
from copy import deepcopy
from abc import ABCMeta
from .domain_interface import DomainInterface
from ..utils import check_consistency, check_positive_integer
[docs]
class BaseDomain(DomainInterface, metaclass=ABCMeta):
"""
Base class for all geometric domains, implementing common functionality.
All specific domain types should inherit from this class and implement the
abstract methods of :class:`~pina.domain.domain_interface.DomainInterface`.
This class is not meant to be instantiated directly.
"""
def __init__(self, variables_dict=None):
"""
Initialization of the :class:`BaseDomain` class.
:param variables_dict: A dictionary where the keys are the variable
names and the values are the domain extrema. The domain extrema can
be either a list or tuple with two elements or a single number. If
the domain extrema is a single number, the variable is fixed to that
value.
:type variables_dict: dict | None
:raises TypeError: If the domain dictionary is not a dictionary.
:raises ValueError: If the domain dictionary is empty.
:raises ValueError: If the domain dictionary contains variables with
invalid ranges.
:raises ValueError: If the domain dictionary contains values that are
neither numbers nor lists/tuples of numbers of length 2.
"""
# Initialize fixed and ranged variables
self._fixed = {}
self._range = {}
invalid = []
# Skip checks if variables_dict is None -- SimplexDomain case
if variables_dict is None:
return
# Check variables_dict is a dictionary
if not isinstance(variables_dict, dict):
raise TypeError(
"variables_dict must be dict: {name: number | (low, high)}"
)
# Check variables_dict is not empty
if not variables_dict:
raise ValueError(
"The dictionary defining the domain cannot be empty."
)
# Check consistency
for v in variables_dict.values():
check_consistency(v, (int, float))
# Iterate over variables_dict items
for k, v in variables_dict.items():
# Fixed variables
if isinstance(v, (int, float)):
self._fixed[k] = v
# Ranged variables
elif isinstance(v, (list, tuple)) and len(v) == 2:
low, high = v
if low >= high:
raise ValueError(
f"Invalid range for variable '{k}': "
f"low ({low}) >= high ({high})"
)
self._range[k] = (low, high)
# Save invalid keys
else:
invalid.append(k)
# Raise an error if there are invalid keys
if invalid:
raise ValueError(f"Invalid value(s) for key(s): {invalid}")
[docs]
def update(self, domain):
"""
Update the current domain by adding the labels contained in ``domain``.
Each new label introduces a new dimension. Only domains of the same type
can be used for update.
:param BaseDomain domain: The domain whose labels are to be merged
into the current one.
:raises TypeError: If the provided domain is not of the same type as
the current one.
:return: A new domain instance with the merged labels.
:rtype: BaseDomain
"""
# Raise an error if the domain types do not match
if not isinstance(domain, type(self)):
raise TypeError(
f"Cannot update domain of type {type(self)} "
f"with domain of type {type(domain)}."
)
# Update fixed and ranged variables
updated = deepcopy(self)
updated.fixed.update(domain.fixed)
updated.range.update(domain.range)
return updated
def _validate_sampling(self, n, mode, variables):
"""
Validate the sampling settings.
:param int n: The number of samples to generate.
:param str mode: The sampling method.
:param variables: The list of variables to sample. If ``all``, all
variables are sampled.
:raises AssertionError: If ``n`` is not a positive integer.
:raises ValueError: If the sampling mode is invalid.
:raises ValueError: If ``variables`` is neither ``all``, a string, nor a
list/tuple of strings.
:raises ValueError: If any of the specified variables is unknown.
:return: The validated list of variables to sample.
:rtype: list[str]
"""
# Validate n
check_positive_integer(value=n, strict=True)
# Validate mode
if mode not in self.sample_modes:
raise ValueError(
f"Invalid sampling mode: {mode}. Available: {self.sample_modes}"
)
# Validate variables
check_consistency(variables, str)
if variables == "all":
variables = self.variables
elif isinstance(variables, str):
variables = [variables]
else:
variables = list(dict.fromkeys(variables))
# Check for unknown variables
unknown = [v for v in variables if v not in self.variables]
if unknown:
raise ValueError(
f"Unknown variable(s): {unknown}. Available: {self.variables}"
)
return sorted(variables)
@property
def sample_modes(self):
"""
The list of available sampling modes.
:return: The list of available sampling modes.
:rtype: list[str]
"""
return list(self._sample_modes)
@property
def variables(self):
"""
The list of variables of the domain.
:return: The list of variables of the domain.
:rtype: list[str]
"""
return sorted(list(self._fixed.keys()) + list(self._range.keys()))
@property
def domain_dict(self):
"""
The dictionary representing the domain.
:return: The dictionary representing the domain.
:rtype: dict
"""
return {**self._fixed, **self._range}
@property
def range(self):
"""
The range variables of the domain.
:return: The range variables of the domain.
:rtype: dict
"""
return self._range
@property
def fixed(self):
"""
The fixed variables of the domain.
:return: The fixed variables of the domain.
:rtype: dict
"""
return self._fixed