MIONet#
- class MIONet(networks, aggregator='*', reduction='+', scale=True, translation=True)[source]#
Bases:
ModuleMIONet model class.
The MIONet is a general architecture for learning operators, which map functions to functions. It can be trained with both Supervised and Physics-Informed learning strategies.
See also
Original reference: Jin, P., Meng, S., and Lu L. (2022). MIONet: Learning multiple-input operators via tensor product. SIAM Journal on Scientific Computing 44.6 (2022): A3490-A351 DOI: 10.1137/22M1477751
Initialization of the
MIONetclass.- Parameters:
networks (dict) – The neural networks to use as models. The
dicttakes as key a neural network, and as value the list of indeces to extract from the input variable in the forward pass of the neural network. If alist[int]is passed, the corresponding columns of the inner most entries are extracted. If alist[str]is passed the variables of the correspondingLabelTensorare extracted. Eachtorch.nn.Modulemodel has to take as input either aLabelTensoror atorch.Tensor. Default implementation consists of several branch nets and one trunk nets.aggregator (str or Callable) – The aggregator to be used to aggregate component-wise partial results from the modules in
networks. Available aggregators include: sum:+, product:*, mean:mean, min:min, max:max. Default is*.reduction (str or Callable) – The reduction to be used to reduce the aggregated result of the modules in
networksto the desired output dimension. Available reductions include: sum:+, product:*, mean:mean, min:min, max:max. Default is+.scale (bool) – If
True, the final output is scaled before being returned in the forward pass. Default isTrue.translation (bool) – If
True, the final output is translated before being returned in the forward pass. Default isTrue.
- Raises:
ValueError – If the passed networks have not the same output dimension.
Warning
No checks are performed in the forward pass to verify if the input is instance of either
LabelTensorortorch.Tensor. In general, in case of aLabelTensor, both alist[int]or alist[str]can be passed asnetworksdict values. Differently, in case of atorch.Tensor, only alist[int]can be passed asnetworksdict values.- Example:
>>> branch_net1 = FeedForward(input_dimensons=1, ... output_dimensions=10) >>> branch_net2 = FeedForward(input_dimensons=2, ... output_dimensions=10) >>> trunk_net = FeedForward(input_dimensons=1, output_dimensions=10) >>> networks = {branch_net1 : ['x'], branch_net2 : ['x', 'y'], ... trunk_net : ['z']} >>> model = MIONet(networks=networks, ... reduction='+', ... aggregator='*') >>> model MIONet( (models): ModuleList( (0): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (1): FeedForward( (model): Sequential( (0): Linear(in_features=2, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) (2): FeedForward( (model): Sequential( (0): Linear(in_features=1, out_features=20, bias=True) (1): Tanh() (2): Linear(in_features=20, out_features=20, bias=True) (3): Tanh() (4): Linear(in_features=20, out_features=10, bias=True) ) ) ) )
- forward(x)[source]#
Forward pass for the
MIONetmodel.- Parameters:
x (LabelTensor | torch.Tensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
- property aggregator#
The aggregator function.
- Returns:
The aggregator function.
- Return type:
str or Callable
- property reduction#
The reduction function.
- Returns:
The reduction function.
- Return type:
str or Callable
- property scale#
The scale factor.
- Returns:
The scale factor.
- Return type:
- property translation#
The translation factor.
- Returns:
The translation factor.
- Return type:
- property indeces_variables_extracted#
The input indeces for each model in form of list.
- Returns:
The indeces for each model.
- Return type:
- property model#
The models in form of list.
- Returns:
The models.
- Return type: