ezyrb.approximation.ann.ANN
- class ANN(layers, function, stop_training, loss=None, optimizer=<class 'torch.optim.adam.Adam'>, lr=0.001, l2_regularization=0, frequency_print=10, last_identity=True)[source]
Feed-Forward Artifical Neural Network (ANN).
- Parameters:
layers (list) – ordered list with the number of neurons of each hidden layer.
function (torch.nn.modules.activation) – activation function at each layer. A single activation function can be passed or a list of them of length equal to the number of hidden layers.
stop_training (list) – list with the maximum number of training iterations (int) and/or the desired tolerance on the training loss (float).
loss (torch.nn.Module) – loss definition (Mean Squared if not given).
optimizer (torch.optim) – the torch class implementing optimizer. Default value is Adam optimizer.
lr (float) – the learning rate. Default is 0.001.
l2_regularization (float) – the L2 regularization coefficient, it corresponds to the “weight_decay”. Default is 0 (no regularization).
frequency_print (int) – the frequency in terms of epochs of the print during the training of the network.
last_identity (boolean) – Flag to specify if the last activation function is the identity function. In the case the user provides the entire list of activation functions, this attribute is ignored. Default value is True.
- Example:
>>> import ezyrb >>> import numpy as np >>> import torch.nn as nn >>> x = np.random.uniform(-1, 1, size =(4, 2)) >>> y = np.array([np.sin(x[:, 0]), np.cos(x[:, 1]**3)]).T >>> ann = ezyrb.ANN([10, 5], nn.Tanh(), [20000,1e-5]) >>> ann.fit(x, y) >>> y_pred = ann.predict(x) >>> print(y) >>> print(y_pred) >>> print(len(ann.loss_trend)) >>> print(ann.loss_trend[-1])
- __init__(layers, function, stop_training, loss=None, optimizer=<class 'torch.optim.adam.Adam'>, lr=0.001, l2_regularization=0, frequency_print=10, last_identity=True)[source]
Methods
__init__(layers, function, stop_training[, ...])fit(points, values)Build the ANN given 'points' and 'values' and perform training.
predict(new_point)Evaluate the ANN at given 'new_points'.