- neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,))
Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.
NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see “Bayesian Deep Ensembles via the Neural Tangent Kernel” for how this can correspond to infinite ensembles of infinitely wide NNs as well).
Note that first invocation of the returned
predict_fnwill be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization of
k_train_train.ntk(or both) is performed and cached for future invocations.
k_train_train – train-train kernel. Can be (a)
Kernelobject. Must contain the necessary
ntkkernels for arguments provided to the returned
predict_fnfunction. For example, if you request to compute posterior test [only] NTK covariance in future
k_train_trainmust contain both
ndarray) – train targets.
float) – a scalar representing the strength of the diagonal regularization for k_train_train, i.e. computing k_train_train + diag_reg * I during Cholesky factorization.
bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * jnp.mean(jnp.trace(k_train_train)).
int]]) – f(x_train) axes such that k_train_train, k_test_train`[, and `k_test_test] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.
A function of signature predict_fn(get, k_test_train, k_test_test) computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.