neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,))[source]

Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.

NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see “Bayesian Deep Ensembles via the Neural Tangent Kernel” for how this can correspond to infinite ensembles of infinitely wide NNs as well).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of k_train_train.nngp or k_train_train.ntk (or both) is performed and cached for future invocations.

  • k_train_train – train-train kernel. Can be (a) jax.numpy.ndarray, (b) Kernel namedtuple, (c) Kernel object. Must contain the necessary nngp and/or ntk kernels for arguments provided to the returned predict_fn function. For example, if you request to compute posterior test [only] NTK covariance in future predict_fn invocations, k_train_train must contain both ntk and nngp kernels.

  • y_train (ndarray) – train targets.

  • diag_reg (float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization.

  • diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * jnp.mean(jnp.trace(k_train_train)).

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that k_train_train, k_test_train`[, and `k_test_test] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.


A function of signature predict_fn(get, k_test_train, k_test_test) computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.