Weighting callbacks#

class LinearWeightUpdate(target_epoch, condition_name, initial_value, target_value)[source]#

Bases: Callback

Callback to linearly adjust the weight of a condition from an initial value to a target value over a specified number of epochs.

Callback initialization.

Parameters:
  • target_epoch (int) – The epoch at which the weight of the condition should reach the target value.

  • condition_name (str) – The name of the condition whose weight should be adjusted.

  • initial_value (float) – The initial value of the weight.

  • target_value (float) – The target value of the weight.

on_train_start(trainer, pl_module)[source]#

Initialize the weight of the condition to the specified initial_value.

Parameters:
on_train_epoch_start(trainer, pl_module)[source]#

Adjust at each epoch the weight of the condition.

Parameters: