Predict

Functions to make predictions on the train/test set using NTK/NNGP.

Most functions in this module accept training data as inputs and return a new function predict_fn that computes predictions on the train set / given test set / timesteps.

WARNING: trace_axes parameter supplied to prediction functions must match the respective parameter supplied to the function used to compute the kernel. Namely, this is the same trace_axes used to compute the empirical kernel (utils/empirical.py; diagonal_axes must be ()), or channel_axis in the output of the top layer used to compute the closed-form kernel (stax.py; note that closed-form kernels currently only support a single channel_axis).

class neural_tangents.predict.Gaussian(mean: jax.numpy.ndarray, covariance: jax.numpy.ndarray)[source]

A (mean, covariance) convenience namedtuple.

property covariance

Alias for field number 1

property mean

Alias for field number 0

class neural_tangents.predict.ODEState(fx_train=None, fx_test=None, qx_train=None, qx_test=None)[source]

ODE state dataclass holding outputs and auxiliary variables.

neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1))[source]

Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.

NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this can correspond to infinite ensembles of infinitely wide NNs as well).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of k_train_train.nngp or k_train_train.ntk (or both) is performed and cached for future invocations.

Parameters
  • k_train_train – train-train kernel. Can be (a) np.ndarray, (b) Kernel namedtuple, (c) Kernel object. Must contain the necessary nngp and/or ntk kernels for arguments provided to the returned predict_fn function. For example, if you request to compute posterior test [only] NTK covariance in future predict_fn invocations, k_train_train must contain both ntk and nngp kernels.

  • y_train (ndarray) – train targets.

  • diag_reg (float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization.

  • diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * np.mean(np.trace(k_train_train)).

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train, k_test_train`[, and `k_test_test] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Returns

A function of signature predict_fn(get, k_test_train, k_test_test) computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.

neural_tangents.predict.gradient_descent(loss, k_train_train, y_train, learning_rate=1.0, momentum=None, trace_axes=(- 1))[source]

Predicts the outcome of function space training using gradient descent.

Uses an ODE solver. If momentum != None, solves a continuous-time version of gradient descent with momentum (note: this case uses standard momentum as opposed to Nesterov momentum).

Solves the function space ODE for [momentum] gradient descent with a given loss (detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s]) t. Note that for gradient descent absolute_time = learning_rate * t and the scales of the learning rate and query step[s] t are interchangeable. However, the momentum gradient descent ODE is solved in the units of learning_rate**0.5, and therefore absolute_time = learning_rate**0.5 * t, hence the learning_rate and training time[s] (step[s]) t scales are not interchangeable.

[*] https://arxiv.org/abs/1902.06720

Example

>>> from neural_tangents import empirical_ntk_fn
>>> from neural_tangents import predict
>>>
>>> t = 1e-7
>>> learning_rate = 1e-2
>>> momentum = 0.9
>>>
>>> kernel_fn = empirical_ntk_fn(f)
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>> from jax.experimental import stax
>>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
>>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train,
>>>                                       y_train, learning_rate, momentum)
>>>
>>> fx_train_0 = f(params, x_train)
>>> fx_test_0 = f(params, x_test)
>>>
>>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
>>>                                    k_test_train)
Parameters
  • loss (Callable[[ndarray, ndarray], float]) – a loss function whose signature is loss(f(x_train), y_train). Note: the loss function should treat the batch and output dimensions symmetrically.

  • k_train_train (ndarray) – kernel on the training data. Must have the shape of zip(y_train.shape, y_train.shape) with trace_axes absent.

  • y_train (ndarray) – targets for the training data.

  • learning_rate (float) – learning rate, step size.

  • momentum (Optional[float]) – momentum scalar.

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Return type

Callable[[Union[None, int, float, ndarray], Union[None, int, float, ndarray, ODEState], Union[None, int, float, ndarray], Optional[ndarray]], Union[ndarray, Tuple[ndarray, ndarray], ODEState]]

Returns

A function that returns output train [and test] set[s] predictions at time[s] t.

neural_tangents.predict.gradient_descent_mse(k_train_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1))[source]

Predicts the outcome of function space gradient descent training on MSE.

Solves in closed form for the continuous-time version of gradient descent.

Uses the closed-form solution for gradient descent on an MSE loss in function space detailed in [,*] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s]) t. Note that these time[s] (step[s]) are continuous and are interpreted in units of the learning_rate so absolute_time = learning_rate * t, and the scales of learning_rate and t are interchangeable.

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (t is a scalar or an array) or Cholesky factorization (t=None) of k_train_train is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None) times).

[*] https://arxiv.org/abs/1806.07572 [**] https://arxiv.org/abs/1902.06720

