Source code for pina.condition.input_target_condition
"""
This module contains condition classes for supervised learning tasks.
"""
import torch
from torch_geometric.data import Data
from ..label_tensor import LabelTensor
from ..graph import Graph
from .condition_interface import ConditionInterface
[docs]
class InputTargetCondition(ConditionInterface):
"""
The :class:`InputTargetCondition` class represents a supervised condition
defined by both ``input`` and ``target`` data. The model is trained to
reproduce the ``target`` values given the ``input``. Supported data types
include :class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
The class automatically selects the appropriate implementation based on
the types of ``input`` and ``target``. Depending on whether the ``input``
and ``target`` are tensors or graph-based data, one of the following
specialized subclasses is instantiated:
- :class:`TensorInputTensorTargetCondition`: For cases where both ``input``
and ``target`` data are either :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor`.
- :class:`TensorInputGraphTargetCondition`: For cases where ``input`` is
either a :class:`torch.Tensor` or :class:`~pina.label_tensor.LabelTensor`
and ``target`` is either a :class:`~pina.graph.Graph` or a
:class:`torch_geometric.data.Data`.
- :class:`GraphInputTensorTargetCondition`: For cases where ``input`` is
either a :class:`~pina.graph.Graph` or :class:`torch_geometric.data.Data`
and ``target`` is either a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor`.
- :class:`GraphInputGraphTargetCondition`: For cases where both ``input``
and ``target`` are either :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data`.
:Example:
>>> from pina import Condition, LabelTensor
>>> from pina.graph import Graph
>>> import torch
>>> pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> edge_index = torch.randint(0, 100, (2, 300))
>>> graph = Graph(pos=pos, edge_index=edge_index)
>>> input = LabelTensor(torch.randn(100, 2), labels=["x", "y"])
>>> condition = Condition(input=input, target=graph)
"""
# Available input and target data types
__slots__ = ["input", "target"]
_avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
_avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple)
def __new__(cls, input, target):
"""
Instantiate the appropriate subclass of :class:`InputTargetCondition`
based on the types of both ``input`` and ``target`` data.
:param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: The target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:return: The subclass of InputTargetCondition.
:rtype: pina.condition.input_target_condition.
TensorInputTensorTargetCondition |
pina.condition.input_target_condition.
TensorInputGraphTargetCondition |
pina.condition.input_target_condition.
GraphInputTensorTargetCondition |
pina.condition.input_target_condition.GraphInputGraphTargetCondition
:raises ValueError: If ``input`` and/or ``target`` are not of type
:class:`torch.Tensor`, :class:`~pina.label_tensor.LabelTensor`,
:class:`~pina.graph.Graph`, or :class:`~torch_geometric.data.Data`.
"""
if cls != InputTargetCondition:
return super().__new__(cls)
# Tensor - Tensor
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (torch.Tensor, LabelTensor)
):
subclass = TensorInputTensorTargetCondition
return subclass.__new__(subclass, input, target)
# Tensor - Graph
if isinstance(input, (torch.Tensor, LabelTensor)) and isinstance(
target, (Graph, Data, list, tuple)
):
cls._check_graph_list_consistency(target)
subclass = TensorInputGraphTargetCondition
return subclass.__new__(subclass, input, target)
# Graph - Tensor
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (torch.Tensor, LabelTensor)
):
cls._check_graph_list_consistency(input)
subclass = GraphInputTensorTargetCondition
return subclass.__new__(subclass, input, target)
# Graph - Graph
if isinstance(input, (Graph, Data, list, tuple)) and isinstance(
target, (Graph, Data, list, tuple)
):
cls._check_graph_list_consistency(input)
cls._check_graph_list_consistency(target)
subclass = GraphInputGraphTargetCondition
return subclass.__new__(subclass, input, target)
# If the input and/or target are not of the correct type raise an error
raise ValueError(
"Invalid input | target types."
"Please provide either torch_geometric.data.Data, Graph, "
"LabelTensor or torch.Tensor objects."
)
def __init__(self, input, target):
"""
Initialization of the :class:`InputTargetCondition` class.
:param input: The input data for the condition.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: The target data for the condition.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
.. note::
If either ``input`` or ``target`` is a list of
:class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data`
objects, all elements in the list must share the same structure,
with matching keys and consistent data types.
"""
super().__init__()
self._check_input_target_len(input, target)
self.input = input
self.target = target
@staticmethod
def _check_input_target_len(input, target):
"""
Check that the length of the input and target lists are the same.
:param input: The input data.
:type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:param target: The target data.
:type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] |
list[Data] | tuple[Graph] | tuple[Data]
:raises ValueError: If the lengths of the input and target lists do not
match.
"""
if isinstance(input, (Graph, Data)) or isinstance(
target, (Graph, Data)
):
return
# Raise an error if the lengths of the input and target do not match
if len(input) != len(target):
raise ValueError(
"The input and target lists must have the same length."
)
[docs]
class TensorInputTensorTargetCondition(InputTargetCondition):
"""
Specialization of the :class:`InputTargetCondition` class for the case where
both ``input`` and ``target`` are :class:`torch.Tensor` or
:class:`~pina.label_tensor.LabelTensor` objects.
"""
[docs]
class TensorInputGraphTargetCondition(InputTargetCondition):
"""
Specialization of the :class:`InputTargetCondition` class for the case where
``input`` is either a :class:`torch.Tensor` or a
:class:`~pina.label_tensor.LabelTensor` object and ``target`` is either a
:class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object.
"""
[docs]
class GraphInputTensorTargetCondition(InputTargetCondition):
"""
Specialization of the :class:`InputTargetCondition` class for the case where
``input`` is either a :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data` object and ``target`` is either a
:class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object.
"""
[docs]
class GraphInputGraphTargetCondition(InputTargetCondition):
"""
Specialization of the :class:`InputTargetCondition` class for the case where
both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or
:class:`torch_geometric.data.Data` objects.
"""