nt.empirical – finite NNGP and NTK
Compute empirical NNGP and NTK; approximate functions via Taylor series.
All functions in this module are applicable to any JAX functions of proper
signatures (not only those from
NNGP and NTK are computed using
empirical_kernel_fn (for both). The kernels have a very
specific output shape convention that may be unexpected. Further, NTK has
multiple implementations that may perform differently depending on the task.
Please read individual functions’ docstrings.
For details, please see “Fast Finite Width Neural Tangent Kernel”.
>>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> # >>> key1, key2, key3 = random.split(random.PRNGKey(1), 3) >>> x_train = random.normal(key1, (20, 32, 32, 3)) >>> y_train = random.uniform(key1, (20, 10)) >>> x_test = random.normal(key2, (5, 32, 32, 3)) >>> # >>> # A narrow CNN. >>> init_fn, f, _ = stax.serial( >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> # >>> _, params = init_fn(key3, x_train.shape) >>> # >>> # Default setting: reducing over logits; pass `vmap_axes=0` because the >>> # network is iid along the batch axis, no BatchNorm. Use default >>> # `implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION` (`1`). >>> kernel_fn = nt.empirical_kernel_fn( >>> f, trace_axes=(-1,), vmap_axes=0, >>> implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) >>> # >>> # (5, 20) jnp.ndarray test-train NNGP/NTK >>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params) >>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params) >>> # >>> # Full kernel: not reducing over logits. Use structured derivatives >>> # `implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES` (`3`) for >>> # typically faster computation and lower memory cost. >>> kernel_fn = nt.empirical_kernel_fn( >>> f, trace_axes=(), vmap_axes=0, >>> implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) >>> # >>> # (5, 20, 10, 10) jnp.ndarray test-train NNGP/NTK namedtuple. >>> k_test_train = kernel_fn(x_test, x_train, None, params) >>> # >>> # A wide FCN with lots of parameters and many (`100`) outputs. >>> init_fn, f, _ = stax.serial( >>> stax.Flatten(), >>> stax.Dense(1024), >>> stax.Relu(), >>> stax.Dense(1024), >>> stax.Relu(), >>> stax.Dense(100) >>> ) >>> # >>> _, params = init_fn(key3, x_train.shape) >>> # >>> # Use ntk-vector products >>> # (`implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS`) since the >>> # network has many parameters relative to the cost of forward pass, >>> # large outputs. >>> ntk_fn = nt.empirical_ntk_fn( >>> f, vmap_axes=0, >>> implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS) >>> # >>> # (5, 5) jnp.ndarray test-test NTK >>> ntk_test_test = ntk_fn(x_test, None, params) >>> # >>> # Compute only output variances: >>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,)) >>> # >>> # (20,) jnp.ndarray train-train diagonal NNGP >>> nngp_train_train_diag = nngp_fn(x_train, None, params)
Finite-width NNGP and/or NTK kernel functions.
Returns a function that computes single draws from NNGP and NT kernels.
Returns a function to draw a single sample the NNGP of a given network
Returns a function to draw a single sample the NTK of a given network
enum.IntEnum specifying NTK implementation method.
- class neural_tangents.NtkImplementation(value)
Implementation method of the underlying finite width NTK computation.
Below is a very brief summary of each method. For details, please see “Fast Finite Width Neural Tangent Kernel”.
0) evaluates FLOPs of all other methods at compilation time, and selects the fastest method. However, at the time it only works correctly on TPUs, and on CPU/GPU can return wrong results, which is why it is not the default. TODO(romann): revisit based on http://b/202218145.
1) computes the NTK as the outer product of two Jacobians, each computed using reverse-mode Autodiff (vector-Jacobian products, VJPs). When JITted, the contraction is performed in a layerwise fashion, so that entire Jacobians aren’t necessarily instantiated in memory at once, and the memory usage of the method can be lower than memory needed to instantiate the two Jacobians. This method is best suited for networks with small outputs (such as scalar outputs for binary classification or regression, as opposed to 1000 ImageNet classes), and an expensive forward pass relative to the number of parameters (such as CNNs, where forward pass reuses a small filter bank many times). It is also the the most reliable method, since its implementation is simplest, and reverse-mode Autodiff is most commonly used and well tested elsewhere. For this reason it is set as the default.
2) computes the NTK as a sequence of NTK-vector products, similarly to how a Jacobian is computed as a sequence of Jacobian-vector products (JVPs) or vector-Jacobian products (VJPs). This amounts to using both forward (JVPs) and reverse (VJPs) mode Autodiff, and allows to eliminate the Jacobian contraction at the expense of additional forward passes. Therefore this method is recommended for networks with a cheap forward pass relative to the number of parameters (e.g. fully-connected networks, where each parameter matrix is used only once in the forward pass), and networks with large outputs (e.g. 1000 ImageNet classes). Memory requirements of this method are same as
1). Due to reliance of forward-mode Autodiff, this method is slightly more prone to JAX and XLA bugs than
1), but overall is quite simple and reliable.
3) uses a custom JAX interpreter to compute the NTK more efficiently than other methods. It traverses the computational graph of a function in the same order as during reverse-mode Autodiff, but instead of computing VJPs, it directly computes MJJMPs, “matrix-Jacobian-Jacobian-matrix” products, which arise in the computation of an NTK. Each MJJMP computation relies on the structure in the Jacobians, hence the name. This method can be dramatically faster (up to several orders of magnitude) then other methods on fully-connected networks, and is usually faster or equivalent on CNNs, Transformers, and other architectures, but exact speedup (e.g. from no speedup to 10X) depends on each specific setting. It can also use less memory than other methods. In our experience it consistently outperforms other methods in most settings. However, its implementation is significantly more complex (hence bug-prone), and it doesn’t yet support functions using more exotic JAX primitives (e.g.
jax.checkpoint, parallel collectives such as
jax.lax.psum, compiled loops like
jax.lax.scan, etc.), which is why it is highly-recommended to try, but not set as the default yet.
A function to compute NTK-vector products without instantiating the NTK.
Returns an NTK-vector product function.
Linearization and Taylor expansion
Decorators to Taylor-expand around function parameters.
Returns a function
Returns a function