Source code for pina.model.low_rank_neural_operator

"""Module for the Low Rank Neural Operator model class."""

import torch
from torch import nn

from ..utils import check_consistency

from .kernel_neural_operator import KernelNeuralOperator
from .block.low_rank_block import LowRankBlock


[docs] class LowRankNeuralOperator(KernelNeuralOperator): """ Low Rank Neural Operator model class. The Low Rank Neural Operator is a general architecture for learning operators, which map functions to functions. It can be trained both with Supervised and Physics-Informed learning strategies. The Low Rank Neural Operator performs convolution by means of a low rank approximation. .. seealso:: **Original reference**: Kovachki, N., Li, Z., Liu, B., Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2023). *Neural operator: Learning maps between function spaces with applications to PDEs*. Journal of Machine Learning Research, 24(89), 1-97. """ def __init__( self, lifting_net, projecting_net, field_indices, coordinates_indices, n_kernel_layers, rank, inner_size=20, n_layers=2, func=torch.nn.Tanh, bias=True, ): """ Initialization of the :class:`LowRankNeuralOperator` class. :param torch.nn.Module lifting_net: The lifting neural network mapping the input to its hidden dimension. It must take as input the input field and the coordinates at which the input field is evaluated. :param torch.nn.Module projecting_net: The projection neural network mapping the hidden representation to the output function. It must take as input the embedding dimension plus the dimension of the coordinates. :param list[str] field_indices: The labels of the fields in the input tensor. :param list[str] coordinates_indices: The labels of the coordinates in the input tensor. :param int n_kernel_layers: The number of hidden kernel layers. :param int rank: The rank of the low rank approximation. :param int inner_size: The number of neurons for each hidden layer in the basis function neural network. Default is ``20``. :param int n_layers: The number of hidden layers in the basis function neural network. Default is ``2``. :param func: The activation function. If a list is passed, it must have the same length as ``n_layers``. If a single function is passed, it is used for all layers, except for the last one. Default is :class:`torch.nn.Tanh`. :type func: torch.nn.Module | list[torch.nn.Module] :param bool bias: If ``True`` bias is considered for the basis function neural network. Default is ``True``. :raises ValueError: If the input dimension does not match with the labels of the fields and coordinates. :raises ValueError: If the input dimension of the projecting network does not match with the hidden dimension of the lifting network. """ # check consistency check_consistency(field_indices, str) check_consistency(coordinates_indices, str) check_consistency(n_kernel_layers, int) # check hidden dimensions match input_lifting_net = next(lifting_net.parameters()).size()[-1] output_lifting_net = lifting_net( torch.rand(size=next(lifting_net.parameters()).size()) ).shape[-1] projecting_net_input = next(projecting_net.parameters()).size()[-1] if len(field_indices) + len(coordinates_indices) != input_lifting_net: raise ValueError( "The lifting_net must take as input the " "coordinates vector and the field vector." ) if ( output_lifting_net + len(coordinates_indices) != projecting_net_input ): raise ValueError( "The projecting_net input must be equal to " "the embedding dimension (which is the output) " "of the lifting_net plus the dimension of the " "coordinates, i.e. len(coordinates_indices)." ) # assign self.coordinates_indices = coordinates_indices self.field_indices = field_indices integral_net = nn.Sequential( *[ LowRankBlock( input_dimensions=len(coordinates_indices), embedding_dimenion=output_lifting_net, rank=rank, inner_size=inner_size, n_layers=n_layers, func=func, bias=bias, ) for _ in range(n_kernel_layers) ] ) super().__init__(lifting_net, integral_net, projecting_net)
[docs] def forward(self, x): r""" Forward pass for the :class:`LowRankNeuralOperator` model. The ``lifting_net`` maps the input to the hidden dimension. Then, several layers of :class:`~pina.model.block.low_rank_block.LowRankBlock` are applied. Finally, the ``projecting_net`` maps the hidden representation to the output function. :param LabelTensor x: The input tensor for performing the computation. It expects a tensor :math:`B \times N \times D`, where :math:`B` is the batch_size, :math:`N` the number of points in the mesh, :math:`D` the dimension of the problem, i.e. the sum of ``len(coordinates_indices)`` and ``len(field_indices)``. :return: The output tensor. :rtype: torch.Tensor """ # extract points coords = x.extract(self.coordinates_indices) # lifting x = self._lifting_operator(x) # kernel for module in self._integral_kernels: x = module(x, coords) # projecting return self._projection_operator(torch.cat((x, coords), dim=-1))