neural_tangents.predict.gp_inference
- neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,))[source]
Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.
NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see “Bayesian Deep Ensembles via the Neural Tangent Kernel” for how this can correspond to infinite ensembles of infinitely wide NNs as well).
Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization ofk_train_train.nngp
ork_train_train.ntk
(or both) is performed and cached for future invocations.- Parameters:
k_train_train – train-train kernel. Can be (a)
jax.numpy.ndarray
, (b)Kernel
namedtuple, (c)Kernel
object. Must contain the necessarynngp
and/orntk
kernels for arguments provided to the returnedpredict_fn
function. For example, if you request to compute posterior test [only] NTK covariance in futurepredict_fn
invocations,k_train_train
must contain bothntk
andnngp
kernels.y_train (
ndarray
) – train targets.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization fork_train_train
, i.e. computingk_train_train + diag_reg * I
during Cholesky factorization.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * jnp.mean(jnp.trace(k_train_train))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
,k_test_train`[, and `k_test_test
] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.
- Returns:
A function of signature
predict_fn(get, k_test_train, k_test_test)
computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.