FNO#
- class FNO(lifting_net, projecting_net, n_modes, dimensions=3, padding=8, padding_type='constant', inner_size=20, n_layers=2, func=<class 'torch.nn.modules.activation.Tanh'>, layers=None)[source]#
Bases:
KernelNeuralOperator
Fourier Neural Operator model class.
The Fourier Neural Operator (FNO) is a general architecture for learning operators, which map functions to functions. It can be trained both with Supervised and Physics_Informed learning strategies. The Fourier Neural Operator performs global convolution in the Fourier space.
See also
Original reference: Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2020). Fourier neural operator for parametric partial differential equations. DOI: arXiv preprint arXiv:2010.08895.
- Parameters:
lifting_net (torch.nn.Module) – The lifting neural network mapping the input to its hidden dimension.
projecting_net (torch.nn.Module) – The projection neural network mapping the hidden representation to the output function.
dimensions (int) – The number of dimensions. It can be set to
1
,2
, or3
. Default is3
.padding (int) – The padding size. Default is
8
.padding_type (str) – The padding strategy. Default is
constant
.inner_size (int) – The inner size. Default is
20
.n_layers (int) – The number of layers. Default is
2
.func (torch.nn.Module | list[torch.nn.Module]) – The activation function. If a list is passed, it must have the same length as
n_layers
. If a single function is passed, it is used for all layers, except for the last one. Default istorch.nn.Tanh
.layers (list[int]) – The list of the dimension of inner layers. If
None
,n_layers
of dimensioninner_size
are used. Otherwise, it overrides the values passed ton_layers
andinner_size
. Default isNone
.
- forward(x)[source]#
Forward pass for the
FourierNeuralOperator
model.The
lifting_net
maps the input to the hidden dimension. Then, several layers of Fourier blocks are applied. Finally, theprojection_net
maps the hidden representation to the output function.- Parameters:
x (torch.Tensor | LabelTensor) –
The input tensor for performing the computation. Depending on the
dimensions
in the initialization, it expects a tensor with the following shapes:1D tensors:
[batch, X, channels]
2D tensors:
[batch, X, Y, channels]
3D tensors:
[batch, X, Y, Z, channels]
- Returns:
The output tensor.
- Return type: