Predict¶
Functions to make predictions on the train/test set using NTK/NNGP.
Most functions in this module accept training data as inputs and return a new
function predict_fn
that computes predictions on the train set / given test
set / timesteps.
WARNING: trace_axes
parameter supplied to prediction functions must match the
respective parameter supplied to the function used to compute the kernel.
Namely, this is the same trace_axes
used to compute the empirical kernel
(utils/empirical.py
; diagonal_axes
must be ()
), or channel_axis
in the
output of the top layer used to compute the closed-form kernel (stax.py
; note
that closed-form kernels currently only support a single channel_axis
).
- class neural_tangents.predict.Gaussian(mean: jax.numpy.ndarray, covariance: jax.numpy.ndarray)[source]¶
A
(mean, covariance)
convenience namedtuple.- property covariance¶
Alias for field number 1
- property mean¶
Alias for field number 0
- class neural_tangents.predict.ODEState(fx_train=None, fx_test=None, qx_train=None, qx_test=None)[source]¶
ODE state dataclass holding outputs and auxiliary variables.
- neural_tangents.predict.gp_inference(k_train_train, y_train, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1))[source]¶
Compute the mean and variance of the ‘posterior’ of NNGP/NTK/NTKGP.
NNGP - the exact posterior of an infinitely wide Bayesian NN. NTK - exact distribution of an infinite ensemble of infinitely wide NNs trained with gradient flow for infinite time. NTKGP - posterior of a GP (Gaussian process) with the NTK covariance (see https://arxiv.org/abs/2007.05864 for how this can correspond to infinite ensembles of infinitely wide NNs as well).
Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as a Cholesky factorization ofk_train_train.nngp
ork_train_train.ntk
(or both) is performed and cached for future invocations.- Parameters
k_train_train – train-train kernel. Can be (a)
np.ndarray
, (b)Kernel
namedtuple, (c)Kernel
object. Must contain the necessarynngp
and/orntk
kernels for arguments provided to the returnedpredict_fn
function. For example, if you request to compute posterior test [only] NTK covariance in futurepredict_fn
invocations,k_train_train
must contain bothntk
andnngp
kernels.y_train (
ndarray
) – train targets.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization fork_train_train
, i.e. computingk_train_train + diag_reg * I
during Cholesky factorization.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * np.mean(np.trace(k_train_train))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
,k_test_train`[, and `k_test_test
] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.
- Returns
A function of signature
predict_fn(get, k_test_train, k_test_test)
computing ‘posterior’ Gaussian distribution (mean or mean and covariance) on a given test set.
- neural_tangents.predict.gradient_descent(loss, k_train_train, y_train, learning_rate=1.0, momentum=None, trace_axes=(- 1))[source]¶
Predicts the outcome of function space training using gradient descent.
Uses an ODE solver. If
momentum != None
, solves a continuous-time version of gradient descent with momentum (note: this case uses standard momentum as opposed to Nesterov momentum).Solves the function space ODE for [momentum] gradient descent with a given
loss
(detailed in [*]) given a Neural Tangent Kernel[s] over the dataset[s] at arbitrary time[s] (step[s])t
. Note that for gradient descentabsolute_time = learning_rate * t
and the scales of the learning rate and query step[s]t
are interchangeable. However, the momentum gradient descent ODE is solved in the units oflearning_rate**0.5
, and thereforeabsolute_time = learning_rate**0.5 * t
, hence thelearning_rate
and training time[s] (step[s])t
scales are not interchangeable.[*] https://arxiv.org/abs/1902.06720
Example
>>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e-7 >>> learning_rate = 1e-2 >>> momentum = 0.9 >>> >>> kernel_fn = empirical_ntk_fn(f) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> from jax.experimental import stax >>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat) >>> predict_fn = predict.gradient_descent(cross_entropy, k_train_train, >>> y_train, learning_rate, momentum) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train)
- Parameters
loss (
Callable
[[ndarray
,ndarray
],float
]) – a loss function whose signature isloss(f(x_train), y_train)
. Note: the loss function should treat the batch and output dimensions symmetrically.k_train_train (
ndarray
) – kernel on the training data. Must have the shape ofzip(y_train.shape, y_train.shape)
withtrace_axes
absent.y_train (
ndarray
) – targets for the training data.learning_rate (
float
) – learning rate, step size.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.
- Return type
Callable
[[Union
[None
,int
,float
,ndarray
],Union
[None
,int
,float
,ndarray
,ODEState
],Union
[None
,int
,float
,ndarray
],Optional
[ndarray
]],Union
[ndarray
,Tuple
[ndarray
,ndarray
],ODEState
]]- Returns
A function that returns output train [and test] set[s] predictions at time[s]
t
.
- neural_tangents.predict.gradient_descent_mse(k_train_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1))[source]¶
Predicts the outcome of function space gradient descent training on MSE.
Solves in closed form for the continuous-time version of gradient descent.
Uses the closed-form solution for gradient descent on an MSE loss in function space detailed in [,*] given a Neural Tangent or Neural Network Gaussian Process Kernel over the dataset. Given NNGP or NTK, this function will return a function that predicts the time evolution for function space points at arbitrary time[s] (training step[s])
t
. Note that these time[s] (step[s]) are continuous and are interpreted in units of thelearning_rate
soabsolute_time = learning_rate * t
, and the scales oflearning_rate
andt
are interchangeable.Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as either eigendecomposition (t
is a scalar or an array) or Cholesky factorization (t=None
) ofk_train_train
is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None
) times).[*] https://arxiv.org/abs/1806.07572 [**] https://arxiv.org/abs/1902.06720
Example
>>> from neural_tangents import empirical_ntk_fn >>> from neural_tangents import predict >>> >>> t = 1e-7 >>> kernel_fn = empirical_ntk_fn(f) >>> k_train_train = kernel_fn(x_train, None, params) >>> k_test_train = kernel_fn(x_test, x_train, params) >>> >>> predict_fn = predict.gradient_descent_mse(k_train_train, y_train) >>> >>> fx_train_0 = f(params, x_train) >>> fx_test_0 = f(params, x_test) >>> >>> fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, >>> k_test_train)
- Parameters
k_train_train (
ndarray
) – kernel on the training data. Must have the shape ofzip(y_train.shape, y_train.shape)
withtrace_axes
absent.y_train (
ndarray
) – targets for the training data.learning_rate (
float
) – learning rate, step size.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization fork_train_train
, i.e. computingk_train_train + diag_reg * I
during Cholesky factorization or eigendecomposition.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * np.mean(np.trace(k_train_train))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatk_train_train
lacks these pairs of dimensions and is to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.
- Return type
Callable
[[Union
[None
,int
,float
,ndarray
],Union
[None
,int
,float
,ndarray
],Union
[None
,int
,float
,ndarray
],Optional
[ndarray
]],Union
[ndarray
,Tuple
[ndarray
,ndarray
]]]- Returns
A function of signature
predict_fn(t, fx_train_0, fx_test_0, k_test_train)
that returns output train [and test] set[s] predictions at time[s]t
.
- neural_tangents.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, learning_rate=1.0, diag_reg=0.0, diag_reg_absolute_scale=False, trace_axes=(- 1), **kernel_fn_train_train_kwargs)[source]¶
Predicts the gaussian embedding induced by gradient descent on MSE loss.
This is equivalent to an infinite ensemble of infinite-width networks after marginalizing out the initialization, if
kernel_fn
is the kernel function of the infinite-width network. Note thatkernel_fn
can in principle also be an empirical / Monte Carlo finite-width kernel function, but in this case the returned output will not have a simple interpretation (unless these functions are used to approximate the infinite-width kernel).Note that first invocation of the returned
predict_fn
will be slow and allocate a lot of memory for its whole lifetime, as the kernel computation, and either eigendecomposition (t
is a scalar or an array) or Cholesky factorization (t=None
) ofkernel_fn(x_train, None, get)
is performed and cached for future invocations (or both, if the function is called on both finite and infinite (t=None
) times).- Parameters
kernel_fn (
Union
[Callable
[[Union
[List
[Kernel
],Tuple
[Kernel
, …],Kernel
,List
[ndarray
],Tuple
[ndarray
, …],ndarray
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,None
],Union
[Tuple
[str
, …],str
,None
]],Union
[List
[Kernel
],Tuple
[Kernel
, …],Kernel
,List
[ndarray
],Tuple
[ndarray
, …],ndarray
]],Callable
[[Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,None
],Union
[Tuple
[str
, …],str
,None
],Any
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
]],Callable
[[Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,None
],Union
[Tuple
[str
, …],str
,None
]],Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
,Generator
[Union
[List
[ndarray
],Tuple
[ndarray
, …],ndarray
],None
,None
]]]]) – A kernel function that computes NNGP and/or NTK. Must have a signaturekernel_fn(x1, x2, get, **kernel_fn_kwargs)
and return aKernel
object or anamedtuple
withnngp
and/orntk
attributes. Therefore, it can be anAnalyticKernelFn
, but also aMonteCarloKernelFn
, or anEmpiricalKernelFn
(but onlynt.empirical_kernel_fn
and notnt.empirical_ntk_fn
orntk.empirical_nngp_fn
, since the latter two do not accept aget
argument). Note that for meaningful outputs, the kernel function must represent or at least approximate the infinite-width kernel.x_train (
ndarray
) – training inputs.y_train (
ndarray
) – training targets.learning_rate (
float
) – learning rate, step size.diag_reg (
float
) – a scalar representing the strength of the diagonal regularization forkernel_fn(x_train, None, get)
, i.e. computingkernel_fn(x_train, None, get) + diag_reg * I
during Cholesky factorization or eigendecomposition.diag_reg_absolute_scale (
bool
) –True
fordiag_reg
to represent regularization in absolute units,False
to bediag_reg * np.mean(np.trace(kernel_fn(x_train, None, get)))
.trace_axes (
Union
[int
,Sequence
[int
]]) –f(x_train)
axes such thatkernel_fn(x_train, None, get)
,kernel_fn(x_test, x_train, get)`[, and `kernel_fn(x_test, None, get)
] lack these pairs of dimensions and are to be interpreted as \(\Theta \otimes I\), i.e. block-diagonal alongtrace_axes
. These can can be specified either to save space and compute, or to even improve approximation accuracy of the infinite-width or infinite-samples limit, since in in these limits the covariance along channel / feature / logit axes indeed converges to a constant-diagonal matrix. However, if you target linearized dynamics of a specific finite-width network,trace_axes=()
will yield most accurate result.**kernel_fn_train_train_kwargs – optional keyword arguments passed to
kernel_fn
. For train-train kernel, these are passed tokernel_fn
without changes. For test-test kernel, they are passed tokernel_fn
, unless overwritten by a similar**kernel_fn_test_test_kwargs
arguments passed to thepredict_fn
function call. Finally, for test-train kernel, values that are tuples of arrays (destined for calls of the finite-width network on training and testing data) will be tuples of values combined from**kernel_fn_train_train_kwargs
and**kernel_fn_test_test_kwargs
, and all other values must match.
- Returns
A function with signature
predict_fn(t, x_test, get, compute_cov)
returning either mean or mean and covariance of the infinite ensemble of infinite-width networks outputs onx_test
at time[s]t
, in theget
regime ("nngp"
,"ntk"
, or("nngp", "ntk")
).
- neural_tangents.predict.max_learning_rate(ntk_train_train, y_train_size=None, momentum=0.0, eps=1e-12)[source]¶
Computes the maximal feasible learning rate for infinite width NNs.
The network is assumed to be trained using mini-/full-batch GD + momentum with mean squared loss. The loss is assumed to have the form
1/(2 * batch_size * output_size) \|f(train_x) - train_y\|^2
. For vanilla SGD (i.e.momentum = 0
) the maximal feasible learning rate is the largest\eta
such that the operator(I - \eta / (batch_size * output_size) * NTK)
is a contraction, which is2 * batch_size * output_size * lambda_max(NTK)
. Whenmomentum > 0
, we use2 * (1 + momentum) * batch_size * output_size * lambda_max(NTK)
(seeThe Dynamics of Momentum
section in https://distill.pub/2017/momentum/).- Parameters
ntk_train_train (
ndarray
) – analytic or empirical NTK on the training data.y_train_size (
Optional
[int
]) – total training set output size, i.e.f(x_train).size == y_train.size
. Ifoutput_size=None
it is inferred fromntk_train_train.shape
assumingtrace_axes=()
.momentum – The
momentum
for momentum optimizers.eps (
float
) – a float to avoid zero divisor.
- Return type
- Returns
The maximal feasible learning rate for infinite width NNs.