neural_tangents.predict.max_learning_rate
- neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]
Computes the maximal feasible learning rate for infinite width NNs.
The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form
1/(2 * batch_size * output_size) |f(train_x) - train_y|^2. For vanilla SGD (i.e.momentum = 0) the maximal feasible learning rate is the largestetasuch that the operator(I - eta / (batch_size * output_size) * NTK)is a contraction, which is2 * batch_size * output_size * lambda_max(NTK). Whenmomentum > 0, we use2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)(see The Dynamics of Momentum section in “Why Momentum Really Works”).- Parameters:
ntk_train_train (
ndarray) – analytic or empirical NTK on the training data.y_train_size (
Optional[int]) – total training set output size, i.e.f(x_train).size == y_train.size. Ifoutput_size=Noneit is inferred fromntk_train_train.shapeassumingtrace_axes=().momentum – The
momentumfor momentum optimizers.eps (
float) – a float to avoid zero divisor.
- Return type:
- Returns:
The maximal feasible learning rate for infinite width NNs.