Predict¶
Functions to make predictions on the train/test set using NTK/NNGP.
Most functions in this module accept training data as inputs and return a new
function predict_fn
that computes predictions on the train set / given test
set / timesteps.
WARNING: trace_axes
parameter supplied to prediction functions must match the
respective parameter supplied to the function used to compute the kernel.
Namely, this is the same trace_axes
used to compute the empirical kernel
(utils/empirical.py
; diagonal_axes
must be ()
), or channel_axis
in the
output of the top layer used to compute the closedform kernel (stax.py
; note
that closedform kernels currently only support a single channel_axis
).

class
neural_tangents.predict.
Gaussian
(mean: jax.numpy.ndarray, covariance: jax.numpy.ndarray)[source]¶ A
(mean, covariance)
convenience namedtuple.
property
covariance
¶ Alias for field number 1

property
mean
¶ Alias for field number 0

property

class
neural_tangents.predict.
ODEState
(fx_train=None, fx_test=None, qx_train=None, qx_test=None)[source]¶ ODE state dataclass holding outputs and auxiliary variables.

neural_tangents.predict.
gp_inference
(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=( 1))[source]¶ Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.
NNGP  the exact posterior of an infinitely wide Bayesian NN. NTK  exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP  posterior of a GP (Gaussian process) with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this can correspond to infinite ensembles of infinitely wide NNs as well).
Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization ofk_train_train.nngp
ork_train_train.ntk
(or both) is performed and cached for future invocations. Parameters
k_train_train – traintrain kernel. Can be (a)
np.ndarray
, (b)Kernel
namedtuple, (c)Kernel
object. Must contain the necessarynngp
and/orntk
kernels for arguments provided to the returnedpredict_fn
function. For example, if you request to compute posterior test [only] NTK covariance in futurepredict_fn
invocations,k_train_train
must contain bothntk
andnngp
kernels.y_train (
ndarray
) – train targets.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization fork_train_train
, i.e. computingk_train_train + diag_reg * I
during Cholesky factorization.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * np.mean(np.trace(k_train_train))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
,k_test_train`[, and `k_test_test
] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. blockdiagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinitewidth or infinitesamples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constantdiagonal matrix. However, if you target linearized dynamics of a specific finitewidth network,trace_axes=()
will yield most accurate result.
 Returns
A function of signature
predict_fn(get, k_test_train, k_test_test)
computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.

neural_tangents.predict.
gradient_descent
(loss, k_train_train, y_train, learning_rate=1.0, momentum=None, trace_axes=( 1))[source]¶ Predicts the outcome of function space training using gradient descent.
Uses an ODE solver. If
momentum != None
, solves a continuoustime version of gradient descent with momentum (note: this case uses standard momentum as opposed to Nesterov momentum).Solves the function space ODE for [momentum] gradient descent with a given
loss
(detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s])t
. Note that for gradient descentabsolute_time = learning_rate * t
and the scales of the learning rate and query step[s]t
are interchangeable. However, the momentum gradient descent ODE is solved in the units oflearning_rate**0.5
, and thereforeabsolute_time = learning_rate**0.5 * t
, hence thelearning_rate
and training time[s] (step[s])t
scales are not interchangeable.[*] https://arxiv.org/abs/1902.06720
Example
>>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e7 >>> learning_rate = 1e2 >>> momentum = 0.9 >>> >>> kernel_fn = empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> from jax.experimental import stax >>> cross_entropy = lambda fx, y_hat: np.mean(stax.logsoftmax(fx) * y_hat) >>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train, >>> y_train, learning_rate, momentum) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train)
 Parameters
loss (
Callable
[[ndarray
,ndarray
],float
]) – a loss function whose signature isloss(f(x_train), y_train)
. Note: the loss function should treat the batch and output dimensions symmetrically.k_train_train (
ndarray
) – kernel on the training data. Must have the shape ofzip(y_train.shape, y_train.shape)
withtrace_axes
absent.y_train (
ndarray
) – targets for the training data.learning_rate (
float
) – learning rate, step size.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. blockdiagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinitewidth or infinitesamples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constantdiagonal matrix. However, if you target linearized dynamics of a specific finitewidth network,trace_axes=()
will yield most accurate result.
 Return type
Callable
[[Union
[None
,int
,float
,ndarray
],Union
[None
,int
,float
,ndarray
,ODEState
],Union
[None
,int
,float
,ndarray
],Optional
[ndarray
]],Union
[ndarray
,Tuple
[ndarray
,ndarray
],ODEState
]] Returns
A function that returns output train [and test] set[s] predictions at time[s]
t
.

neural_tangents.predict.
gradient_descent_mse
(k_train_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=( 1))[source]¶ Predicts the outcome of function space gradient descent training on MSE.
Solves in closed form for the continuoustime version of gradient descent.
Uses the closedform solution for gradient descent on an MSE loss in function space detailed in [,*] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s])
t
. Note that these time[s] (step[s]) are continuous and are interpreted in units of thelearning_rate
soabsolute_time = learning_rate * t
, and the scales oflearning_rate
andt
are interchangeable.Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (t
is a scalar or an array) or Cholesky factorization (t=None
) ofk_train_train
is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None
) times).[*] https://arxiv.org/abs/1806.07572 [**] https://arxiv.org/abs/1902.06720
Example
>>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e7 >>> kernel_fn = empirical_ntk_fn(f) >>> k_train_train = kernel_fn(x_train, None, params) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> predict_fn = predict.gradient_descent_mse(k_train_train, y_train) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train)
 Parameters
