Source code for pina.model.block.low_rank_block

"""Module for the Low Rank Neural Operator Block class."""

import torch

from ...utils import check_consistency


[docs] class LowRankBlock(torch.nn.Module): """ The inner block of the Low Rank Neural Operator. .. seealso:: **Original reference**: Kovachki, N., Li, Z., Liu, B., Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2023). *Neural operator: Learning maps between function spaces with applications to PDEs*. Journal of Machine Learning Research, 24(89), 1-97. """ def __init__( self, input_dimensions, embedding_dimenion, rank, inner_size=20, n_layers=2, func=torch.nn.Tanh, bias=True, ): r""" Initialization of the :class:`LowRankBlock` class. :param int input_dimensions: The input dimension of the field. :param int embedding_dimenion: The embedding dimension of the field. :param int rank: The rank of the low rank approximation. The expected value is :math:`2d`, where :math:`d` is the rank of each basis function. :param int inner_size: The number of neurons for each hidden layer in the basis function neural network. Default is ``20``. :param int n_layers: The number of hidden layers in the basis function neural network. Default is ``2``. :param func: The activation function. If a list is passed, it must have the same length as ``n_layers``. If a single function is passed, it is used for all layers, except for the last one. Default is :class:`torch.nn.Tanh`. :type func: torch.nn.Module | list[torch.nn.Module] :param bool bias: If ``True`` bias is considered for the basis function neural network. Default is ``True``. """ super().__init__() from ..feed_forward import FeedForward # Assignment (check consistency inside FeedForward) self._basis = FeedForward( input_dimensions=input_dimensions, output_dimensions=2 * rank * embedding_dimenion, inner_size=inner_size, n_layers=n_layers, func=func, bias=bias, ) self._nn = torch.nn.Linear(embedding_dimenion, embedding_dimenion) check_consistency(rank, int) self._rank = rank self._func = func()
[docs] def forward(self, x, coords): r""" Forward pass of the block. It performs an affine transformation of the field, followed by a low rank approximation. The latter is performed by means of a dot product of the basis :math:`\psi^{(i)}` with the vector field :math:`v` to compute coefficients used to expand :math:`\phi^{(i)}`, evaluated in the spatial input :math:`x`. :param torch.Tensor x: The input tensor for performing the computation. :param torch.Tensor coords: The coordinates for which the field is evaluated to perform the computation. :return: The output tensor. :rtype: torch.Tensor """ # extract basis coords = coords.as_subclass(torch.Tensor) basis = self._basis(coords) # reshape [B, N, D, 2*rank] shape = list(basis.shape[:-1]) + [-1, 2 * self.rank] basis = basis.reshape(shape) # divide psi = basis[..., : self.rank] phi = basis[..., self.rank :] # compute dot product coeff = torch.einsum("...dr,...d->...r", psi, x) # expand the basis expansion = torch.einsum("...r,...dr->...d", coeff, phi) # apply linear layer and return return self._func(self._nn(x) + expansion)
@property def rank(self): """ The basis rank. :return: The basis rank. :rtype: int """ return self._rank