nt.empirical
– finite NNGP and NTK
Compute empirical NNGP and NTK; approximate functions via Taylor series.
All functions in this module are applicable to any JAX functions of proper
signatures (not only those from stax
).
NNGP and NTK are computed using empirical_nngp_fn
,
empirical_ntk_fn
, or
empirical_kernel_fn
(for both). The kernels have a very
specific output shape convention that may be unexpected. Further, NTK has
multiple implementations that may perform differently depending on the task.
Please read individual functions’ docstrings.
For details, please see “Fast Finite Width Neural Tangent Kernel”.
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2, key3 = random.split(random.PRNGKey(1), 3)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> # A narrow CNN.
>>> init_fn, f, _ = stax.serial(
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(32, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>> #
>>> _, params = init_fn(key3, x_train.shape)
>>> #
>>> # Default setting: reducing over logits; pass `vmap_axes=0` because the
>>> # network is iid along the batch axis, no BatchNorm. Use default
>>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`).
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(1,), vmap_axes=0,
>>> implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
>>> #
>>> # (5, 20) jnp.ndarray testtrain NNGP/NTK
>>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params)
>>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
>>> #
>>> # Full kernel: not reducing over logits. Use structured derivatives
>>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for
>>> # typically faster computation and lower memory cost.
>>> kernel_fn = nt.empirical_kernel_fn(
>>> f, trace_axes=(), vmap_axes=0,
>>> implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
>>> #
>>> # (5, 20, 10, 10) jnp.ndarray testtrain NNGP/NTK namedtuple.
>>> k_test_train = kernel_fn(x_test, x_train, None, params)
>>> #
>>> # A wide FCN with lots of parameters and many (`100`) outputs.
>>> init_fn, f, _ = stax.serial(
>>> stax.Flatten(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(1024),
>>> stax.Relu(),
>>> stax.Dense(100)
>>> )
>>> #
>>> _, params = init_fn(key3, x_train.shape)
>>> #
>>> # Use ntkvector products
>>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the
>>> # network has many parameters relative to the cost of forward pass,
>>> # large outputs.
>>> ntk_fn = nt.empirical_ntk_fn(
>>> f, vmap_axes=0,
>>> implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
>>> #
>>> # (5, 5) jnp.ndarray testtest NTK
>>> ntk_test_test = ntk_fn(x_test, None, params)
>>> #
>>> # Compute only output variances:
>>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,))
>>> #
>>> # (20,) jnp.ndarray traintrain diagonal NNGP
>>> nngp_train_train_diag = nngp_fn(x_train, None, params)
Kernel functions
Finitewidth NNGP and/or NTK kernel functions.

Returns a function that computes single draws from NNGP and NT kernels. 

Returns a function to draw a single sample the NNGP of a given network 

Returns a function to draw a single sample the NTK of a given network 
NTK implementation
An enum.IntEnum
specifying NTK implementation method.
 class neural_tangents.NtkImplementation(value)[source]
Implementation method of the underlying finite width NTK computation.
Below is a very brief summary of each method. For details, please see “Fast Finite Width Neural Tangent Kernel”.
 AUTO
(or
0
) evaluates FLOPs of all other methods at compilation time, and selects the fastest method. However, at the time it only works correctly on TPUs, and on CPU/GPU can return wrong results, which is why it is not the default. TODO(romann): revisit based on http://b/202218145.
 JACOBIAN_CONTRACTION
(or
1
) computes the NTK as the outer product of two Jacobians, each computed using reversemode Autodiff (vectorJacobian products, VJPs). When JITted, the contraction is performed in a layerwise fashion, so that entire Jacobians aren’t necessarily instantiated in memory at once, and the memory usage of the method can be lower than memory needed to instantiate the two Jacobians. This method is best suited for networks with small outputs (such as scalar outputs for binary classification or regression, as opposed to 1000 ImageNet classes), and an expensive forward pass relative to the number of parameters (such as CNNs, where forward pass reuses a small filter bank many times). It is also the the most reliable method, since its implementation is simplest, and reversemode Autodiff is most commonly used and well tested elsewhere. For this reason it is set as the default.
 NTK_VECTOR_PRODUCTS
(or
2
) computes the NTK as a sequence of NTKvector products, similarly to how a Jacobian is computed as a sequence of Jacobianvector products (JVPs) or vectorJacobian products (VJPs). This amounts to using both forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate the Jacobian contraction at the expense of additional forward passes. Therefore this method is recommended for networks with a cheap forward pass relative to the number of parameters (e.g. fullyconnected networks, where each parameter matrix is used only once in the forward pass), and networks with large outputs (e.g. 1000 ImageNet classes). Memory requirements of this method are same asJACOBIAN_CONTRACTION
(1
). Due to reliance of forwardmode Autodiff, this method is slightly more prone to JAX and XLA bugs thanJACOBIAN_CONTRACTION
(1
), but overall is quite simple and reliable.
 STRUCTURED_DERIVATIVES
(or
3
) uses a custom JAX interpreter to compute the NTK more efficiently than other methods. It traverses the computational graph of a function in the same order as during reversemode Autodiff, but instead of computing VJPs, it directly computes MJJMPs, “matrixJacobianJacobianmatrix” products, which arise in the computation of an NTK. Each MJJMP computation relies on the structure in the Jacobians, hence the name. This method can be dramatically faster (up to several orders of magnitude) then other methods on fullyconnected networks, and is usually faster or equivalent on CNNs, Transformers, and other architectures, but exact speedup (e.g. from no speedup to 10X) depends on each specific setting. It can also use less memory than other methods. In our experience it consistently outperforms other methods in most settings. However, its implementation is significantly more complex (hence bugprone), and it doesn’t yet support functions using more exotic JAX primitives (e.g.jax.checkpoint
, parallel collectives such asjax.lax.psum
, compiled loops likejax.lax.scan
, etc.), which is why it is highlyrecommended to try, but not set as the default yet.
NTKvector products
A function to compute NTKvector products without instantiating the NTK.

Returns an NTKvector product function. 
Linearization and Taylor expansion
Decorators to Taylorexpand around function parameters.

Returns a function 

Returns a function 