neural_tangents.predict.gradient_descent_mse_ensemble
- neural_tangents.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(-1,), **kernel_fn_train_train_kwargs)[source]
Predicts the gaussian embedding induced by gradient descent on MSE loss.
This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if
kernel_fn
is the kernel function of the infinite-width network. Note thatkernel_fn
can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel).Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (t
is a scalar or an array) or Cholesky factorization (t=None
) ofkernel_fn(x_train, None, get)
is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None
) times).- Parameters:
kernel_fn (
Union
[AnalyticKernelFn
,EmpiricalKernelFn
,EmpiricalGetKernelFn
,MonteCarloKernelFn
]) – A kernel function that computes NNGP and/or NTK. Must have a signaturekernel_fn(x1, x2, get, **kernel_fn_kwargs)
and return aKernel
object or anamedtuple
withnngp
and/orntk
attributes. Therefore, it can be anAnalyticKernelFn
, but also aMonteCarloKernelFn
, or anEmpiricalKernelFn
(but onlynt.empirical_kernel_fn
and notnt.empirical_ntk_fn
ornt.empirical_nngp_fn
, since the latter two do not accept aget
argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel.x_train (
ndarray
) – training inputs.y_train (
ndarray
) – training targets.learning_rate (
float
) – learning rate, step size.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization forkernel_fn(x_train, None, get)
, i.e. computingkernel_fn(x_train, None, get) + diag_reg * I
during Cholesky factorization or eigendecomposition.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * jnp.mean(jnp.trace(kernel_fn(x_train, None, get)))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatkernel_fn(x_train, None, get)
,kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)
] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.**kernel_fn_train_train_kwargs – optional keyword arguments passed to
kernel_fn
. For train-train kernel, these are passed tokernel_fn
without changes. For test-test kernel, they are passed tokernel_fn
, unless overwritten by a similar**kernel_fn_test_test_kwargs
arguments passed to thepredict_fn
function call. Finally, for test-train kernel, values that are tuples of arrays (destined for calls of the finite-width network on training and testing data) will be tuples of values combined from**kernel_fn_train_train_kwargs
and**kernel_fn_test_test_kwargs
, and all other values must match.
- Returns:
A function with signature
predict_fn(t, x_test, get, compute_cov)
returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs onx_test
at time[s]t
, in theget
regime ("nngp"
,"ntk"
, or("nngp", "ntk")
).