Network#
- class Network(model, input_variables, output_variables, extra_features=None)[source]#
Bases:
Module
Network class with standard forward method and possibility to pass extra features. This class is used internally in PINA to convert any
torch.nn.Module
s in a PINA module.- Parameters:
model (torch.nn.Module) – The torch model to convert in a PINA model.
input_variables (list(str)) – The input variables of the
AbstractProblem
, whose type depends on the type of domain (spatial, temporal, and parameter).output_variables (list(str)) – The output variables of the
AbstractProblem
, whose type depends on the problem setting.extra_features (list(torch.nn.Module)) – List of torch models to augment the input, defaults to None.
- forward(x)[source]#
Forward method for Network class. This class implements the standard forward method, and it adds the possibility to pass extra features. All the PINA models
forward
s are overriden by this class, to enablepina.label_tensor.LabelTensor
labels extraction.- Parameters:
x (torch.Tensor) – Input of the network.
- Return torch.Tensor:
Output of the network.
- forward_map(x)[source]#
Forward method for Network class when the input is a tuple. This class is simply a forward with the input casted as a tuple or list :class`torch.Tensor`. All the PINA models
forward
s are overriden by this class, to enablepina.label_tensor.LabelTensor
labels extraction.- Parameters:
x (list (torch.Tensor) | tuple(torch.Tensor)) – Input of the network.
- Return torch.Tensor:
Output of the network.
Note
This function does not extract the input variables, all the variables are used for both tensors. Output variables are correctly applied.