Source code for pina.model.pirate_network
"""Module for the PirateNet model class."""
import torch
from .block import FourierFeatureEmbedding, PirateNetBlock
from ..utils import check_consistency, check_positive_integer
[docs]
class PirateNet(torch.nn.Module):
"""
Implementation of Physics-Informed residual adaptive network (PirateNet).
The model consists of a Fourier feature embedding layer, multiple PirateNet
blocks, and a final output layer. Each PirateNet block consist of three
dense layers with dual gating mechanism and an adaptive residual connection,
whose contribution is controlled by a trainable parameter ``alpha``.
The PirateNet, augmented with random weight factorization, is designed to
mitigate spectral bias in deep networks.
.. seealso::
**Original reference**:
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
*Simulating Three-dimensional Turbulence with Physics-informed Neural
Networks*.
DOI: `arXiv preprint arXiv:2507.08972.
<https://arxiv.org/abs/2507.08972>`_
"""
def __init__(
self,
input_dimension,
inner_size,
output_dimension,
embedding=None,
n_layers=3,
activation=torch.nn.Tanh,
):
"""
Initialization of the :class:`PirateNet` class.
:param int input_dimension: The number of input features.
:param int inner_size: The number of hidden units in the dense layers.
:param int output_dimension: The number of output features.
:param torch.nn.Module embedding: The embedding module used to transform
the input into a higher-dimensional feature space. If ``None``, a
default :class:`~pina.model.block.FourierFeatureEmbedding` with
scaling factor of 2 is used. Default is ``None``.
:param int n_layers: The number of PirateNet blocks in the model.
Default is 3.
:param torch.nn.Module activation: The activation function to be used in
the blocks. Default is :class:`torch.nn.Tanh`.
"""
super().__init__()
# Check consistency
check_consistency(activation, torch.nn.Module, subclass=True)
check_positive_integer(input_dimension, strict=True)
check_positive_integer(inner_size, strict=True)
check_positive_integer(output_dimension, strict=True)
check_positive_integer(n_layers, strict=True)
# Initialize the activation function
self.activation = activation()
# Initialize the Fourier embedding
self.embedding = embedding or FourierFeatureEmbedding(
input_dimension=input_dimension,
output_dimension=inner_size,
sigma=2.0,
)
# Initialize the shared dense layers
self.linear1 = torch.nn.Linear(inner_size, inner_size)
self.linear2 = torch.nn.Linear(inner_size, inner_size)
# Initialize the PirateNet blocks
self.blocks = torch.nn.ModuleList(
[PirateNetBlock(inner_size, activation) for _ in range(n_layers)]
)
# Initialize the output layer
self.output_layer = torch.nn.Linear(inner_size, output_dimension)
[docs]
def forward(self, input_):
"""
Forward pass of the PirateNet model. It applies the Fourier feature
embedding, computes the shared gating tensors U and V, and passes the
input through each block in the network. Finally, it applies the output
layer to produce the final output.
:param input_: The input tensor for the model.
:type input_: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
# Apply the Fourier feature embedding
x = self.embedding(input_)
# Compute U and V from the shared dense layers
U = self.activation(self.linear1(x))
V = self.activation(self.linear2(x))
# Pass through each block in the network
for block in self.blocks:
x = block(x, U, V)
return self.output_layer(x)
@property
def alpha(self):
"""
Return the alpha values of all PirateNetBlock layers.
:return: A list of alpha values from each block.
:rtype: list
"""
return [block.alpha.item() for block in self.blocks]