"""Module for the DeepONet and MIONet model classes."""
from functools import partial
import torch
from torch import nn
from ..utils import check_consistency, is_function
[docs]
class MIONet(torch.nn.Module):
"""
MIONet model class.
The MIONet is a general architecture for learning operators, which map
functions to functions. It can be trained with both Supervised and
Physics-Informed learning strategies.
.. seealso::
**Original reference**: Jin, P., Meng, S., and Lu L. (2022).
*MIONet: Learning multiple-input operators via tensor product.*
SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351
DOI: `10.1137/22M1477751 <https://doi.org/10.1137/22M1477751>`_
"""
def __init__(
self,
networks,
aggregator="*",
reduction="+",
scale=True,
translation=True,
):
"""
Initialization of the :class:`MIONet` class.
:param dict networks: The neural networks to use as models. The ``dict``
takes as key a neural network, and as value the list of indeces to
extract from the input variable in the forward pass of the neural
network. If a ``list[int]`` is passed, the corresponding columns of
the inner most entries are extracted. If a ``list[str]`` is passed
the variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
Each :class:`torch.nn.Module` model has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
Default implementation consists of several branch nets and one
trunk nets.
:param aggregator: The aggregator to be used to aggregate component-wise
partial results from the modules in ``networks``. Available
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
min: ``min``, max: ``max``. Default is ``*``.
:type aggregator: str or Callable
:param reduction: The reduction to be used to reduce the aggregated
result of the modules in ``networks`` to the desired output
dimension. Available reductions include: sum: ``+``, product: ``*``,
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
:type reduction: str or Callable
:param bool scale: If ``True``, the final output is scaled before being
returned in the forward pass. Default is ``True``.
:param bool translation: If ``True``, the final output is translated
before being returned in the forward pass. Default is ``True``.
:raises ValueError: If the passed networks have not the same output
dimension.
.. warning::
No checks are performed in the forward pass to verify if the input
is instance of either :class:`~pina.label_tensor.LabelTensor` or
:class:`torch.Tensor`. In general, in case of a
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
``list[str]`` can be passed as ``networks`` dict values.
Differently, in case of a :class:`torch.Tensor`, only a
``list[int]`` can be passed as ``networks`` dict values.
:Example:
>>> branch_net1 = FeedForward(input_dimensons=1,
... output_dimensions=10)
>>> branch_net2 = FeedForward(input_dimensons=2,
... output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> networks = {branch_net1 : ['x'],
branch_net2 : ['x', 'y'],
... trunk_net : ['z']}
>>> model = MIONet(networks=networks,
... reduction='+',
... aggregator='*')
>>> model
MIONet(
(models): ModuleList(
(0): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(1): FeedForward(
(model): Sequential(
(0): Linear(in_features=2, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(2): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
)
"""
super().__init__()
# check type consistency
check_consistency(networks, dict)
check_consistency(scale, bool)
check_consistency(translation, bool)
# check trunk branch nets consistency
shapes = []
for key, value in networks.items():
check_consistency(value, (str, int))
check_consistency(key, torch.nn.Module)
input_ = torch.rand(10, len(value))
shapes.append(key(input_).shape[-1])
if not all(map(lambda x: x == shapes[0], shapes)):
raise ValueError(
"The passed networks have not the same output dimension."
)
# assign trunk and branch net with their input indeces
self.models = torch.nn.ModuleList(networks.keys())
self._indeces = networks.values()
# initializie aggregation
self._init_aggregator(aggregator=aggregator)
self._init_reduction(reduction=reduction)
# scale and translation
self._scale = (
torch.nn.Parameter(torch.tensor([1.0]))
if scale
else torch.tensor([1.0])
)
self._trasl = (
torch.nn.Parameter(torch.tensor([1.0]))
if translation
else torch.tensor([1.0])
)
@staticmethod
def _symbol_functions(**kwargs):
"""
Return a dictionary of functions that can be used as aggregators or
reductions.
:param dict kwargs: Additional parameters.
:return: A dictionary of functions.
:rtype: dict
"""
return {
"+": partial(torch.sum, **kwargs),
"*": partial(torch.prod, **kwargs),
"mean": partial(torch.mean, **kwargs),
"min": lambda x: torch.min(x, **kwargs).values,
"max": lambda x: torch.max(x, **kwargs).values,
}
def _init_aggregator(self, aggregator):
"""
Initialize the aggregator.
:param aggregator: The aggregator to be used to aggregate.
:type aggregator: str or Callable
:raises ValueError: If the aggregator is not supported.
"""
aggregator_funcs = self._symbol_functions(dim=2)
if aggregator in aggregator_funcs:
aggregator_func = aggregator_funcs[aggregator]
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
aggregator_func = aggregator
else:
raise ValueError(f"Unsupported aggregation: {str(aggregator)}")
self._aggregator = aggregator_func
self._aggregator_type = aggregator
def _init_reduction(self, reduction):
"""
Initialize the reduction.
:param reduction: The reduction to be used.
:type reduction: str or Callable
:raises ValueError: If the reduction is not supported.
"""
reduction_funcs = self._symbol_functions(dim=-1)
if reduction in reduction_funcs:
reduction_func = reduction_funcs[reduction]
elif isinstance(reduction, nn.Module) or is_function(reduction):
reduction_func = reduction
else:
raise ValueError(f"Unsupported reduction: {reduction}")
self._reduction = reduction_func
self._reduction_type = reduction
def _get_vars(self, x, indeces):
"""
Extract the variables from the input tensor.
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:param indeces: The indeces to extract.
:type indeces: list[int] | list[str]
:raises RuntimeError: If failing to extract the variables.
:raises RuntimeError: If failing to extract the right indeces.
:return: The extracted variables.
:rtype: LabelTensor | torch.Tensor
"""
if isinstance(indeces[0], str):
try:
return x.extract(indeces)
except AttributeError as e:
raise RuntimeError(
"Not possible to extract input variables from tensor."
" Ensure that the passed tensor is a LabelTensor or"
" pass list of integers to extract variables. For"
" more information refer to warning in the documentation."
) from e
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
raise RuntimeError(
"Not able to extract right indeces for tensor."
" For more information refer to warning in the documentation."
)
[docs]
def forward(self, x):
"""
Forward pass for the :class:`MIONet` model.
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:return: The output tensor.
:rtype: LabelTensor | torch.Tensor
"""
# forward pass
output_ = [
model(self._get_vars(x, indeces))
for model, indeces in zip(self.models, self._indeces)
]
# aggregation
aggregated = self._aggregator(torch.dstack(output_))
# reduce
output_ = self._reduction(aggregated)
if self._reduction_type in self._symbol_functions(dim=-1):
output_ = output_.reshape(-1, 1)
# scale and translate
output_ *= self._scale
output_ += self._trasl
return output_
@property
def aggregator(self):
"""
The aggregator function.
:return: The aggregator function.
:rtype: str or Callable
"""
return self._aggregator
@property
def reduction(self):
"""
The reduction function.
:return: The reduction function.
:rtype: str or Callable
"""
return self._reduction
@property
def scale(self):
"""
The scale factor.
:return: The scale factor.
:rtype: torch.Tensor
"""
return self._scale
@property
def translation(self):
"""
The translation factor.
:return: The translation factor.
:rtype: torch.Tensor
"""
return self._trasl
@property
def indeces_variables_extracted(self):
"""
The input indeces for each model in form of list.
:return: The indeces for each model.
:rtype: list
"""
return self._indeces
@property
def model(self):
"""
The models in form of list.
:return: The models.
:rtype: list[torch.nn.Module]
"""
return self._indeces
[docs]
class DeepONet(MIONet):
"""
DeepONet model class.
The MIONet is a general architecture for learning operators, which map
functions to functions. It can be trained with both Supervised and
Physics-Informed learning strategies.
.. seealso::
**Original reference**: Lu, L., Jin, P., Pang, G. et al.
*Learning nonlinear operators via DeepONet based on the universal
approximation theorem of operator*.
Nat Mach Intell 3, 218-229 (2021).
DOI: `10.1038/s42256-021-00302-5
<https://doi.org/10.1038/s42256-021-00302-5>`_
"""
def __init__(
self,
branch_net,
trunk_net,
input_indeces_branch_net,
input_indeces_trunk_net,
aggregator="*",
reduction="+",
scale=True,
translation=True,
):
"""
Initialization of the :class:`DeepONet` class.
:param torch.nn.Module branch_net: The neural network to use as branch
model. It has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
The output dimension has to be the same as that of ``trunk_net``.
:param torch.nn.Module trunk_net: The neural network to use as trunk
model. It has to take as input either a
:class:`~pina.label_tensor.LabelTensor` or a :class:`torch.Tensor`.
The output dimension has to be the same as that of ``branch_net``.
:param input_indeces_branch_net: List of indeces to extract from the
input variable of the ``branch_net``.
If a list of ``int`` is passed, the corresponding columns of the
inner most entries are extracted. If a list of ``str`` is passed the
variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
:type input_indeces_branch_net: list[int] | list[str]
:param input_indeces_trunk_net: List of indeces to extract from the
input variable of the ``trunk_net``.
If a list of ``int`` is passed, the corresponding columns of the
inner most entries are extracted. If a list of ``str`` is passed the
variables of the corresponding
:class:`~pina.label_tensor.LabelTensor` are extracted.
:type input_indeces_trunk_net: list[int] | list[str]
:param aggregator: The aggregator to be used to aggregate component-wise
partial results from the modules in ``networks``. Available
aggregators include: sum: ``+``, product: ``*``, mean: ``mean``,
min: ``min``, max: ``max``. Default is ``*``.
:type aggregator: str or Callable
:param reduction: The reduction to be used to reduce the aggregated
result of the modules in ``networks`` to the desired output
dimension. Available reductions include: sum: ``+``, product: ``*``,
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
:type reduction: str or Callable
:param bool scale: If ``True``, the final output is scaled before being
returned in the forward pass. Default is ``True``.
:param bool translation: If ``True``, the final output is translated
before being returned in the forward pass. Default is ``True``.
.. warning::
In the forward pass we do not check if the input is instance of
:py:obj:`pina.label_tensor.LabelTensor` or :class:`torch.Tensor`.
A general rule is that for a :py:obj:`pina.label_tensor.LabelTensor`
input both list of integers and list of strings can be passed for
``input_indeces_branch_net`` and ``input_indeces_trunk_net``.
Differently, for a :class:`torch.Tensor` only a list of integers can
be passed for ``input_indeces_branch_net`` and
``input_indeces_trunk_net``.
.. warning::
No checks are performed in the forward pass to verify if the input
is instance of either :class:`~pina.label_tensor.LabelTensor` or
:class:`torch.Tensor`. In general, in case of a
:class:`~pina.label_tensor.LabelTensor`, both a ``list[int]`` or a
``list[str]`` can be passed as ``input_indeces_branch_net`` and
``input_indeces_trunk_net``. Differently, in case of a
:class:`torch.Tensor`, only a ``list[int]`` can be passed.
:Example:
>>> branch_net = FeedForward(input_dimensons=1,
... output_dimensions=10)
>>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10)
>>> model = DeepONet(branch_net=branch_net,
... trunk_net=trunk_net,
... input_indeces_branch_net=['x'],
... input_indeces_trunk_net=['t'],
... reduction='+',
... aggregator='*')
>>> model
DeepONet(
(trunk_net): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
(branch_net): FeedForward(
(model): Sequential(
(0): Linear(in_features=1, out_features=20, bias=True)
(1): Tanh()
(2): Linear(in_features=20, out_features=20, bias=True)
(3): Tanh()
(4): Linear(in_features=20, out_features=10, bias=True)
)
)
)
"""
networks = {
branch_net: input_indeces_branch_net,
trunk_net: input_indeces_trunk_net,
}
super().__init__(
networks=networks,
aggregator=aggregator,
reduction=reduction,
scale=scale,
translation=translation,
)
[docs]
def forward(self, x):
"""
Forward pass for the :class:`DeepONet` model.
:param x: The input tensor.
:type x: LabelTensor | torch.Tensor
:return: The output tensor.
:rtype: LabelTensor | torch.Tensor
"""
return super().forward(x)
@property
def branch_net(self):
"""
The branch net of the DeepONet.
:return: The branch net.
:rtype: torch.nn.Module
"""
return self.models[0]
@property
def trunk_net(self):
"""
The trunk net of the DeepONet.
:return: The trunk net.
:rtype: torch.nn.Module
"""
return self.models[1]