Example

>>> from neural_tangents import empirical_ntk_fn
>>> from neural_tangents import predict
>>>
>>> t = 1e-7
>>> kernel_fn = empirical_ntk_fn(f)
>>> k_train_train = kernel_fn(x_train, None, params)
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>> predict_fn = predict.gradient_descent_mse(k_train_train, y_train)
>>>
>>> fx_train_0 = f(params, x_train)
>>> fx_test_0 = f(params, x_test)
>>>
>>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0,
>>>                                    k_test_train)
Parameters
  • k_train_train (ndarray) – kernel on the training data. Must have the shape of zip(y_train.shape, y_train.shape) with trace_axes absent.

  • y_train (ndarray) – targets for the training data.

  • learning_rate (float) – learning rate, step size.

  • diag_reg (float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization or eigendecomposition.

  • diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * np.mean(np.trace(k_train_train)).

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

Return type

Callable[[Union[None, int, float, ndarray], Union[None, int, float, ndarray], Union[None, int, float, ndarray], Optional[ndarray]], Union[ndarray, Tuple[ndarray, ndarray]]]

Returns

A function of signature predict_fn(t, fx_train_0, fx_test_0, k_test_train) that returns output train [and test] set[s] predictions at time[s] t.

neural_tangents.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1), **kernel_fn_train_train_kwargs)[source]

Predicts the gaussian embedding induced by gradient descent on MSE loss.

This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if kernel_fn is the kernel function of the infinite-width network. Note that kernel_fn can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (t is a scalar or an array) or Cholesky factorization (t=None) of kernel_fn(x_train, None, get) is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None) times).

Parameters
  • kernel_fn (Union[Callable[[Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[Kernel], Tuple[Kernel, …], Kernel, List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None], Any], Union[List[ndarray], Tuple[ndarray, …], ndarray]], Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None]], Union[List[ndarray], Tuple[ndarray, …], ndarray, Generator[Union[List[ndarray], Tuple[ndarray, …], ndarray], None, None]]]]) – A kernel function that computes NNGP and/or NTK. Must have a signature kernel_fn(x1, x2, get, **kernel_fn_kwargs) and return a Kernel object or a namedtuple with nngp and/or ntk attributes. Therefore, it can be an AnalyticKernelFn, but also a MonteCarloKernelFn, or an EmpiricalKernelFn (but only nt.empirical_kernel_fn and not nt.empirical_ntk_fn or ntk.empirical_nngp_fn, since the latter two do not accept a get argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel.

  • x_train (ndarray) – training inputs.

  • y_train (ndarray) – training targets.

  • learning_rate (float) – learning rate, step size.

  • diag_reg (float) – a scalar representing the strength of the diagonal regularization for kernel_fn(x_train, None, get), i.e. computing kernel_fn(x_train, None, get) + diag_reg * I during Cholesky factorization or eigendecomposition.

  • diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * np.mean(np.trace(kernel_fn(x_train, None, get))).

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that kernel_fn(x_train, None, get), kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

  • **kernel_fn_train_train_kwargs – optional keyword arguments passed to kernel_fn. For train-train kernel, these are passed to kernel_fn without changes. For test-test kernel, they are passed to kernel_fn, unless overwritten by a similar **kernel_fn_test_test_kwargs arguments passed to the predict_fn function call. Finally, for test-train kernel, values that are tuples of arrays (destined for calls of the finite-width network on training and testing data) will be tuples of values combined from **kernel_fn_train_train_kwargs and **kernel_fn_test_test_kwargs, and all other values must match.

Returns

A function with signature predict_fn(t, x_test, get, compute_cov) returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on x_test at time[s] t, in the get regime ("nngp", "ntk", or ("nngp", "ntk")).

neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]

Computes the maximal feasible learning rate for infinite width NNs.

The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form 1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2. For vanilla SGD (i.e. momentum = 0) the maximal feasible learning rate is the largest \eta such that the operator (I - \eta / (batch_size * output_size) * NTK) is a contraction, which is 2 * batch_size * output_size * lambda_max(NTK). When momentum > 0, we use 2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK) (see The Dynamics of Momentum section in https://distill.pub/2017/momentum/).

Parameters
  • ntk_train_train (ndarray) – analytic or empirical NTK on the training data.

  • y_train_size (Optional[int]) – total training set output size, i.e. f(x_train).size ==  y_train.size. If output_size=None it is inferred from ntk_train_train.shape assuming trace_axes=().

  • momentum – The momentum for momentum optimizers.

  • eps (float) – a float to avoid zero divisor.

Return type

float

Returns

The maximal feasible learning rate for infinite width NNs.