Source code for pina._src.equation.zoo.fixed_laplacian
"""Module for defining the fixed laplacian equation."""
import warnings
from pina._src.equation.equation import Equation
from pina._src.core.operator import laplacian
from pina._src.core.utils import check_consistency
[docs]
class FixedLaplacian(Equation):
"""
Equation to enforce a fixed laplacian for a specific condition.
"""
def __init__(self, value, components=None, d=None):
"""
Initialization of the :class:`FixedLaplacian` class.
:param value: The fixed value to be enforced to the laplacian.
:type value: float | int
:param components: The name of the output variables for which the fixed
laplace condition is applied. It should be a subset of the output
labels. If ``None``, all output variables are considered. Default is
``None``.
:type components: str | list[str]
:param d: The name of the input variables on which the laplacian is
computed. It should be a subset of the input labels. If ``None``,
all the input variables are considered. Default is ``None``.
:type d: str | list[str]
:raises ValueError: If ``value`` is neither a float nor an integer.
:raises ValueError: If, when provided, ``components`` is neither a
string nor a list of strings.
:raises ValueError: If, when provided, ``d`` is neither a string nor a
list of strings.
"""
# Check consistency
check_consistency(value, (float, int))
if components is not None:
check_consistency(components, str)
if d is not None:
check_consistency(d, str)
def equation(input_, output_):
"""
Definition of the equation to enforce a fixed laplacian.
:param LabelTensor input_: The input points where the residual is
computed.
:param LabelTensor output_: The output tensor, potentially produced
by a :class:`torch.nn.Module` instance.
:return: The residual values of the equation.
:rtype: LabelTensor
"""
return (
laplacian(output_, input_, components=components, d=d) - value
)
super().__init__(equation)
# Back-compatibility with version 0.2, to be removed soon
[docs]
class Laplace(FixedLaplacian):
def __init__(self, components=None, d=None):
warnings.warn(
"Laplace is deprecated, use FixedLaplacian with value=0.0 instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(0.0, components=components, d=d)