neural_tangents.predict.gp_inference
- neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,))[source]
Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.
NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see “Bayesian Deep Ensembles via the Neural Tangent Kernel” for how this can correspond to infinite ensembles of infinitely wide NNs as well).
Note that first invocation of the returned
predict_fnwill be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization ofk_train_train.nngpork_train_train.ntk(or both) is performed and cached for future invocations.- Parameters:
k_train_train – train-train kernel. Can be (a)
jax.numpy.ndarray, (b)Kernelnamedtuple, (c)Kernelobject. Must contain the necessarynngpand/orntkkernels for arguments provided to the returnedpredict_fnfunction. For example, if you request to compute posterior test [only] NTK covariance in futurepredict_fninvocations,k_train_trainmust contain bothntkandnngpkernels.y_train (
ndarray) – train targets.diag_reg (
float) – a scalar representing the strength of the diagonal regularization fork_train_train, i.e. computingk_train_train + diag_reg * Iduring Cholesky factorization.diag_reg_absolute_scale (
bool) –Truefordiag_regto represent regularization in absolute units,Falseto bediag_reg * jnp.mean(jnp.trace(k_train_train)).trace_axes (
Union[int,Sequence[int]]) –f(x_train)axes such thatk_train_train,k_test_train`[, and `k_test_test] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()will yield most accurate result.
- Returns:
A function of signature
predict_fn(get, k_test_train, k_test_test)computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.