PirateNet#

class PirateNet(input_dimension, inner_size, output_dimension, embedding=None, n_layers=3, activation=<class 'torch.nn.modules.activation.Tanh'>)[source]#

Bases: Module

Implementation of Physics-Informed residual adaptive network (PirateNet).

The model consists of a Fourier feature embedding layer, multiple PirateNet blocks, and a final output layer. Each PirateNet block consist of three dense layers with dual gating mechanism and an adaptive residual connection, whose contribution is controlled by a trainable parameter alpha.

The PirateNet, augmented with random weight factorization, is designed to mitigate spectral bias in deep networks.

See also

Original reference: Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). Simulating Three-dimensional Turbulence with Physics-informed Neural Networks. DOI: arXiv preprint arXiv:2507.08972.

Initialization of the PirateNet class.

Parameters:
  • input_dimension (int) – The number of input features.

  • inner_size (int) – The number of hidden units in the dense layers.

  • output_dimension (int) – The number of output features.

  • embedding (torch.nn.Module) – The embedding module used to transform the input into a higher-dimensional feature space. If None, a default FourierFeatureEmbedding with scaling factor of 2 is used. Default is None.

  • n_layers (int) – The number of PirateNet blocks in the model. Default is 3.

  • activation (torch.nn.Module) – The activation function to be used in the blocks. Default is torch.nn.Tanh.

forward(input_)[source]#

Forward pass of the PirateNet model. It applies the Fourier feature embedding, computes the shared gating tensors U and V, and passes the input through each block in the network. Finally, it applies the output layer to produce the final output.

Parameters:

input (torch.Tensor | LabelTensor) – The input tensor for the model.

Returns:

The output tensor of the model.

Return type:

torch.Tensor | LabelTensor

property alpha#

Return the alpha values of all PirateNetBlock layers.

Returns:

A list of alpha values from each block.

Return type:

list