Source code for pina.model.network
import torch
import torch.nn as nn
from ..utils import check_consistency
from ..label_tensor import LabelTensor
[docs]
class Network(torch.nn.Module):
def __init__(
self, model, input_variables, output_variables, extra_features=None
):
"""
Network class with standard forward method
and possibility to pass extra features. This
class is used internally in PINA to convert
any :class:`torch.nn.Module` s in a PINA module.
:param model: The torch model to convert in a PINA model.
:type model: torch.nn.Module
:param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the
type of domain (spatial, temporal, and parameter).
:param list(str) output_variables: The output variables of the :class:`AbstractProblem`, whose type depends on the
problem setting.
:param extra_features: List of torch models to augment the input, defaults to None.
:type extra_features: list(torch.nn.Module)
"""
super().__init__()
# check model consistency
check_consistency(model, nn.Module)
check_consistency(input_variables, str)
check_consistency(output_variables, str)
self._model = model
self._input_variables = input_variables
self._output_variables = output_variables
# check consistency and assign extra fatures
if extra_features is None:
self._extra_features = []
else:
for feat in extra_features:
check_consistency(feat, nn.Module)
self._extra_features = nn.Sequential(*extra_features)
# check model works with inputs
# TODO
[docs]
def forward(self, x):
"""
Forward method for Network class. This class
implements the standard forward method, and
it adds the possibility to pass extra features.
All the PINA models ``forward`` s are overriden
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
extraction.
:param torch.Tensor x: Input of the network.
:return torch.Tensor: Output of the network.
"""
# only labeltensors as input
assert isinstance(
x, LabelTensor
), "Expected LabelTensor as input to the model."
# extract torch.Tensor from corresponding label
# in case `input_variables = []` all points are used
if self._input_variables:
x = x.extract(self._input_variables)
# extract features and append
for feature in self._extra_features:
x = x.append(feature(x))
# perform forward pass + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
return output
# TODO to remove in next releases (only used in GAROM solver)
[docs]
def forward_map(self, x):
"""
Forward method for Network class when the input is
a tuple. This class is simply a forward with the input casted as a
tuple or list :class`torch.Tensor`.
All the PINA models ``forward`` s are overriden
by this class, to enable :class:`pina.label_tensor.LabelTensor` labels
extraction.
:param list (torch.Tensor) | tuple(torch.Tensor) x: Input of the network.
:return torch.Tensor: Output of the network.
.. note::
This function does not extract the input variables, all the variables
are used for both tensors. Output variables are correctly applied.
"""
# convert LabelTensor s to torch.Tensor s
x = list(map(lambda x: x.as_subclass(torch.Tensor), x))
# perform forward pass (using torch.Tensor) + converting to LabelTensor
output = self._model(x).as_subclass(LabelTensor)
# set the labels for LabelTensor
output.labels = self._output_variables
return output
@property
def torchmodel(self):
return self._model
@property
def extra_features(self):
return self._extra_features