neural_tangents.predict.gradient_descent(loss, k_train_train, y_train, learning_rate=1.0, momentum=None, trace_axes=(-1,))[source]

Predicts the outcome of function space training using gradient descent.

Uses an ODE solver. If momentum != None, solves a continuous-time version of gradient descent with momentum.


We use standard momentum as opposed to Nesterov momentum.

Solves the function space ODE for [momentum] gradient descent with a given loss (detailed in “Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent”.) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) t. Note that for gradient descent absolute_time = learning_rate * t and the scales of the learning rate and query step[s] t are interchangeable. However, the momentum gradient descent ODE is solved in the units of learning_rate**0.5, and therefore absolute_time = learning_rate**0.5 * t, hence the learning_rate and training time[s] (step[s]) t scales are not interchangeable.


>>> import neural_tangents as nt
>>> #
>>> t = 1e-7
>>> learning_rate = 1e-2
>>> momentum = 0.9
>>> #
>>> kernel_fn = nt.empirical_ntk_fn(f)
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>> #
>>> from jax.nn import log_softmax
>>> cross_entropy = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)
>>> predict_fn = nt.redict.gradient_descent(
>>>     cross_entropy, k_train_train, y_train, learning_rate, momentum)
>>> #
>>> fx_train_0 = f(params, x_train)
>>> fx_test_0 = f(params, x_test)
>>> #
>>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
>>>                                    k_test_train)
  • loss (Callable[[ndarray, ndarray], float]) – a loss function whose signature is loss(f(x_train), y_train). Note: the loss function should treat the batch and output dimensions symmetrically.

  • k_train_train (ndarray) – kernel on the training data. Must have the shape of zip(y_train.shape, y_train.shape) with trace_axes absent.

  • y_train (ndarray) – targets for the training data.

  • learning_rate (float) – learning rate, step size.

  • momentum (Optional[float]) – momentum scalar.

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Return type



A function that returns output train [and test] set[s] predictions at time[s] t.