# Predict¶

Functions to make predictions on the train/test set using NTK/NNGP.

Most functions in this module accept training data as inputs and return a new function predict_fn that computes predictions on the train set / given test set / timesteps.

WARNING: trace_axes parameter supplied to prediction functions must match the respective parameter supplied to the function used to compute the kernel. Namely, this is the same trace_axes used to compute the empirical kernel (utils/empirical.py; diagonal_axes must be ()), or channel_axis in the output of the top layer used to compute the closed-form kernel (stax.py; note that closed-form kernels currently only support a single channel_axis).

class neural_tangents.predict.Gaussian(mean: jax.numpy.ndarray, covariance: jax.numpy.ndarray)[source]

A (mean, covariance) convenience namedtuple.

property covariance

Alias for field number 1

property mean

Alias for field number 0

class neural_tangents.predict.ODEState(fx_train=None, fx_test=None, qx_train=None, qx_test=None)[source]

ODE state dataclass holding outputs and auxiliary variables.

neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1))[source]

Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.

NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this can correspond to infinite ensembles of infinitely wide NNs as well).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of k_train_train.nngp or k_train_train.ntk (or both) is performed and cached for future invocations.

Parameters
• k_train_train – train-train kernel. Can be (a) np.ndarray, (b) Kernel namedtuple, (c) Kernel object. Must contain the necessary nngp and/or ntk kernels for arguments provided to the returned predict_fn function. For example, if you request to compute posterior test [only] NTK covariance in future predict_fn invocations, k_train_train must contain both ntk and nngp kernels.

• y_train (ndarray) – train targets.

• diag_reg (float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization.

• diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * np.mean(np.trace(k_train_train)).

• trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train, k_test_train[, and k_test_test] lack these pairs of dimensions and are to be interpreted as $$\Theta \otimes I$$, i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Returns

A function of signature predict_fn(get, k_test_train, k_test_test) computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.

neural_tangents.predict.gradient_descent(loss, k_train_train, y_train, learning_rate=1.0, momentum=None, trace_axes=(- 1))[source]

Predicts the outcome of function space training using gradient descent.

Uses an ODE solver. If momentum != None, solves a continuous-time version of gradient descent with momentum (note: this case uses standard momentum as opposed to Nesterov momentum).

Solves the function space ODE for [momentum] gradient descent with a given loss (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) t. Note that for gradient descent absolute_time = learning_rate * t and the scales of the learning rate and query step[s] t are interchangeable. However, the momentum gradient descent ODE is solved in the units of learning_rate**0.5, and therefore absolute_time = learning_rate**0.5 * t, hence the learning_rate and training time[s] (step[s]) t scales are not interchangeable.

Example

>>> from neural_tangents import empirical_ntk_fn
>>> from neural_tangents import predict
>>>
>>> t = 1e-7
>>> learning_rate = 1e-2
>>> momentum = 0.9
>>>
>>> kernel_fn = empirical_ntk_fn(f)
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>> from jax.experimental import stax
>>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
>>>                                       y_train, learning_rate, momentum)
>>>
>>> fx_train_0 = f(params, x_train)
>>> fx_test_0 = f(params, x_test)
>>>
>>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
>>>                                    k_test_train)

Parameters
• loss (Callable[[ndarray, ndarray], float]) – a loss function whose signature is loss(f(x_train), y_train). Note: the loss function should treat the batch and output dimensions symmetrically.

• k_train_train (ndarray) – kernel on the training data. Must have the shape of zip(y_train.shape, y_train.shape) with trace_axes absent.

• y_train (ndarray) – targets for the training data.

• learning_rate (float) – learning rate, step size.

• momentum (Optional[float]) – momentum scalar.

• trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train lacks these pairs of dimensions and is to be interpreted as $$\Theta \otimes I$$, i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Return type

Callable[[Union[None, int, float, ndarray], Union[None, int, float, ndarray, ODEState], Union[None, int, float, ndarray], Optional[ndarray]], Union[ndarray, Tuple[ndarray, ndarray], ODEState]]

Returns

A function that returns output train [and test] set[s] predictions at time[s] t.

neural_tangents.predict.gradient_descent_mse(k_train_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1))[source]

Predicts the outcome of function space gradient descent training on MSE.

Solves in closed form for the continuous-time version of gradient descent.

Uses the closed-form solution for gradient descent on an MSE loss in function space detailed in [,*] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s]) t. Note that these time[s] (step[s]) are continuous and are interpreted in units of the learning_rate so absolute_time = learning_rate * t, and the scales of learning_rate and t are interchangeable.

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (t is a scalar or an array) or Cholesky factorization (t=None) of k_train_train is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None) times).

Example

>>> from neural_tangents import empirical_ntk_fn
>>> from neural_tangents import predict
>>>
>>> t = 1e-7
>>> kernel_fn = empirical_ntk_fn(f)
>>> k_train_train = kernel_fn(x_train, None, params)
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>>
>>> fx_train_0 = f(params, x_train)
>>> fx_test_0 = f(params, x_test)
>>>
>>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
>>>                                    k_test_train)

Parameters
• k_train_train (ndarray) – kernel on the training data. Must have the shape of zip(y_train.shape, y_train.shape) with trace_axes absent.

• y_train (ndarray) – targets for the training data.

• learning_rate (float) – learning rate, step size.

• diag_reg (float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization or eigendecomposition.

• diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * np.mean(np.trace(k_train_train)).

• trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train lacks these pairs of dimensions and is to be interpreted as $$\Theta \otimes I$$, i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Return type

Callable[[Union[None, int, float, ndarray], Union[None, int, float, ndarray], Union[None, int, float, ndarray], Optional[ndarray]], Union[ndarray, Tuple[ndarray, ndarray]]]

Returns

A function of signature predict_fn(t, fx_train_0, fx_test_0, k_test_train) that returns output train [and test] set[s] predictions at time[s] t.

neural_tangents.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1), **kernel_fn_train_train_kwargs)[source]

Predicts the gaussian embedding induced by gradient descent on MSE loss.

This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if kernel_fn is the kernel function of the infinite-width network. Note that kernel_fn can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (t is a scalar or an array) or Cholesky factorization (t=None) of kernel_fn(x_train, None, get) is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None) times).

Parameters
Returns

A function with signature predict_fn(t, x_test, get, compute_cov) returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on x_test at time[s] t, in the get regime ("nngp", "ntk", or ("nngp", "ntk")).

neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]

Computes the maximal feasible learning rate for infinite width NNs.

The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form 1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2. For vanilla SGD (i.e. momentum = 0) the maximal feasible learning rate is the largest \eta such that the operator (I - \eta / (batch_size * output_size) * NTK) is a contraction, which is 2 * batch_size * output_size * lambda_max(NTK). When momentum > 0, we use 2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK) (see The Dynamics of Momentum section in https://distill.pub/2017/momentum/).

Parameters
• ntk_train_train (ndarray) – analytic or empirical NTK on the training data.

• y_train_size (Optional[int]) – total training set output size, i.e. f(x_train).size ==  y_train.size. If output_size=None it is inferred from ntk_train_train.shape assuming trace_axes=().

• momentum – The momentum for momentum optimizers.

• eps (float) – a float to avoid zero divisor.

Return type

float

Returns

The maximal feasible learning rate for infinite width NNs.