neural_tangents.empirical_kernel_fn
- neural_tangents.empirical_kernel_fn(f, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]
Returns a function that computes single draws from NNGP and NT kernels.
Warning
Resulting kernel shape is nearly
zip(f(x1).shape, f(x2).shape)subject totrace_axesanddiagonal_axesparameters, which make certain assumptions about the outputsf(x)that may only be true in the infinite width / infinite number of samples limit, or may not apply to your architecture. For most precise results in the context of linearized training dynamics of a specific finite-width network, set bothtrace_axes=()anddiagonal_axes=()to obtain the kernel exactly of shapezip(f(x1).shape, f(x2).shape).For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal the empirical kernels will have terms measuring the covariance between the outputs. Here, we ignore these cross-terms and consider each output separately. Please raise an issue if this feature is important to you.
- Parameters:
f (
ApplyFn) – the function whose kernel(s) (NNGP and/or NTK) we are computing. It should have the signaturef(params, x, **kwargs)whereparamsis aPyTree,xis aPyTree, andfshould also return aPyTree.trace_axes (
Union[int,Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis intrace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to aconstant * identity matrixin the limit of interest (e.g. infinite width or infiniten_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infiniten_sampleslimit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)diagonal_axes (
Union[int,Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis indiagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infiniten_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes intrace_axesinstead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)vmap_axes (
Union[Any,None,tuple[Optional[Any],Optional[Any],dict[str,Optional[Any]]]]) –applicable only to NTK.
A triple of
(in_axes, out_axes, kwargs_axes)passed tovmapto evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies thatf(params, x, **kwargs)equals to a concatenation alongout_axesoffapplied to slices ofxand**kwargsalongin_axesandkwargs_axes. In other words, it certifies thatfcan be evaluated as avmapwithout_axes=out_axesoverx(alongin_axes) and those arguments in**kwargsthat are present inkwargs_axes.keys()(alongkwargs_axes.values()).For example if
_, f, _ = nt.stax.Aggregate(),fis called viaf(params, x, pattern=pattern). By default, inputsx, patternspattern, and outputs offare all batched along the leading0dimension, and each outputf(params, x, pattern=pattern)[i]only depends on the inputsx[i]andpattern[i]. In this case, we can passvmap_axes=(0, 0, dict(pattern=0)to specify along which dimensions inputs, outputs, and keyword arguments are batched respectively.This allows us to evaluate Jacobians much more efficiently. If
vmap_axesis not a triple, it is interpreted asin_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case isvmap_axes=0for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case ofnt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all,vmap_axesmust be set toNone, to avoid wrong (and potentially silent) results.implementation (
Union[NtkImplementation,int]) – Applicable only to NTK, anNtkImplementationvalue (or anint0,1,2, or3). See theNtkImplementationdocstring for details._j_rules (
bool) – Internal debugging parameter, applicable only to NTK whenimplementationisSTRUCTURED_DERIVATIVES(3) orAUTO(0). Set toTrueto allow custom Jacobian rules for intermediary primitivedy/dwcomputations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set toFalseto use JVPs or VJPs, via JAX’sjax.jacfwdorjax.jacrev. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it toFalsecould improve performance._s_rules (
bool) – Internal debugging parameter, applicable only to NTK whenimplementationisSTRUCTURED_DERIVATIVES(3) orAUTO(0). Set toTrueto allow efficient MJJMp rules for structureddy/dwprimitive Jacobians. In practice should be set toTrue, and setting it toFalsecan lead to dramatic deterioration of performance._fwd (
Optional[bool]) – Internal debugging parameter, applicable only to NTK whenimplementationisSTRUCTURED_DERIVATIVES(3) orAUTO(0). Set toTrueto allowjax.jvpin intermediary primitive Jacobiandy/dwcomputations,Falseto always usejax.vjp.Noneto decide automatically based on input/output sizes. Applicable when_j_rules=False, or when a primitive does not have a Jacobian rule. Should be set toNonefor best performance.
- Return type:
- Returns:
A function to draw a single sample the NNGP and NTK empirical kernels of a given network
f.