neural_tangents.empirical_kernel_fn

neural_tangents.empirical_kernel_fn(f, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]

Returns a function that computes single draws from NNGP and NT kernels.

Warning

Resulting kernel shape is nearly zip(f(x1).shape, f(x2).shape) subject to trace_axes and diagonal_axes parameters, which make certain assumptions about the outputs f(x) that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set both trace_axes=() and diagonal_axes=() to obtain the kernel exactly of shape zip(f(x1).shape, f(x2).shape).

For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.

Parameters:
  • f (ApplyFn) – the function whose kernel(s) (NNGP and/or NTK) we are computing. It should have the signature f(params, x, **kwargs) where params is a PyTree, x is a PyTree, and f should also return a PyTree.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]) –

    applicable only to NTK.

    A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f(params, x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).

    For example if _, f, _ = nt.stax.Aggregate(), f is called via f(params, x, pattern=pattern). By default, inputs x, patterns pattern, and outputs of f are all batched along the leading 0 dimension, and each output f(params, x, pattern=pattern)[i] only depends on the inputs x[i] and pattern[i]. In this case, we can pass vmap_axes=(0, 0, dict(pattern=0) to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively.

    This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (Union[NtkImplementation, int]) – Applicable only to NTK, an NtkImplementation value (or an int 0, 1, 2, or 3). See the NtkImplementation docstring for details.

  • _j_rules (bool) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’s jax.jacfwd or jax.jacrev. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance.

  • _s_rules (bool) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance.

  • _fwd (Optional[bool]) – Internal debugging parameter, applicable only to NTK when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow jax.jvp in intermediary primitive Jacobian dy/dw computations, False to always use jax.vjp. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.

Return type:

EmpiricalGetKernelFn

Returns:

A function to draw a single sample the NNGP and NTK empirical kernels of a given network f.