neural_tangents.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,), **kernel_fn_train_train_kwargs)[source]

Predicts the gaussian embedding induced by gradient descent on MSE loss.

This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if kernel_fn is the kernel function of the infinite-width network. Note that kernel_fn can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel).

Note that first invocation of the returned predict_fn will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (t is a scalar or an array) or Cholesky factorization (t=None) of kernel_fn(x_train, None, get) is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None) times).

  • kernel_fn (Union[AnalyticKernelFn, EmpiricalKernelFn, EmpiricalGetKernelFn, MonteCarloKernelFn]) – A kernel function that computes NNGP and/or NTK. Must have a signature kernel_fn(x1, x2, get, **kernel_fn_kwargs) and return a Kernel object or a namedtuple with nngp and/or ntk attributes. Therefore, it can be an AnalyticKernelFn, but also a MonteCarloKernelFn, or an EmpiricalKernelFn (but only nt.empirical_kernel_fn and not nt.empirical_ntk_fn or nt.empirical_nngp_fn, since the latter two do not accept a get argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel.

  • x_train (ndarray) – training inputs.

  • y_train (ndarray) – training targets.

  • learning_rate (float) – learning rate, step size.

  • diag_reg (float) – a scalar representing the strength of the diagonal regularization for kernel_fn(x_train, None, get), i.e. computing kernel_fn(x_train, None, get) + diag_reg * I during Cholesky factorization or eigendecomposition.

  • diag_reg_absolute_scale (bool) – True for diag_reg to represent regularization in absolute units, False to be diag_reg * jnp.mean(jnp.trace(kernel_fn(x_train, None, get))).

  • trace_axes (Union[int, Sequence[int]]) – f(x_train) axes such that kernel_fn(x_train, None, get), kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal along trace_axes. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network, trace_axes=() will yield most accurate result.

  • **kernel_fn_train_train_kwargs – optional keyword arguments passed to kernel_fn. For train-train kernel, these are passed to kernel_fn without changes. For test-test kernel, they are passed to kernel_fn, unless overwritten by a similar **kernel_fn_test_test_kwargs arguments passed to the predict_fn function call. Finally, for test-train kernel, values that are tuples of arrays (destined for calls of the finite-width network on training and testing data) will be tuples of values combined from **kernel_fn_train_train_kwargs and **kernel_fn_test_test_kwargs, and all other values must match.


A function with signature predict_fn(t, x_test, get, compute_cov) returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs on x_test at time[s] t, in the get regime ("nngp", "ntk", or ("nngp", "ntk")).