k_train_train (
ndarray
) – kernel on the training data. Must have the shape ofzip(y_train.shape, y_train.shape)
withtrace_axes
absent.y_train (
ndarray
) – targets for the training data.learning_rate (
float
) – learning rate, step size.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization fork_train_train
, i.e. computingk_train_train + diag_reg * I
during Cholesky factorization or eigendecomposition.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * np.mean(np.trace(k_train_train))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. blockdiagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinitewidth or infinitesamples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constantdiagonal matrix. However, if you target linearized dynamics of a specific finitewidth network,trace_axes=()
will yield most accurate result.
 Return type
Callable
[[Union
[None
,int
,float
,ndarray
],Union
[None
,int
,float
,ndarray
],Union
[None
,int
,float
,ndarray
],Optional
[ndarray
]],Union
[ndarray
,Tuple
[ndarray
,ndarray
]]] Returns
A function of signature
predict_fn(t, fx_train_0, fx_test_0, k_test_train)
that returns output train [and test] set[s] predictions at time[s]t
.

neural_tangents.predict.
gradient_descent_mse_ensemble
(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=( 1), **kernel_fn_train_train_kwargs)[source]¶ Predicts the gaussian embedding induced by gradient descent on MSE loss.
This is equivalent to an infinite ensemble of infinitewidth networks after marginalizing out the initialization, if
kernel_fn
is the kernel function of the infinitewidth network. Note thatkernel_fn
can in principle also be an empirical / Monte Carlo finitewidth kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinitewidth kernel).Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (t
is a scalar or an array) or Cholesky factorization (t=None
) ofkernel_fn(x_train, None, get)
is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None
) times). Parameters
kernel_fn (
Union
[Callable
[[Union
[List
[Kernel
],Tuple
[Kernel
, …],Kernel
,List
[ndarray
],Tuple
[ndarray
, …],ndarray
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,None
],Union
[Tuple
[str
, …],str
,None
]],Union
[List
[Kernel
],Tuple
[Kernel
, …],Kernel
,List
[ndarray
],Tuple
[ndarray
, …],ndarray
]],Callable
[[Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,None
],Union
[Tuple
[str
, …],str
,None
],Any
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
]],Callable
[[Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,None
],Union
[Tuple
[str
, …],str
,None
]],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,Generator
[Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
],None
,None
]]]]) – A kernel function that computes NNGP and/or NTK. Must have a signaturekernel_fn(x1, x2, get, **kernel_fn_kwargs)
and return aKernel
object or anamedtuple
withnngp
and/orntk
attributes. Therefore, it can be anAnalyticKernelFn
, but also aMonteCarloKernelFn
, or anEmpiricalKernelFn
(but onlynt.empirical_kernel_fn
and notnt.empirical_ntk_fn
orntk.empirical_nngp_fn
, since the latter two do not accept aget
argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinitewidth kernel.x_train (
ndarray
) – training inputs.y_train (
ndarray
) – training targets.learning_rate (
float
) – learning rate, step size.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization forkernel_fn(x_train, None, get)
, i.e. computingkernel_fn(x_train, None, get) + diag_reg * I
during Cholesky factorization or eigendecomposition.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * np.mean(np.trace(kernel_fn(x_train, None, get)))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatkernel_fn(x_train, None, get)
,kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)
] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. blockdiagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinitewidth or infinitesamples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constantdiagonal matrix. However, if you target linearized dynamics of a specific finitewidth network,trace_axes=()
will yield most accurate result.**kernel_fn_train_train_kwargs – optional keyword arguments passed to
kernel_fn
. For traintrain kernel, these are passed tokernel_fn
without changes. For testtest kernel, they are passed tokernel_fn
, unless overwritten by a similar**kernel_fn_test_test_kwargs
arguments passed to thepredict_fn
function call. Finally, for testtrain kernel, values that are tuples of arrays (destined for calls of the finitewidth network on training and testing data) will be tuples of values combined from**kernel_fn_train_train_kwargs
and**kernel_fn_test_test_kwargs
, and all other values must match.
 Returns
A function with signature
predict_fn(t, x_test, get, compute_cov)
returning either mean or mean and covariance of the infinite ensemble of infinitewidth networks outputs onx_test
at time[s]t
, in theget
regime ("nngp"
,"ntk"
, or("nngp", "ntk")
).

neural_tangents.predict.
max_learning_rate
(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e12)[source]¶ Computes the maximal feasible learning rate for infinite width NNs.
The network is assumed to be trained using mini/fullbatch GD + momentum with mean squared loss. The loss is assumed to have the form
1/(2 * batch_size * output_size) \f(train_x)  train_y\^2
. For vanilla SGD (i.e.momentum = 0
) the maximal feasible learning rate is the largest\eta
such that the operator(I  \eta / (batch_size * output_size) * NTK)
is a contraction, which is2 * batch_size * output_size * lambda_max(NTK)
. Whenmomentum > 0
, we use2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)
(seeThe Dynamics of Momentum
section in https://distill.pub/2017/momentum/). Parameters
ntk_train_train (
ndarray
) – analytic or empirical NTK on the training data.y_train_size (
Optional
[int
]) – total training set output size, i.e.f(x_train).size == y_train.size
. Ifoutput_size=None
it is inferred fromntk_train_train.shape
assumingtrace_axes=()
.momentum – The
momentum
for momentum optimizers.eps (
float
) – a float to avoid zero divisor.
 Return type
 Returns
The maximal feasible learning rate for infinite width NNs.