MIONet#
- class MIONet(networks, aggregator='*', reduction='+', scale=True, translation=True)[source]#
Bases:
Module
MIONet model class.
The MIONet is a general architecture for learning operators, which map functions to functions. It can be trained with both Supervised and Physics-Informed learning strategies.
See also
Original reference: Jin, P., Meng, S., and Lu L. (2022). MIONet: Learning multiple-input operators via tensor product. SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351 DOI: 10.1137/22M1477751
Initialization of the
MIONet
class.- Parameters:
networks (dict) – The neural networks to use as models. The
dict
takes as key a neural network, and as value the list of indeces to extract from the input variable in the forward pass of the neural network. If alist[int]
is passed, the corresponding columns of the inner most entries are extracted. If alist[str]
is passed the variables of the correspondingLabelTensor
are extracted. Eachtorch.nn.Module
model has to take as input either aLabelTensor
or atorch.Tensor
. Default implementation consists of several branch nets and one trunk nets.aggregator (str or Callable) – The aggregator to be used to aggregate component-wise partial results from the modules in
networks
. Available aggregators include: sum:+
, product:*
, mean:mean
, min:min
, max:max
. Default is*
.reduction (str or Callable) – The reduction to be used to reduce the aggregated result of the modules in
networks
to the desired output dimension. Available reductions include: sum:+
, product:*
, mean:mean
, min:min
, max:max
. Default is+
.scale (bool) – If
True
, the final output is scaled before being returned in the forward pass. Default isTrue
.translation (bool) – If
True
, the final output is translated before being returned in the forward pass. Default isTrue
.
- Raises:
ValueError – If the passed networks have not the same output dimension.
Warning
No checks are performed in the forward pass to verify if the input is instance of either
LabelTensor
ortorch.Tensor
. In general, in case of aLabelTensor
, both alist[int]
or alist[str]
can be passed asnetworks
dict values. Differently, in case of atorch.Tensor
, only alist[int]
can be passed asnetworks
dict values.- Example:
>>> branch_net1 = FeedForward(input_dimensons=1, ... output_dimensions=10) >>> branch_net2 = FeedForward(input_dimensons=2, ... output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> networks = {branch_net1 : ['x'], branch_net2 : ['x', 'y'], ... trunk_net : ['z']} >>> model = MIONet(networks=networks, ... reduction='+', ... aggregator='*') >>> model MIONet( (models): ModuleList( (0): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (1): FeedForward( (model): Sequential( (0): Linear(in_features=2, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (2): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) ) )
- forward(x)[source]#
Forward pass for the
MIONet
model.- Parameters:
x (LabelTensor | torch.Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
- property aggregator#
The aggregator function.
- Returns:
The aggregator function.
- Return type:
str or Callable
- property reduction#
The reduction function.
- Returns:
The reduction function.
- Return type:
str or Callable
- property scale#
The scale factor.
- Returns:
The scale factor.
- Return type:
- property translation#
The translation factor.
- Returns:
The translation factor.
- Return type:
- property indeces_variables_extracted#
The input indeces for each model in form of list.
- Returns:
The indeces for each model.
- Return type:
- property model#
The models in form of list.
- Returns:
The models.
- Return type: