Source code for pina._src.condition.condition

"""Module for the Condition class."""

from pina._src.condition.input_equation_condition import InputEquationCondition
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.condition.graph_time_series_condition import (
    GraphTimeSeriesCondition,
)
from pina._src.condition.data_condition import DataCondition
from pina._src.condition.domain_equation_condition import (
    DomainEquationCondition,
)


[docs] class Condition: """ The :class:`Condition` class is a core component of the PINA framework that provides a unified interface to define heterogeneous constraints that must be satisfied by a :class:`~pina.problem.base_problem.BaseProblem`. It encapsulates all types of constraints - physical, boundary, initial, or data-driven - that the solver must satisfy during training. The specific behavior is inferred from the arguments passed to the constructor. Multiple types of conditions can be used within the same problem, allowing for a high degree of flexibility in defining complex problems. The :class:`Condition` class behavior specializes internally based on the arguments provided during instantiation. Depending on the specified keyword arguments, the class automatically selects the appropriate internal implementation. Available `Condition` types: - :class:`~pina.condition.input_target_condition.InputTargetCondition`: represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to reproduce the ``target`` values given the ``input``. Supported data types include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`. The class automatically selects the appropriate implementation based on the types of ``input`` and ``target``. - :class:`~pina.condition.domain_equation_condition.DomainEquationCondition` : represents a general physics-informed condition defined by a ``domain`` and an ``equation``. The model learns to minimize the equation residual through evaluations performed at points sampled from the specified domain. - :class:`~pina.condition.input_equation_condition.InputEquationCondition`: represents a general physics-informed condition defined by ``input`` points and an ``equation``. The model learns to minimize the equation residual through evaluations performed at the provided ``input``. Supported data types for the ``input`` include :class:`~pina.graph.Graph` or :class:`~pina.label_tensor.LabelTensor`. The class automatically selects the appropriate implementation based on the types of the ``input``. - :class:`~pina.condition.time_series_condition.TimeSeriesCondition`: represents a condition designed for time series data, where the model is trained to capture temporal dependencies and dynamics. It is defined by an ``input`` tensor of shape ``[trajectories, time_steps, *features]`` containing time series data. Supported data types for the ``input`` include class:`~pina.label_tensor.LabelTensor` or :class:`torch.Tensor`. The class automatically selects the appropriate implementation based on the type of the ``input``. - :class:`~pina.condition.data_condition.DataCondition`: represents an unsupervised, data-driven condition defined by the ``input`` only. The model is trained using a custom unsupervised loss determined by the chosen :class:`~pina.solver.base_solver.BaseSolver`, while leveraging the provided data during training. Optional ``conditional_variables`` can be specified when the model depends on additional parameters. Supported data types include :class:`~pina.label_tensor.LabelTensor`, :class:`torch.Tensor`, :class:`~torch_geometric.data.Data`, or :class:`~pina.graph.Graph`. The class automatically selects the appropriate implementation based on the type of the ``input``. .. note:: The user should always instantiate :class:`Condition` directly, without manually creating subclass instances. Please refer to the specific :class:`Condition` classes for implementation details. :Example: >>> from pina import Condition >>> # Example of InputTargetCondition signature >>> condition = Condition(input=input, target=target) >>> # Example of DomainEquationCondition signature >>> condition = Condition(domain=domain, equation=equation) >>> # Example of InputEquationCondition signature >>> condition = Condition(input=input, equation=equation) >>> # Example of TimeSeriesCondition signature >>> condition = Condition( ... input=input, n_windows=n_windows, unroll_length=unroll_length ... ) >>> # Example of DataCondition signature >>> condition = Condition(input=data, conditional_variables=cond_vars) """ # Internal specifications for condition types, used for dispatching # Each tuple contains: (condition class, required kwargs, optional kwargs) _SPECS = ( (InputTargetCondition, {"input", "target"}, set()), (InputEquationCondition, {"input", "equation"}, set()), (DomainEquationCondition, {"domain", "equation"}, set()), (DataCondition, {"input"}, {"conditional_variables"}), ( TimeSeriesCondition, {"input", "n_windows", "unroll_length"}, {"randomize"}, ), ( GraphTimeSeriesCondition, {"input", "n_windows", "unroll_length"}, {"key", "randomize"}, ), ) # Compute the set of all available keyword arguments (optional + required) available_kwargs = sorted(set().union(*(rq | op for _, rq, op in _SPECS))) def __new__(cls, *args, **kwargs): """ Instantiate the appropriate :class:`Condition` object based on the keyword arguments passed. :param tuple args: The positional arguments (should be empty). :param dict kwargs: The keyword arguments corresponding to the parameters of the specific :class:`Condition` type to instantiate. :raises ValueError: If unexpected positional arguments are provided. :raises ValueError: If the keyword arguments do not match any valid signature for the available condition types. :return: The appropriate :class:`Condition` object. :rtype: ConditionInterface """ # Ensure no positional arguments are provided if args: raise ValueError( "Condition takes only keyword arguments. " f"Available arguments are: {cls.available_kwargs}." ) # Iterate through the specifications to find a matching condition type for condition_cls, required, optional in cls._SPECS: # Find allowed keys for condition type allowed = required | optional # Check if the provided keys match the required and optional keys if required <= set(kwargs) <= allowed: return condition_cls(**kwargs) # If no valid signature is found, prepare a list of valid signatures valid_signatures = [ sorted(required | optional) for _, required, optional in cls._SPECS ] # If no valid signature is found, raise an error raise ValueError( f"Invalid keyword arguments {sorted(set(kwargs))}. " f"Valid signatures are: {valid_signatures}." )