Source code for pina.condition.condition

"""Module for the Condition class."""

from .data_condition import DataCondition
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputEquationCondition
from .input_target_condition import InputTargetCondition


[docs] class Condition: """ The :class:`Condition` class is a core component of the PINA framework that provides a unified interface to define heterogeneous constraints that must be satisfied by a :class:`~pina.problem.abstract_problem.AbstractProblem`. It encapsulates all types of constraints - physical, boundary, initial, or data-driven - that the solver must satisfy during training. The specific behavior is inferred from the arguments passed to the constructor. Multiple types of conditions can be used within the same problem, allowing for a high degree of flexibility in defining complex problems. The :class:`Condition` class behavior specializes internally based on the arguments provided during instantiation. Depending on the specified keyword arguments, the class automatically selects the appropriate internal implementation. Available `Condition` types: - :class:`~pina.condition.input_target_condition.InputTargetCondition`: represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to reproduce the ``target`` values given the ``input``. Supported data types include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. The class automatically selects the appropriate implementation based on the types of ``input`` and ``target``. - :class:`~pina.condition.domain_equation_condition.DomainEquationCondition` : represents a general physics-informed condition defined by a ``domain`` and an ``equation``. The model learns to minimize the equation residual through evaluations performed at points sampled from the specified domain. - :class:`~pina.condition.input_equation_condition.InputEquationCondition`: represents a general physics-informed condition defined by ``input`` points and an ``equation``. The model learns to minimize the equation residual through evaluations performed at the provided ``input``. Supported data types for the ``input`` include :class:`~pina.label_tensor.LabelTensor` or :class:`~pina.graph.Graph`. The class automatically selects the appropriate implementation based on the types of the ``input``. - :class:`~pina.condition.data_condition.DataCondition`: represents an unsupervised, data-driven condition defined by the ``input`` only. The model is trained using a custom unsupervised loss determined by the chosen :class:`~pina.solver.solver.SolverInterface`, while leveraging the provided data during training. Optional ``conditional_variables`` can be specified when the model depends on additional parameters. Supported data types include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. The class automatically selects the appropriate implementation based on the type of the ``input``. .. note:: The user should always instantiate :class:`Condition` directly, without manually creating subclass instances. Please refer to the specific :class:`Condition` classes for implementation details. :Example: >>> from pina import Condition >>> # Example of InputTargetCondition signature >>> condition = Condition(input=input, target=target) >>> # Example of DomainEquationCondition signature >>> condition = Condition(domain=domain, equation=equation) >>> # Example of InputEquationCondition signature >>> condition = Condition(input=input, equation=equation) >>> # Example of DataCondition signature >>> condition = Condition(input=data, conditional_variables=cond_vars) """ # Combine all possible keyword arguments from the different Condition types __slots__ = list( set( InputTargetCondition.__slots__ + InputEquationCondition.__slots__ + DomainEquationCondition.__slots__ + DataCondition.__slots__ ) ) def __new__(cls, *args, **kwargs): """ Instantiate the appropriate :class:`Condition` object based on the keyword arguments passed. :param tuple args: The positional arguments (should be empty). :param dict kwargs: The keyword arguments corresponding to the parameters of the specific :class:`Condition` type to instantiate. :raises ValueError: If unexpected positional arguments are provided. :raises ValueError: If the keyword arguments are invalid. :return: The appropriate :class:`Condition` object. :rtype: ConditionInterface """ # Check keyword arguments if len(args) != 0: raise ValueError( "Condition takes only the following keyword " f"arguments: {Condition.__slots__}." ) # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) # Input - Target Condition if sorted_keys == sorted(InputTargetCondition.__slots__): return InputTargetCondition(**kwargs) # Input - Equation Condition if sorted_keys == sorted(InputEquationCondition.__slots__): return InputEquationCondition(**kwargs) # Domain - Equation Condition if sorted_keys == sorted(DomainEquationCondition.__slots__): return DomainEquationCondition(**kwargs) # Data Condition if ( sorted_keys == sorted(DataCondition.__slots__) or sorted_keys[0] == DataCondition.__slots__[0] ): return DataCondition(**kwargs) # Invalid keyword arguments raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")