Source code for pina.model.block.spectral

"""Module for spectral convolution blocks."""

import torch
from torch import nn
from ...utils import check_consistency


######## 1D Spectral Convolution ###########
[docs] class SpectralConvBlock1D(nn.Module): """ Spectral Convolution Block for one-dimensional tensors. This class computes the spectral convolution of the input with a linear kernel in the fourier space, and then it maps the input back to the physical space. The block expects an input of size [``batch``, ``input_numb_fields``, ``N``] and returns an output of size [``batch``, ``output_numb_fields``, ``N``]. """ def __init__(self, input_numb_fields, output_numb_fields, n_modes): r""" Initialization of the :class:`SpectralConvBlock1D` class. :param int input_numb_fields: The number of channels for the input. :param int output_numb_fields: The number of channels for the output. :param int n_modes: The number of modes to select for each dimension. It must be at most equal to :math:`\floor(Nx/2)+1`. """ super().__init__() # check type consistency check_consistency(input_numb_fields, int) check_consistency(output_numb_fields, int) # assign variables self._modes = n_modes self._input_channels = input_numb_fields self._output_channels = output_numb_fields # scaling factor scale = 1.0 / (self._input_channels * self._output_channels) self._weights = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes, dtype=torch.cfloat, ) ) def _compute_mult1d(self, input, weights): """ Compute the matrix multiplication of the input and the linear kernel weights. :param torch.Tensor input: The input tensor. Expected of size [``batch``, ``input_numb_fields``, ``N``]. :param torch.Tensor weights: The kernel weights. Expected of size [``input_numb_fields``, ``output_numb_fields``, ``N``]. :return: The result of the matrix multiplication. :rtype: torch.Tensor """ return torch.einsum("bix,iox->box", input, weights)
[docs] def forward(self, x): """ Forward pass. :param torch.Tensor x: The input tensor. Expected of size [``batch``, ``input_numb_fields``, ``N``]. :return: The input tensor. Expected of size [``batch``, ``output_numb_fields``, ``N``]. :rtype: torch.Tensor """ batch_size = x.shape[0] # Compute Fourier transform of the input x_ft = torch.fft.rfft(x) # Multiply relevant Fourier modes out_ft = torch.zeros( batch_size, self._output_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat, ) out_ft[:, :, : self._modes] = self._compute_mult1d( x_ft[:, :, : self._modes], self._weights ) # Return to physical space return torch.fft.irfft(out_ft, n=x.size(-1))
######## 2D Spectral Convolution ###########
[docs] class SpectralConvBlock2D(nn.Module): """ Spectral Convolution Block for two-dimensional tensors. This class computes the spectral convolution of the input with a linear kernel in the fourier space, and then it maps the input back to the physical space. The block expects an input of size [``batch``, ``input_numb_fields``, ``Nx``, ``Ny``] and returns an output of size [``batch``, ``output_numb_fields``, ``Nx``, ``Ny``]. """ def __init__(self, input_numb_fields, output_numb_fields, n_modes): r""" Initialization of the :class:`SpectralConvBlock2D` class. :param int input_numb_fields: The number of channels for the input. :param int output_numb_fields: The number of channels for the output. :param n_modes: The number of modes to select for each dimension. It must be at most equal to :math:`\floor(Nx/2)+1`, :math:`\floor(Ny/2)+1`. :type n_modes: list[int] | tuple[int] :raises ValueError: If the number of modes is not consistent. :raises ValueError: If the number of modes is not a list or tuple. """ super().__init__() # check type consistency check_consistency(input_numb_fields, int) check_consistency(output_numb_fields, int) check_consistency(n_modes, int) if isinstance(n_modes, (tuple, list)): if len(n_modes) != 2: raise ValueError( "Expected n_modes to be a list or tuple of len two, " "with each entry corresponding to the number of modes " "for each dimension " ) elif isinstance(n_modes, int): n_modes = [n_modes] * 2 else: raise ValueError( "Expected n_modes to be a list or tuple of len two, " "with each entry corresponding to the number of modes " "for each dimension; or an int value representing the " "number of modes for all dimensions" ) # assign variables self._modes = n_modes self._input_channels = input_numb_fields self._output_channels = output_numb_fields # scaling factor scale = 1.0 / (self._input_channels * self._output_channels) self._weights1 = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes[0], self._modes[1], dtype=torch.cfloat, ) ) self._weights2 = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes[0], self._modes[1], dtype=torch.cfloat, ) ) def _compute_mult2d(self, input, weights): """ Compute the matrix multiplication of the input and the linear kernel weights. :param torch.Tensor input: The input tensor. Expected of size [``batch``, ``input_numb_fields``, ``Nx``, ``Ny``]. :param torch.Tensor weights: The kernel weights. Expected of size [``input_numb_fields``, ``output_numb_fields``, ``Nx``, ``Ny``]. :return: The result of the matrix multiplication. :rtype: torch.Tensor """ return torch.einsum("bixy,ioxy->boxy", input, weights)
[docs] def forward(self, x): """ Forward pass. :param torch.Tensor x: The input tensor. Expected of size [``batch``, ``input_numb_fields``, ``Nx``, ``Ny``]. :return: The input tensor. Expected of size [``batch``, ``output_numb_fields``, ``Nx``, ``Ny``]. :rtype: torch.Tensor """ batch_size = x.shape[0] # Compute Fourier transform of the input x_ft = torch.fft.rfft2(x) # Multiply relevant Fourier modes out_ft = torch.zeros( batch_size, self._output_channels, x.size(-2), x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat, ) out_ft[:, :, : self._modes[0], : self._modes[1]] = self._compute_mult2d( x_ft[:, :, : self._modes[0], : self._modes[1]], self._weights1 ) out_ft[:, :, -self._modes[0] :, : self._modes[1] :] = ( self._compute_mult2d( x_ft[:, :, -self._modes[0] :, : self._modes[1]], self._weights2 ) ) # Return to physical space return torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
######## 3D Spectral Convolution ###########
[docs] class SpectralConvBlock3D(nn.Module): """ Spectral Convolution Block for three-dimensional tensors. This class computes the spectral convolution of the input with a linear kernel in the fourier space, and then it maps the input back to the physical space. The block expects an input of size [``batch``, ``input_numb_fields``, ``Nx``, ``Ny``, ``Nz``] and returns an output of size [``batch``, ``output_numb_fields``, ``Nx``, ``Ny``, ``Nz``]. """ def __init__(self, input_numb_fields, output_numb_fields, n_modes): r""" Initialization of the :class:`SpectralConvBlock3D` class. :param int input_numb_fields: The number of channels for the input. :param int output_numb_fields: The number of channels for the output. :param n_modes: The number of modes to select for each dimension. It must be at most equal to :math:`\floor(Nx/2)+1`, :math:`\floor(Ny/2)+1`, :math:`\floor(Nz/2)+1`. :type n_modes: list[int] | tuple[int] :raises ValueError: If the number of modes is not consistent. :raises ValueError: If the number of modes is not a list or tuple. """ super().__init__() # check type consistency check_consistency(input_numb_fields, int) check_consistency(output_numb_fields, int) check_consistency(n_modes, int) if isinstance(n_modes, (tuple, list)): if len(n_modes) != 3: raise ValueError( "Expected n_modes to be a list or tuple of len three, " "with each entry corresponding to the number of modes " "for each dimension " ) elif isinstance(n_modes, int): n_modes = [n_modes] * 3 else: raise ValueError( "Expected n_modes to be a list or tuple of len three, " "with each entry corresponding to the number of modes " "for each dimension; or an int value representing the " "number of modes for all dimensions" ) # assign variables self._modes = n_modes self._input_channels = input_numb_fields self._output_channels = output_numb_fields # scaling factor scale = 1.0 / (self._input_channels * self._output_channels) self._weights1 = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes[0], self._modes[1], self._modes[2], dtype=torch.cfloat, ) ) self._weights2 = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes[0], self._modes[1], self._modes[2], dtype=torch.cfloat, ) ) self._weights3 = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes[0], self._modes[1], self._modes[2], dtype=torch.cfloat, ) ) self._weights4 = nn.Parameter( scale * torch.rand( self._input_channels, self._output_channels, self._modes[0], self._modes[1], self._modes[2], dtype=torch.cfloat, ) ) def _compute_mult3d(self, input, weights): """ Compute the matrix multiplication of the input and the linear kernel weights. :param torch.Tensor input: The input tensor. Expected of size [``batch``, ``input_numb_fields``, ``Nx``, ``Ny``, ``Nz``]. :param torch.Tensor weights: The kernel weights. Expected of size [``input_numb_fields``, ``output_numb_fields``, ``Nx``, ``Ny``, ``Nz``]. :return: The result of the matrix multiplication. :rtype: torch.Tensor """ return torch.einsum("bixyz,ioxyz->boxyz", input, weights)
[docs] def forward(self, x): """ Forward pass. :param torch.Tensor x: The input tensor. Expected of size [``batch``, ``input_numb_fields``, ``Nx``, ``Ny``, ``Nz``]. :return: The input tensor. Expected of size [``batch``, ``output_numb_fields``, ``Nx``, ``Ny``, ``Nz``]. :rtype: torch.Tensor """ batch_size = x.shape[0] # Compute Fourier transform of the input x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1]) # Multiply relevant Fourier modes out_ft = torch.zeros( batch_size, self._output_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat, ) slice0 = ( slice(None), slice(None), slice(self._modes[0]), slice(self._modes[1]), slice(self._modes[2]), ) out_ft[slice0] = self._compute_mult3d(x_ft[slice0], self._weights1) slice1 = ( slice(None), slice(None), slice(self._modes[0]), slice(-self._modes[1], None), slice(self._modes[2]), ) out_ft[slice1] = self._compute_mult3d(x_ft[slice1], self._weights2) slice2 = ( slice(None), slice(None), slice(-self._modes[0], None), slice(self._modes[1]), slice(self._modes[2]), ) out_ft[slice2] = self._compute_mult3d(x_ft[slice2], self._weights3) slice3 = ( slice(None), slice(None), slice(-self._modes[0], None), slice(-self._modes[1], None), slice(self._modes[2]), ) out_ft[slice3] = self._compute_mult3d(x_ft[slice3], self._weights4) # Return to physical space return torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1)))