Source code for pina._src.adaptive_function.base_adaptive_function

"""Module for the Adaptive Function base class."""

import torch
from pina._src.core.utils import check_consistency
from pina._src.adaptive_function.adaptive_function_interface import (
    AdaptiveFunctionInterface,
)


[docs] class BaseAdaptiveFunction(torch.nn.Module, AdaptiveFunctionInterface): r""" Base class for all adaptive functions, implementing common functionality. This class extends a standard :class:`torch.nn.Module` activation function into a trainable adaptive form. It implements the common mechanism used to scale and shift both the input and the output of a given activation function. Given a function :math:`f:\mathbb{R}^n\rightarrow\mathbb{R}^m`, the adaptive function :math:`f_{\text{adaptive}}:\mathbb{R}^n\rightarrow\mathbb{R}^m` is defined as: .. math:: f_{\text{adaptive}}(\mathbf{x}) = \alpha\,f(\beta\mathbf{x}+\gamma), where :math:`\alpha`, :math:`\beta`, and :math:`\gamma` are learnable parameters controlling output scaling, input scaling, and input shifting, respectively. All specific adaptive functions should inherit from this class and implement the abstract methods declared in the interface. This class is not meant to be instantiated directly. .. seealso:: **Original reference**: Godfrey, L. B., Gashler, M. S. (2015). *A continuum among logarithmic, linear, and exponential functions, and its potential to improve generalization in neural networks.* 7th international joint conference on knowledge discovery, knowledge engineering and knowledge management (IC3K), Vol. 1. DOI: `arXiv preprint arXiv:1602.01321. <https://arxiv.org/abs/1602.01321>`_. **Original reference**: Jagtap, A. D., Karniadakis, G. E. (2020). *Adaptive activation functions accelerate convergence in deep and physics-informed neural networks*. Journal of Computational Physics, 404. DOI: `JCP 10.1016 <https://doi.org/10.1016/j.jcp.2019.109136>`_. """ def __init__(self, alpha=None, beta=None, gamma=None, fixed=None): """ Initialization of the :class:`BaseAdaptiveFunction` class. :param alpha: The output scaling parameter of the adaptive function. If ``None``, it is initialized to ``1``. Default is ``None``. :type alpha: int | float :param beta: The input scaling parameter of the adaptive function. If ``None``, it is initialized to ``1``. Default is ``None``. :type beta: int | float :param gamma: The input shifting parameter of the adaptive function. If ``None``, it is initialized to ``0``. Default is ``None``. :type gamma: int | float :param fixed: The names of parameters to keep fixed during training. These parameters will not be optimized and will have ``requires_grad=False``. Available options are ``"alpha"``, ``"beta"``, and ``"gamma"``. If ``None``, all parameters are trainable. Default is ``None``. :type fixed: str | list[str] :raises ValueError: If alpha, when provided, is not a number. :raises ValueError: If beta, when provided, is not a number. :raises ValueError: If gamma, when provided, is not a number. :raises ValueError: If fixed, when provided, is neither a string nor a list of strings. :raises ValueError: If fixed contains invalid parameter names. """ super().__init__() # Set default values for alpha, beta, gamma if they are None alpha = 1.0 if alpha is None else alpha beta = 1.0 if beta is None else beta gamma = 0.0 if gamma is None else gamma # Check consistency check_consistency(alpha, (int, float)) check_consistency(beta, (int, float)) check_consistency(gamma, (int, float)) # Process fixed parameters if fixed is not None: check_consistency(fixed, str) fixed = {fixed} if isinstance(fixed, str) else set(fixed) else: fixed = set() # Validate fixed parameter names invalid_names = fixed - {"alpha", "beta", "gamma"} if invalid_names: raise ValueError( f"Invalid fixed parameter name(s): {sorted(invalid_names)}. " "Available options are 'alpha', 'beta', and 'gamma'." ) # Register either a trainable parameter or a fixed buffer def _register_adaptive_param(name, value): """ Helper function to register an adaptive parameter as either a trainable parameter or a fixed buffer, depending on whether it is specified in the ``fixed`` argument. """ # Convert value to tensor tensor = torch.tensor(value, dtype=torch.float32) # Register as buffer if fixed, otherwise as parameter if name in fixed: self.register_buffer(f"_{name}", tensor) else: setattr(self, f"_{name}", torch.nn.Parameter(tensor)) # Register parameters _register_adaptive_param("alpha", alpha) _register_adaptive_param("beta", beta) _register_adaptive_param("gamma", gamma) # Initialize the adaptive function to None, to be set by subclasses self._func = None
[docs] def forward(self, x): """ Compute the transformation of the adaptive function on the input. :param x: The input tensor to evaluate the adaptive function. :type x: torch.Tensor | LabelTensor :raises RuntimeError: If the adaptive function has not been set. :raises RuntimeError: If the adaptive function is not callable. :return: The output of the adaptive function. :rtype: torch.Tensor | LabelTensor """ # Raise an error if the adaptive function has not been set if self._func is None: raise RuntimeError("The adaptive function has not been set.") # Raise an error if the adaptive function is not callable if not callable(self._func): raise RuntimeError("The adaptive function is not callable.") return self.alpha * (self._func(self.beta * x + self.gamma))
@property def alpha(self): """ The output scaling parameter of the adaptive function. :return: The alpha parameter. :rtype: torch.nn.Parameter | torch.Tensor """ return self._alpha @property def beta(self): """ The input scaling parameter of the adaptive function. :return: The beta parameter. :rtype: torch.nn.Parameter | torch.Tensor """ return self._beta @property def gamma(self): """ The input shifting parameter of the adaptive function. :return: The gamma parameter. :rtype: torch.nn.Parameter | torch.Tensor """ return self._gamma