It implements the decoupled weight decay described by Loshchilov & Hutter, in which the weight decay is decoupled from the optimization steps w.r.t. to the loss function. For SGD variants, this simplifies hyperparameter search since it decouples the settings of weight decay and learning rate. For adaptive gradient algorithms, it regularizes variables with large gradients more than L2 regularization would, which was shown to yield better training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing optimizers with decoupled weight decay. We explicitly define the two examples used in the above paper (SGDW and AdamW), but in general this can extend any OptimizerX class by using ExtendedCls = extend_with_decoupled_weight_decay(OptimizerX). Weight decay can then be set when instantiating the optimizer: optimizerX = ExtendedCls(weight_decay=0.001, learning_rate=0.001). In order for it to work, it must be the first class the Optimizer with weight decay inherits from, e.g.
step=tf.Variable(0,trainable=False)schedule=tf.optimizers.schedules.PiecewiseConstantDecay([10000,15000],[1e-0,1e-1,1e-2])# lr and wd can be a function or a tensorlr=1e-1*schedule(step)wd=lambda:1e-4*schedule(step)# ...optimizer=tfa.optimizers.AdamW(learning_rate=lr,weight_decay=wd)
List of regex patterns of variables excluded from weight decay. Variables whose name contain a substring matching the pattern will be excluded. Note decay_var_list in minimize or apply_gradients takes priority over exclude_from_weight_decay if specified.
**kwargs
Optional list or tuple or set of Variable objects to decay.
This is the second part of minimize(). It returns an Operation that applies gradients.
Args
grads_and_vars
List of (gradient, variable) pairs.
name
Optional name for the returned operation. Default to the name passed to the Optimizer constructor.
decay_var_list
Optional list of variables to be decayed. Defaults to all variables in var_list. Note decay_var_list takes priority over exclude_from_weight_decay if specified.
**kwargs
Additional arguments to pass to the base optimizer's apply_gradient method, e.g., TF2.2 added an argument experimental_aggregate_gradients.
Returns
An Operation that applies the specified gradients.
This method simply computes gradient using tf.GradientTape and calls apply_gradients(). If you want to process the gradient before applying then call tf.GradientTape and apply_gradients() explicitly instead of using this function.
Args
loss
Tensor or callable. If a callable, loss should take no arguments and return the value to minimize. If a Tensor, the tape argument must be passed.
var_list
list or tuple of Variable objects to update to minimize loss, or a callable returning the list or tuple of Variable objects. Use callable when the variable list would otherwise be incomplete before minimize since the variables are created at the first time loss is called.
grad_loss
Optional. A Tensor holding the gradient computed for loss.
decay_var_list
Optional list of variables to be decayed. Defaults to all variables in var_list. Note decay_var_list takes priority over exclude_from_weight_decay if specified.
name
Optional name for the returned operation.
tape
(Optional) tf.GradientTape. If loss is provided as a Tensor, the tape that computed the loss must be provided.
Returns
An Operation that updates the variables in var_list.
Raises
ValueError
If some of the variables are not Variable objects.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-05-25 UTC."],[],[]]