DeepONet#
- class DeepONet(branch_net, trunk_net, input_indeces_branch_net, input_indeces_trunk_net, aggregator='*', reduction='+', scale=True, translation=True)[source]#
Bases:
MIONet
DeepONet model class.
The MIONet is a general architecture for learning operators, which map functions to functions. It can be trained with both Supervised and Physics-Informed learning strategies.
See also
Original reference: Lu, L., Jin, P., Pang, G. et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operator. Nat Mach Intell 3, 218-229 (2021). DOI: 10.1038/s42256-021-00302-5
Initialization of the
DeepONet
class.- Parameters:
branch_net (torch.nn.Module) – The neural network to use as branch model. It has to take as input either a
LabelTensor
or atorch.Tensor
. The output dimension has to be the same as that oftrunk_net
.trunk_net (torch.nn.Module) – The neural network to use as trunk model. It has to take as input either a
LabelTensor
or atorch.Tensor
. The output dimension has to be the same as that ofbranch_net
.input_indeces_branch_net (list[int] | list[str]) – List of indeces to extract from the input variable of the
branch_net
. If a list ofint
is passed, the corresponding columns of the inner most entries are extracted. If a list ofstr
is passed the variables of the correspondingLabelTensor
are extracted.input_indeces_trunk_net (list[int] | list[str]) – List of indeces to extract from the input variable of the
trunk_net
. If a list ofint
is passed, the corresponding columns of the inner most entries are extracted. If a list ofstr
is passed the variables of the correspondingLabelTensor
are extracted.aggregator (str or Callable) – The aggregator to be used to aggregate component-wise partial results from the modules in
networks
. Available aggregators include: sum:+
, product:*
, mean:mean
, min:min
, max:max
. Default is*
.reduction (str or Callable) – The reduction to be used to reduce the aggregated result of the modules in
networks
to the desired output dimension. Available reductions include: sum:+
, product:*
, mean:mean
, min:min
, max:max
. Default is+
.scale (bool) – If
True
, the final output is scaled before being returned in the forward pass. Default isTrue
.translation (bool) – If
True
, the final output is translated before being returned in the forward pass. Default isTrue
.
Warning
In the forward pass we do not check if the input is instance of
pina.label_tensor.LabelTensor
ortorch.Tensor
. A general rule is that for apina.label_tensor.LabelTensor
input both list of integers and list of strings can be passed forinput_indeces_branch_net
andinput_indeces_trunk_net
. Differently, for atorch.Tensor
only a list of integers can be passed forinput_indeces_branch_net
andinput_indeces_trunk_net
.Warning
No checks are performed in the forward pass to verify if the input is instance of either
LabelTensor
ortorch.Tensor
. In general, in case of aLabelTensor
, both alist[int]
or alist[str]
can be passed asinput_indeces_branch_net
andinput_indeces_trunk_net
. Differently, in case of atorch.Tensor
, only alist[int]
can be passed.- Example:
>>> branch_net = FeedForward(input_dimensons=1, ... output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> model = DeepONet(branch_net=branch_net, ... trunk_net=trunk_net, ... input_indeces_branch_net=['x'], ... input_indeces_trunk_net=['t'], ... reduction='+', ... aggregator='*') >>> model DeepONet( (trunk_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (branch_net): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) )
- forward(x)[source]#
Forward pass for the
DeepONet
model.- Parameters:
x (LabelTensor | torch.Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
- property branch_net#
The branch net of the DeepONet.
- Returns:
The branch net.
- Return type:
- property trunk_net#
The trunk net of the DeepONet.
- Returns:
The trunk net.
- Return type: