Empirical

Compute empirical NNGP and NTK; approximate functions via Taylor series.

All functions in this module are applicable to any JAX functions of proper signatures (not only those from nt.stax).

NNGP and NTK are computed using empirical_nngp_fn, empirical_ntk_fn, or empirical_kernel_fn (for both). The kernels have a very specific output shape convention that may be unexpected. Further, NTK has multiple implementations that may perform differently depending on the task. Please read individual functions’ docstrings.

Example

>>>  from jax import random
>>>  import neural_tangents as nt
>>>  from neural_tangents import stax
>>>
>>>  key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>>  x_train = random.normal(key1, (20, 32, 32, 3))
>>>  y_train = random.uniform(key1, (20, 10))
>>>  x_test = random.normal(key2, (5, 32, 32, 3))
>>>
>>>  # A narrow CNN.
>>>  init_fn, f, _ = stax.serial(
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Relu(),
>>>      stax.Conv(32, (3, 3)),
>>>      stax.Flatten(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  _, params = init_fn(key3, x_train.shape)
>>>
>>>  # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>>  # network is iid along the batch axis, no BatchNorm. Use default
>>>  # `implementation=1` since the network has few trainable parameters.
>>>  kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(-1,),
>>>                                     vmap_axes=0, implementation=1)
>>>
>>>  # (5, 20) np.ndarray test-train NNGP/NTK
>>>  nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>>  ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>>
>>>  # Full kernel: not reducing over logits.
>>>  kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0)
>>>
>>>  # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple.
>>>  k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>>  # A wide FCN with lots of parameters
>>>  init_fn, f, _ = stax.serial(
>>>      stax.Flatten(),
>>>      stax.Dense(1024),
>>>      stax.Relu(),
>>>      stax.Dense(1024),
>>>      stax.Relu(),
>>>      stax.Dense(10)
>>>  )
>>>
>>>  _, params = init_fn(key3, x_train.shape)
>>>
>>>  # Use implicit differentiation in NTK: `implementation=2` to reduce
>>>  # memory cost, since the network has many trainable parameters.
>>>  ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2)
>>>
>>>  # (5, 5) np.ndarray test-test NTK
>>>  ntk_test_train = ntk_fn(x_test, None, params)
>>>
>>>  # Compute only output variances:
>>>  nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>>
>>>  # (20,) np.ndarray train-train diagonal NNGP
>>>  nngp_train_train_diag = nngp_fn(x_train, None, params)
neural_tangents.utils.empirical.empirical_kernel_fn(f, trace_axes=(- 1,), diagonal_axes=(), vmap_axes=None, implementation=1)[source]

Returns a function that computes single draws from NNGP and NT kernels.

WARNING: resulting kernel shape is nearly zip(f(x1).shape, f(x2).shape) subject to trace_axes and diagonal_axes parameters, which make certain assumptions about the outputs f(x) that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both trace_axes=() and diagonal_axes=() to obtain the kernel exactly of shape zip(f(x1).shape, f(x2).shape).

For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.

Parameters
  • f (Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]]) – the function whose NTK we are computing. f should have the signature f(params, inputs, **kwargs) and should return an np.ndarray outputs.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Optional[Tuple[Union[List[int], Tuple[int, …], int, None], Union[List[int], Tuple[int, …], int, None], Dict[str, Union[List[int], Tuple[int, …], int, None]]]]) –

    applicable only to NTK.

    A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).

    For example if _, f, _ = nt.stax.Aggregate(), f is called via f(params, x, pattern=pattern). By default, inputs x, patterns pattern, and outputs of f are all batched along the leading 0 dimension, and each output f(params, x, pattern=pattern)[i] only depends on the inputs x[i] and pattern[i]. In this case, we can pass vmap_axes=(0, 0, dict(pattern=0) to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively.

    This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (int) –

    applicable only to NTK.

    1 or 2.

    1 directly instantiates Jacobians and computes their outer product.

    2 uses implicit differentiation to avoid instantiating whole Jacobians at once. The implicit kernel is derived by observing that: \(\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)\), i.e. a linear function \([J(X_1) J(X_2)^T]\) applied to an identity matrix \(I\). This allows the computation of the NTK to be phrased as: \(a(v) = J(X_2)^T v\), which is computed by a vector-Jacobian product; \(b(v) = J(X_1) a(v)\) which is computed by a Jacobian-vector product; and \(\Theta = [b(v)] / d[v^T](I)\) which is computed via a vmap of \(b(v)\) over columns of the identity matrix \(I\).

    It is best to benchmark each method on your specific task. We suggest using 1 unless you get OOMs due to large number of trainable parameters, otherwise - 2.

Return type

Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Union[Tuple[str, …], str, None], Any], Union[List[ndarray], Tuple[ndarray, …], ndarray]]

Returns

A function to draw a single sample the NNGP and NTK empirical kernels of a given network f.

neural_tangents.utils.empirical.empirical_nngp_fn(f, trace_axes=(- 1,), diagonal_axes=())[source]

Returns a function to draw a single sample the NNGP of a given network f.

The Neural Network Gaussian Process (NNGP) kernel is defined as \(f(X_1) f(X_2)^T\), i.e. the outer product of the function outputs.

WARNING: resulting kernel shape is nearly zip(f(x1).shape, f(x2).shape) subject to trace_axes and diagonal_axes parameters, which make certain assumptions about the outputs f(x) that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both trace_axes=() and diagonal_axes=() to obtain the kernel exactly of shape zip(f(x1).shape, f(x2).shape).

For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.

Parameters
  • f (Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]]) – the function whose NNGP we are computing. f should have the signature f(params, inputs[, rng]) and should return an np.ndarray outputs.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

Return type

Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Any], Union[List[ndarray], Tuple[ndarray, …], ndarray]]

Returns

A function to draw a single sample the NNGP of a given network f.

neural_tangents.utils.empirical.empirical_ntk_fn(f, trace_axes=(- 1,), diagonal_axes=(), vmap_axes=None, implementation=1)[source]

Returns a function to draw a single sample the NTK of a given network f.

The Neural Tangent Kernel is defined as \(J(X_1) J(X_2)^T\) where \(J\) is the Jacobian \(df/dparams\) of shape full_output_shape + params.shape.

For best performance: 1) pass x2=None if x1 == x2; 2) prefer square batches (i.e `x1.shape == x2.shape); 3) make sure to set vmap_axes correctly. 4) try different implementation values.

WARNING: Resulting kernel shape is nearly zip(f(x1).shape, f(x2).shape) subject to trace_axes and diagonal_axes parameters, which make certain assumptions about the outputs f(x) that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both trace_axes=() and diagonal_axes=() to obtain the kernel exactly of shape zip(f(x1).shape, f(x2).shape).

For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.

Parameters
  • f (Callable[[Any, Union[List[ndarray], Tuple[ndarray, …], ndarray]], Union[List[ndarray], Tuple[ndarray, …], ndarray]]) – the function whose NTK we are computing. f should have the signature f(params, inputs[, rng]) and should return an np.ndarray outputs.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Optional[Tuple[Union[List[int], Tuple[int, …], int, None], Union[List[int], Tuple[int, …], int, None], Dict[str, Union[List[int], Tuple[int, …], int, None]]]]) –

    A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).

    For example if _, f, _ = nt.stax.Aggregate(), f is called via f(params, x, pattern=pattern). By default, inputs x, patterns pattern, and outputs of f are all batched along the leading 0 dimension, and each output f(params, x, pattern=pattern)[i] only depends on the inputs x[i] and pattern[i]. In this case, we can pass vmap_axes=(0, 0, dict(pattern=0) to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively.

    This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (int) –

    1 or 2.

    1 directly instantiates Jacobians and computes their outer product.

    2 uses implicit differentiation to avoid instantiating whole Jacobians at once. The implicit kernel is derived by observing that: \(\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)\), i.e. a linear function \([J(X_1) J(X_2)^T]\) applied to an identity matrix \(I\). This allows the computation of the NTK to be phrased as: \(a(v) = J(X_2)^T v\), which is computed by a vector-Jacobian product; \(b(v) = J(X_1) a(v)\) which is computed by a Jacobian-vector product; and \(\Theta = [b(v)] / d[v^T](I)\) which is computed via a vmap of \(b(v)\) over columns of the identity matrix \(I\).

    It is best to benchmark each method on your specific task. We suggest using 1 unless you get OOMs due to large number of trainable parameters, otherwise - 2.

Return type

Callable[[Union[List[ndarray], Tuple[ndarray, …], ndarray], Union[List[ndarray], Tuple[ndarray, …], ndarray, None], Any], Union[List[ndarray], Tuple[ndarray, …], ndarray]]

Returns

A function ntk_fn that computes the empirical ntk.

neural_tangents.utils.empirical.linearize(f, params)[source]

Returns a function f_lin, the first order taylor approximation to f.

Example

>>> # Compute the MSE of the first order Taylor series of a function.
>>> f_lin = linearize(f, params)
>>> mse = np.mean((f(new_params, x) - f_lin(new_params, x)) ** 2)
Parameters
  • f (Callable[…, Any]) – A function that we would like to linearize. It should have the signature f(params, *args, **kwargs) where params is a PyTree and f should return a PyTree.

  • params (Any) – Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations.

Return type

Callable[…, Any]

Returns

A function f_lin(new_params, *args, **kwargs) whose signature is the same as f. Here f_lin implements the first-order taylor series of f about params.

neural_tangents.utils.empirical.taylor_expand(f, params, degree)[source]

Returns a function f_tayl, Taylor approximation to f of order degree.

Example

>>> # Compute the MSE of the third order Taylor series of a function.
>>> f_tayl = taylor_expand(f, params, 3)
>>> mse = np.mean((f(new_params, x) - f_tayl(new_params, x)) ** 2)
Parameters
  • f (Callable[…, Any]) – A function that we would like to Taylor expand. It should have the signature f(params, *args, **kwargs) where params is a PyTree, and f returns a PyTree.

  • params (Any) – Initial parameters to the function that we would like to take the Taylor series about. This can be any structure that is compatible with the JAX tree operations.

  • degree (int) – The degree of the Taylor expansion.

Return type

Callable[…, Any]

Returns

A function f_tayl(new_params, *args, **kwargs) whose signature is the same as f. Here f_tayl implements the degree-order taylor series of f about params.