Source code for pina.condition.input_target_condition
"""
This module contains condition classes for supervised learning tasks.
"""
import torch
from torch_geometric.data import Data
from ..label_tensor import LabelTensor
from ..graph import Graph
from .condition_interface import ConditionInterface
[docs]
class InputTargetCondition(ConditionInterface):
"""
Condition defined by input and target data. This condition can be used in
both supervised learning and Physics-informed problems. Based on the type of
the input and target, different condition implementations are available:
- :class:`TensorInputTensorTargetCondition`: For :class:`torch.Tensor` or \
:class:`~pina.label_tensor.LabelTensor` input and target data.
- :class:`TensorInputGraphTargetCondition`: For :class:`torch.Tensor` or \
:class:`~pina.label_tensor.LabelTensor` input and \
:class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data` \
target data.
- :class:`GraphInputTensorTargetCondition`: For :class:`~pina.graph.Graph` \
or :class:`~torch_geometric.data.Data` input and :class:`torch.Tensor` \
or :class:`~pina.label_tensor.LabelTensor` target data.
- :class:`GraphInputGraphTargetCondition`: For :class:`~pina.graph.Graph` \
or :class:`~torch_geometric.data.Data` input and target data.
"""
__slots__ = ["input", "target"]
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
_avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
def __new__(cls, input, target):
"""
Instantiate the appropriate subclass of InputTargetCondition based on
the types of input and target data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: Target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:return: Subclass of InputTargetCondition
:rtype: pina.condition.input_target_condition.
TensorInputTensorTargetCondition |
pina.condition.input_target_condition.
TensorInputGraphTargetCondition |
pina.condition.input_target_condition.
GraphInputTensorTargetCondition |
pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If ``input`` and/or ``target`` are not of type
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
"""
if cls != InputTargetCondition:
return super().__new__(cls)
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (torch.Tensor, LabelTensor)
):
subclass = TensorInputTensorTargetCondition
return subclass.__new__(subclass, input, target)
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (Graph, Data, list, tuple)
):
cls._check_graph_list_consistency(target)
subclass = TensorInputGraphTargetCondition
return subclass.__new__(subclass, input, target)
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (torch.Tensor, LabelTensor)
):
cls._check_graph_list_consistency(input)
subclass = GraphInputTensorTargetCondition
return subclass.__new__(subclass, input, target)
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (Graph, Data, list, tuple)
):
cls._check_graph_list_consistency(input)
cls._check_graph_list_consistency(target)
subclass = GraphInputGraphTargetCondition
return subclass.__new__(subclass, input, target)
raise ValueError(
"Invalid input/target types. "
"Please provide either torch_geometric.data.Data, Graph, "
"LabelTensor or torch.Tensor objects."
)
def __init__(self, input, target):
"""
Initialize the object by storing the ``input`` and ``target`` data.
:param input: Input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: Target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
.. note::
If either input or target consists of a list of
:class:~pina.graph.Graph or :class:~torch_geometric.data.Data
objects, all elements must have the same structure (matching
keys and data types).
"""
super().__init__()
self._check_input_target_len(input, target)
self.input = input
self.target = target
@staticmethod
def _check_input_target_len(input, target):
if isinstance(input, (Graph, Data)) or isinstance(
target, (Graph, Data)
):
return
if len(input) != len(target):
raise ValueError(
"The input and target lists must have the same length."
)
[docs]
class TensorInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` ``input`` and ``target`` data.
"""
[docs]
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` ``input`` and
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` `target`
data.
"""
[docs]
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`~pina.graph.Graph` o
:class:`~torch_geometric.data.Data` ``input`` and :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` ``target`` data.
"""
[docs]
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
InputTargetCondition subclass for :class:`~pina.graph.Graph`/
:class:`~torch_geometric.data.Data` ``input`` and ``target`` data.
"""