- neural_tangents.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,), **kernel_fn_train_train_kwargs)
Predicts the gaussian embedding induced by gradient descent on MSE loss.
This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if
kernel_fnis the kernel function of the infinite-width network. Note that
kernel_fncan in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel).
Note that first invocation of the returned
predict_fnwill be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (
tis a scalar or an array) or Cholesky factorization (
kernel_fn(x_train, None, get)is performed and cached for future invocations (or both, if the function is called on both finite and infinite (
MonteCarloKernelFn]) – A kernel function that computes NNGP and/or NTK. Must have a signature
kernel_fn(x1, x2, get, **kernel_fn_kwargs)and return a
Kernelobject or a
ntkattributes. Therefore, it can be an
AnalyticKernelFn, but also a
MonteCarloKernelFn, or an
nt.empirical_nngp_fn, since the latter two do not accept a
getargument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel.
ndarray) – training inputs.
ndarray) – training targets.
float) – learning rate, step size.
float) – a scalar representing the strength of the diagonal regularization for
kernel_fn(x_train, None, get), i.e. computing
kernel_fn(x_train, None, get) + diag_reg * Iduring Cholesky factorization or eigendecomposition.
diag_regto represent regularization in absolute units,
diag_reg * np.mean(np.trace(kernel_fn(x_train, None, get))).
f(x_train)axes such that
kernel_fn(x_train, None, get),
kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along
trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,
trace_axes=()will yield most accurate result.
**kernel_fn_train_train_kwargs – optional keyword arguments passed to
kernel_fn. For train-train kernel, these are passed to
kernel_fnwithout changes. For test-test kernel, they are passed to
kernel_fn, unless overwritten by a similar
**kernel_fn_test_test_kwargsarguments passed to the
predict_fnfunction call. Finally, for test-train kernel, values that are tuples of arrays (destined for calls of the finite-width network on training and testing data) will be tuples of values combined from
**kernel_fn_test_test_kwargs, and all other values must match.
A function with signature
predict_fn(t, x_test, get, compute_cov)returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on
t, in the