nt.predict – inference w/ NNGP & NTK

Functions to make predictions on the train/test set using NTK/NNGP.

Most functions in this module accept training data as inputs and return a new function predict_fn that computes predictions on the train set / given test set / timesteps.


trace_axes parameter supplied to prediction functions must match the respective parameter supplied to the function used to compute the kernel. Namely, this is the same trace_axes used to compute the empirical kernel (utils/empirical.py; diagonal_axes must be ()), or channel_axis in the output of the top layer used to compute the closed-form kernel (stax.py; note that closed-form kernels currently only support a single channel_axis).

Prediction / inference functions

Functions to make train/test set predictions given NNGP/NTK kernels or the linearized function.

gp_inference(k_train_train, y_train[, ...])

Compute the mean and variance of the 'posterior' of NNGP/NTK/NTKGP.

gradient_descent(loss, k_train_train, y_train)

Predicts the outcome of function space training using gradient descent.

gradient_descent_mse(k_train_train, y_train)

Predicts the outcome of function space gradient descent training on MSE.

gradient_descent_mse_ensemble(kernel_fn, ...)

Predicts the gaussian embedding induced by gradient descent on MSE loss.


max_learning_rate(ntk_train_train[, ...])

Computes the maximal feasible learning rate for infinite width NNs.

Helper classes

Dataclasses and namedtuples used to return predictions.

Gaussian(mean, covariance)

A (mean, covariance) convenience namedtuple.

ODEState([fx_train, fx_test, qx_train, qx_test])

ODE state dataclass holding outputs and auxiliary variables.