DeepONet#
- class DeepONet(branch_net, trunk_net, input_indeces_branch_net, input_indeces_trunk_net, aggregator='*', reduction='+', scale=True, translation=True)[source]#
Bases:
MIONet
The PINA implementation of DeepONet network.
DeepONet is a general architecture for learning Operators. Unlike traditional machine learning methods DeepONet is designed to map entire functions to other functions. It can be trained both with Physics Informed or Supervised learning strategies.
See also
Original reference: Lu, L., Jin, P., Pang, G. et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell 3, 218–229 (2021). DOI: 10.1038/s42256-021-00302-5
- Parameters:
branch_net (torch.nn.Module) – The neural network to use as branch model. It has to take as input a
pina.label_tensor.LabelTensor
ortorch.Tensor
. The number of dimensions of the output has to be the same of thetrunk_net
.trunk_net (torch.nn.Module) – The neural network to use as trunk model. It has to take as input a
pina.label_tensor.LabelTensor
ortorch.Tensor
. The number of dimensions of the output has to be the same of thebranch_net
.input_indeces_branch_net (list(int) or list(str)) – List of indeces to extract from the input variable in the forward pass for the branch net. If a list of
int
is passed, the corresponding columns of the inner most entries are extracted. If a list ofstr
is passed the variables of the correspondingpina.label_tensor.LabelTensor
are extracted.input_indeces_trunk_net (list(int) or list(str)) – List of indeces to extract from the input variable in the forward pass for the trunk net. If a list of
int
is passed, the corresponding columns of the inner most entries are extracted. If a list ofstr
is passed the variables of the correspondingpina.label_tensor.LabelTensor
are extracted.aggregator (str or Callable) – Aggregator to be used to aggregate partial results from the modules in nets. Partial results are aggregated component-wise. Available aggregators include sum:
+
, product:*
, mean:mean
, min:min
, max:max
.reduction (str or Callable) – Reduction to be used to reduce the aggregated result of the modules in nets to the desired output dimension. Available reductions include sum:
+
, product:*
, mean:mean
, min:min
, max:max
.scale (bool or Callable) – Scaling the final output before returning the forward pass, default True.
translation (bool or Callable) – Translating the final output before returning the forward pass, default True.
Warning
In the forward pass we do not check if the input is instance of
pina.label_tensor.LabelTensor
ortorch.Tensor
. A general rule is that for apina.label_tensor.LabelTensor
input both list of integers and list of strings can be passed forinput_indeces_branch_net
andinput_indeces_trunk_net
. Differently, for atorch.Tensor
only a list of integers can be passed forinput_indeces_branch_net
andinput_indeces_trunk_net
.- Example:
>>> branch_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> model = DeepONet(branch_net=branch_net, ... trunk_net=trunk_net, ... input_indeces_branch_net=['x'], ... input_indeces_trunk_net=['t'], ... reduction='+', ... aggregator='*') >>> model DeepONet( (trunk_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (branch_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) )
- forward(x)[source]#
Defines the computation performed at every call.
- Parameters:
x (LabelTensor or torch.Tensor) – The input tensor for the forward call.
- Returns:
The output computed by the DeepONet model.
- Return type:
- property branch_net#
The branch net for DeepONet.
- property trunk_net#
The trunk net for DeepONet.