Input Target Condition#
- class InputTargetCondition(input, target)[source]#
Bases:
ConditionInterface
The
InputTargetCondition
class represents a supervised condition defined by bothinput
andtarget
data. The model is trained to reproduce thetarget
values given theinput
. Supported data types includetorch.Tensor
,LabelTensor
,Graph
, orData
.The class automatically selects the appropriate implementation based on the types of
input
andtarget
. Depending on whether theinput
andtarget
are tensors or graph-based data, one of the following specialized subclasses is instantiated:TensorInputTensorTargetCondition
: For cases where bothinput
andtarget
data are eithertorch.Tensor
orLabelTensor
.TensorInputGraphTargetCondition
: For cases whereinput
is either atorch.Tensor
orLabelTensor
andtarget
is either aGraph
or atorch_geometric.data.Data
.GraphInputTensorTargetCondition
: For cases whereinput
is either aGraph
ortorch_geometric.data.Data
andtarget
is either atorch.Tensor
or aLabelTensor
.GraphInputGraphTargetCondition
: For cases where bothinput
andtarget
are eitherGraph
ortorch_geometric.data.Data
.
- Example:
>>> from pina import Condition, LabelTensor >>> from pina.graph import Graph >>> import torch
>>> pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) >>> edge_index = torch.randint(0, 100, (2, 300)) >>> graph = Graph(pos=pos, edge_index=edge_index)
>>> input = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) >>> condition = Condition(input=input, target=graph)
Initialization of the
InputTargetCondition
class.- Parameters:
input (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The input data for the condition.
target (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The target data for the condition.
- class TensorInputTensorTargetCondition(input, target)[source]#
Bases:
InputTargetCondition
Specialization of the
InputTargetCondition
class for the case where bothinput
andtarget
aretorch.Tensor
orLabelTensor
objects.Initialization of the
InputTargetCondition
class.- Parameters:
input (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The input data for the condition.
target (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The target data for the condition.
- class TensorInputGraphTargetCondition(input, target)[source]#
Bases:
InputTargetCondition
Specialization of the
InputTargetCondition
class for the case whereinput
is either atorch.Tensor
or aLabelTensor
object andtarget
is either aGraph
or atorch_geometric.data.Data
object.Initialization of the
InputTargetCondition
class.- Parameters:
input (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The input data for the condition.
target (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The target data for the condition.
- class GraphInputTensorTargetCondition(input, target)[source]#
Bases:
InputTargetCondition
Specialization of the
InputTargetCondition
class for the case whereinput
is either aGraph
ortorch_geometric.data.Data
object andtarget
is either atorch.Tensor
or aLabelTensor
object.Initialization of the
InputTargetCondition
class.- Parameters:
input (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The input data for the condition.
target (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The target data for the condition.
- class GraphInputGraphTargetCondition(input, target)[source]#
Bases:
InputTargetCondition
Specialization of the
InputTargetCondition
class for the case where bothinput
andtarget
are eitherGraph
ortorch_geometric.data.Data
objects.Initialization of the
InputTargetCondition
class.- Parameters:
input (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The input data for the condition.
target (torch.Tensor | LabelTensor | Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data]) – The target data for the condition.