neural_tangents.empirical_ntk_vp_fn

neural_tangents.empirical_ntk_vp_fn(f, x1, x2, params, **apply_fn_kwargs)[source]

Returns an NTK-vector product function.

The function computes NTK-vector product without instantiating the NTK, and has the runtime equivalent to (N1 + N2) forward passes through f, and memory equivalent to evaluating a vector-Jacobian product of f.

For details, please see section L of “Fast Finite Width Neural Tangent Kernel”.

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4)
>>> x1 = random.normal(k1, (20, 32, 32, 3))
>>> x2 = random.normal(k2, (10, 32, 32, 3))
>>> #
>>> # Define a forward-pass function `f`.
>>> init_fn, f, _ = stax.serial(
>>>     stax.Conv(32, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(32, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(32, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> # Initialize parameters.
>>> _, params = init_fn(k3, x1.shape)
>>> #
>>> # NTK-vp function. Can/should be JITted.
>>> ntk_vp_fn = empirical_ntk_vp_fn(f, x1, x2, params)
>>> #
>>> # Cotangent vector
>>> cotangents = random.normal(k4, f(params, x2).shape)
>>> #
>>> # NTK-vp output
>>> ntk_vp = ntk_vp_fn(cotangents)
>>> #
>>> # Output has same shape as `f(params, x1)`.
>>> assert ntk_vp.shape == f(params, x1).shape
Parameters:
  • f (ApplyFn) – forward-pass function of signature f(params, x).

  • x1 (Any) – first batch of inputs.

  • x2 (Optional[Any]) – second batch of inputs. x2=None means x2=x1.

  • params (Any) – A PyTree of parameters about which we would like to compute the neural tangent kernel.

  • **apply_fn_kwargs – keyword arguments passed to f. apply_fn_kwargs will be split into apply_fn_kwargs1 and apply_fn_kwargs2 by the split_kwargs function which will be passed to f. In particular, the rng key in apply_fn_kwargs, will be split into two different (if x1!=x2) or same (if x1==x2) rng keys. See the _read_key function for more details.

Return type:

Callable[[Any], Any]

Returns:

An NTK-vector product function accepting a PyTree of cotangents of shape and structure of f(params, x2), and returning the NTK-vector product of shape and structure of f(params, x1).