Source code for pina.model.average_neural_operator

"""Module for the Averaging Neural Operator model class."""

import torch
from torch import nn
from .block.average_neural_operator_block import AVNOBlock
from .kernel_neural_operator import KernelNeuralOperator
from ..utils import check_consistency


[docs] class AveragingNeuralOperator(KernelNeuralOperator): """ Averaging Neural Operator model class. The Averaging Neural Operator is a general architecture for learning operators, which map functions to functions. It can be trained both with Supervised and Physics-Informed learning strategies. The Averaging Neural Operator performs convolution by means of a field average. .. seealso:: **Original reference**: Lanthaler S., Li, Z., Stuart, A. (2020). *The Nonlocal Neural Operator: Universal Approximation*. DOI: `arXiv preprint arXiv:2304.13221. <https://arxiv.org/abs/2304.13221>`_ """ def __init__( self, lifting_net, projecting_net, field_indices, coordinates_indices, n_layers=4, func=nn.GELU, ): """ Initialization of the :class:`AveragingNeuralOperator` class. :param torch.nn.Module lifting_net: The lifting neural network mapping the input to its hidden dimension. It must take as input the input field and the coordinates at which the input field is evaluated. :param torch.nn.Module projecting_net: The projection neural network mapping the hidden representation to the output function. It must take as input the embedding dimension plus the dimension of the coordinates. :param list[str] field_indices: The labels of the fields in the input tensor. :param list[str] coordinates_indices: The labels of the coordinates in the input tensor. :param int n_layers: The number of hidden layers. Default is ``4``. :param torch.nn.Module func: The activation function to use. Default is :class:`torch.nn.GELU`. :raises ValueError: If the input dimension does not match with the labels of the fields and coordinates. :raises ValueError: If the input dimension of the projecting network does not match with the hidden dimension of the lifting network. """ # check consistency check_consistency(field_indices, str) check_consistency(coordinates_indices, str) check_consistency(n_layers, int) check_consistency(func, nn.Module, subclass=True) # check hidden dimensions match input_lifting_net = next(lifting_net.parameters()).size()[-1] output_lifting_net = lifting_net( torch.rand(size=next(lifting_net.parameters()).size()) ).shape[-1] projecting_net_input = next(projecting_net.parameters()).size()[-1] if len(field_indices) + len(coordinates_indices) != input_lifting_net: raise ValueError( "The lifting_net must take as input the " "coordinates vector and the field vector." ) if ( output_lifting_net + len(coordinates_indices) != projecting_net_input ): raise ValueError( "The projecting_net input must be equal to" "the embedding dimension (which is the output) " "of the lifting_net plus the dimension of the " "coordinates, i.e. len(coordinates_indices)." ) # assign self.coordinates_indices = coordinates_indices self.field_indices = field_indices integral_net = nn.Sequential( *[AVNOBlock(output_lifting_net, func) for _ in range(n_layers)] ) super().__init__(lifting_net, integral_net, projecting_net)
[docs] def forward(self, x): r""" Forward pass for the :class:`AveragingNeuralOperator` model. The ``lifting_net`` maps the input to the hidden dimension. Then, several layers of :class:`~pina.model.block.average_neural_operator_block.AVNOBlock` are applied. Finally, the ``projection_net`` maps the hidden representation to the output function. :param LabelTensor x: The input tensor for performing the computation. It expects a tensor :math:`B \times N \times D`, where :math:`B` is the batch_size, :math:`N` the number of points in the mesh, :math:`D` the dimension of the problem, i.e. the sum of ``len(coordinates_indices)`` and ``len(field_indices)``. :return: The output tensor. :rtype: torch.Tensor """ points_tmp = x.extract(self.coordinates_indices) new_batch = x.extract(self.field_indices) new_batch = torch.cat((new_batch, points_tmp), dim=-1) new_batch = self._lifting_operator(new_batch) new_batch = self._integral_kernels(new_batch) new_batch = torch.cat((new_batch, points_tmp), dim=-1) new_batch = self._projection_operator(new_batch) return new_batch