Source code for pina.model.network

import torch
import torch.nn as nn
from ..utils import check_consistency
from ..label_tensor import LabelTensor


[docs] class Network(torch.nn.Module): def __init__( self, model, input_variables, output_variables, extra_features=None ): """ Network class with standard forward method and possibility to pass extra features. This class is used internally in PINA to convert any :class:`torch.nn.Module` s in a PINA module. :param model: The torch model to convert in a PINA model. :type model: torch.nn.Module :param list(str) input_variables: The input variables of the :class:`AbstractProblem`, whose type depends on the type of domain (spatial, temporal, and parameter). :param list(str) output_variables: The output variables of the :class:`AbstractProblem`, whose type depends on the problem setting. :param extra_features: List of torch models to augment the input, defaults to None. :type extra_features: list(torch.nn.Module) """ super().__init__() # check model consistency check_consistency(model, nn.Module) check_consistency(input_variables, str) check_consistency(output_variables, str) self._model = model self._input_variables = input_variables self._output_variables = output_variables # check consistency and assign extra fatures if extra_features is None: self._extra_features = [] else: for feat in extra_features: check_consistency(feat, nn.Module) self._extra_features = nn.Sequential(*extra_features) # check model works with inputs # TODO
[docs] def forward(self, x): """ Forward method for Network class. This class implements the standard forward method, and it adds the possibility to pass extra features. All the PINA models ``forward`` s are overriden by this class, to enable :class:`pina.label_tensor.LabelTensor` labels extraction. :param torch.Tensor x: Input of the network. :return torch.Tensor: Output of the network. """ # only labeltensors as input assert isinstance( x, LabelTensor ), "Expected LabelTensor as input to the model." # extract torch.Tensor from corresponding label # in case `input_variables = []` all points are used if self._input_variables: x = x.extract(self._input_variables) # extract features and append for feature in self._extra_features: x = x.append(feature(x)) # perform forward pass + converting to LabelTensor output = self._model(x).as_subclass(LabelTensor) # set the labels for LabelTensor output.labels = self._output_variables return output
# TODO to remove in next releases (only used in GAROM solver)
[docs] def forward_map(self, x): """ Forward method for Network class when the input is a tuple. This class is simply a forward with the input casted as a tuple or list :class`torch.Tensor`. All the PINA models ``forward`` s are overriden by this class, to enable :class:`pina.label_tensor.LabelTensor` labels extraction. :param list (torch.Tensor) | tuple(torch.Tensor) x: Input of the network. :return torch.Tensor: Output of the network. .. note:: This function does not extract the input variables, all the variables are used for both tensors. Output variables are correctly applied. """ # convert LabelTensor s to torch.Tensor s x = list(map(lambda x: x.as_subclass(torch.Tensor), x)) # perform forward pass (using torch.Tensor) + converting to LabelTensor output = self._model(x).as_subclass(LabelTensor) # set the labels for LabelTensor output.labels = self._output_variables return output
@property def torchmodel(self): return self._model @property def extra_features(self): return self._extra_features