Low Rank Neural Operator Block#

class LowRankBlock(input_dimensions, embedding_dimenion, rank, inner_size=20, n_layers=2, func=<class 'torch.nn.modules.activation.Tanh'>, bias=True)[source]

Bases: Module

The inner block of the Low Rank Neural Operator.

See also

Original reference: Kovachki, N., Li, Z., Liu, B., Azizzadenesheli, K., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2023). Neural operator: Learning maps between function spaces with applications to PDEs. Journal of Machine Learning Research, 24(89), 1-97.

Initialization of the LowRankBlock class.

Parameters:
  • input_dimensions (int) – The input dimension of the field.

  • embedding_dimenion (int) – The embedding dimension of the field.

  • rank (int) – The rank of the low rank approximation. The expected value is \(2d\), where \(d\) is the rank of each basis function.

  • inner_size (int) – The number of neurons for each hidden layer in the basis function neural network. Default is 20.

  • n_layers (int) – The number of hidden layers in the basis function neural network. Default is 2.

  • func (torch.nn.Module | list[torch.nn.Module]) – The activation function. If a list is passed, it must have the same length as n_layers. If a single function is passed, it is used for all layers, except for the last one. Default is torch.nn.Tanh.

  • bias (bool) – If True bias is considered for the basis function neural network. Default is True.

forward(x, coords)[source]

Forward pass of the block. It performs an affine transformation of the field, followed by a low rank approximation. The latter is performed by means of a dot product of the basis \(\psi^{(i)}\) with the vector field \(v\) to compute coefficients used to expand \(\phi^{(i)}\), evaluated in the spatial input \(x\).

Parameters:
  • x (torch.Tensor) – The input tensor for performing the computation.

  • coords (torch.Tensor) – The coordinates for which the field is evaluated to perform the computation.

Returns:

The output tensor.

Return type:

torch.Tensor

property rank

The basis rank.

Returns:

The basis rank.

Return type:

int