neural_tangents.predict.max_learning_rate
- neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]
Computes the maximal feasible learning rate for infinite width NNs.
The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form
1/(2 * batch_size * output_size) |f(train_x) - train_y|^2
. For vanilla SGD (i.e.momentum = 0
) the maximal feasible learning rate is the largesteta
such that the operator(I - eta / (batch_size * output_size) * NTK)
is a contraction, which is2 * batch_size * output_size * lambda_max(NTK)
. Whenmomentum > 0
, we use2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)
(see The Dynamics of Momentum section in “Why Momentum Really Works”).- Parameters:
ntk_train_train (
ndarray
) – analytic or empirical NTK on the training data.y_train_size (
Optional
[int
]) – total training set output size, i.e.f(x_train).size == y_train.size
. Ifoutput_size=None
it is inferred fromntk_train_train.shape
assumingtrace_axes=()
.momentum – The
momentum
for momentum optimizers.eps (
float
) – a float to avoid zero divisor.
- Return type:
- Returns:
The maximal feasible learning rate for infinite width NNs.