neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]

Computes the maximal feasible learning rate for infinite width NNs.

The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form 1/(2 * batch_size * output_size) |f(train_x) - train_y|^2. For vanilla SGD (i.e. momentum = 0) the maximal feasible learning rate is the largest eta such that the operator (I - eta / (batch_size * output_size) * NTK) is a contraction, which is 2 * batch_size * output_size * lambda_max(NTK). When momentum > 0, we use 2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK) (see The Dynamics of Momentum section in “Why Momentum Really Works”).

  • ntk_train_train (ndarray) – analytic or empirical NTK on the training data.

  • y_train_size (Optional[int]) – total training set output size, i.e. f(x_train).size ==  y_train.size. If output_size=None it is inferred from ntk_train_train.shape assuming trace_axes=().

  • momentum – The momentum for momentum optimizers.

  • eps (float) – a float to avoid zero divisor.

Return type



The maximal feasible learning rate for infinite width NNs.