Source code for pina.model.multi_feed_forward
"""Module for the Multi Feed Forward model class."""
from abc import ABC, abstractmethod
import torch
from .feed_forward import FeedForward
[docs]
class MultiFeedForward(torch.nn.Module, ABC):
"""
Multi Feed Forward neural network model class.
This model allows to create a network with multiple Feed Forward neural
networks combined together. The user is required to define the ``forward``
method to choose how to combine the networks.
"""
def __init__(self, ffn_dict):
"""
Initialization of the :class:`MultiFeedForward` class.
:param dict ffn_dict: A dictionary containing the Feed Forward neural
networks to be combined.
:raises TypeError: If the input is not a dictionary.
"""
super().__init__()
if not isinstance(ffn_dict, dict):
raise TypeError
for name, constructor_args in ffn_dict.items():
setattr(self, name, FeedForward(**constructor_args))
[docs]
@abstractmethod
def forward(self, *args, **kwargs):
"""
Forward pass for the :class:`MultiFeedForward` model.
The user is required to define this method to choose how to combine the
networks.
"""