neural_tangents.empirical_ntk_vp_fn
- neural_tangents.empirical_ntk_vp_fn(f, x1, x2, params, **apply_fn_kwargs)[source]
Returns an NTK-vector product function.
The function computes NTK-vector product without instantiating the NTK, and has the runtime equivalent to
(N1 + N2)
forward passes throughf
, and memory equivalent to evaluating a vector-Jacobian product off
.For details, please see section L of “Fast Finite Width Neural Tangent Kernel”.
Example
>>> from jax import random >>> import neural_tangents as nt >>> from neural_tangents import stax >>> # >>> k1, k2, k3, k4 = random.split(random.PRNGKey(1), 4) >>> x1 = random.normal(k1, (20, 32, 32, 3)) >>> x2 = random.normal(k2, (10, 32, 32, 3)) >>> # >>> # Define a forward-pass function `f`. >>> init_fn, f, _ = stax.serial( >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Relu(), >>> stax.Conv(32, (3, 3)), >>> stax.Flatten(), >>> stax.Dense(10) >>> ) >>> # >>> # Initialize parameters. >>> _, params = init_fn(k3, x1.shape) >>> # >>> # NTK-vp function. Can/should be JITted. >>> ntk_vp_fn = empirical_ntk_vp_fn(f, x1, x2, params) >>> # >>> # Cotangent vector >>> cotangents = random.normal(k4, f(params, x2).shape) >>> # >>> # NTK-vp output >>> ntk_vp = ntk_vp_fn(cotangents) >>> # >>> # Output has same shape as `f(params, x1)`. >>> assert ntk_vp.shape == f(params, x1).shape
- Parameters:
f (
ApplyFn
) – forward-pass function of signaturef(params, x)
.x1 (
Any
) – first batch of inputs.x2 (
Optional
[Any
]) – second batch of inputs.x2=None
meansx2=x1
.params (
Any
) – APyTree
of parameters about which we would like to compute the neural tangent kernel.**apply_fn_kwargs – keyword arguments passed to
f
.apply_fn_kwargs
will be split intoapply_fn_kwargs1
andapply_fn_kwargs2
by thesplit_kwargs
function which will be passed tof
. In particular, the rng key inapply_fn_kwargs
, will be split into two different (ifx1!=x2
) or same (ifx1==x2
) rng keys. See the_read_key
function for more details.
- Return type:
- Returns:
An NTK-vector product function accepting a
PyTree
of cotangents of shape and structure off(params, x2)
, and returning the NTK-vector product of shape and structure off(params, x1)
.