nt.experimental
– prototypes
Warning
This module contains new highly-experimental prototypes. Please beware that they are not properly tested, not supported, and may suffer from sub-optimal performance. Use at your own risk!
Kernel functions
Finite-width NTK kernel function in Tensorflow. See the Python and Colab usage examples.
- neural_tangents.experimental.empirical_ntk_fn_tf(f, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]
Returns a function to draw a single sample the NTK of a given network
f
.This function follows the API of
neural_tangents.empirical_ntk_fn
, but is applicable to Tensorflowtf.Module
,tf.keras.Model
, ortf.function
, via a TF->JAX->TF roundtrip usingtf2jax
andjax2tf
. Docstring below adapted fromneural_tangents.empirical_ntk_fn
.Warning
This function is experimental and risks returning wrong results or performing slowly. It is intended to demonstrate the usage of
neural_tangents.empirical_ntk_fn
in Tensorflow, but has not been extensively tested. Specifically, it appears to have very long compile times (but OK runtime), is prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model.TODO(romann): support division between trainable and non-trainable variables.
TODO(romann): investigate slow compile times.
- Parameters:
f (
Union
[Module
,GenericFunction
]) –tf.Module
ortf.function
whose NTK we are computing. Must satisfy the following:if a
tf.function
, must have the signature off(params, x)
.if a
tf.Module
, must be either atf.keras.Model
, or be callable.input signature (
f.input_shape
fortf.Module
ortf.keras.Model
, orf.input_signature
fortf.function
) must be known.
trace_axes (
Union
[int
,Sequence
[int
]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis intrace_axes
). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to aconstant * identity matrix
in the limit of interest (e.g. infinite width or infiniten_samples
). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infiniten_samples
limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)diagonal_axes (
Union
[int
,Sequence
[int
]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis indiagonal_axes
). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infiniten_samples
). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes intrace_axes
instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)vmap_axes (
Union
[Any
,None
,tuple
[Optional
[Any
],Optional
[Any
],dict
[str
,Optional
[Any
]]]]) –A triple of
(in_axes, out_axes, kwargs_axes)
passed tovmap
to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies thatf.call(x, **kwargs)
equals to a concatenation alongout_axes
off
applied to slices ofx
and**kwargs
alongin_axes
andkwargs_axes
. In other words, it certifies thatf
can be evaluated as avmap
without_axes=out_axes
overx
(alongin_axes
) and those arguments in**kwargs
that are present inkwargs_axes.keys()
(alongkwargs_axes.values()
).This allows us to evaluate Jacobians much more efficiently. If
vmap_axes
is not a triple, it is interpreted asin_axes = out_axes = vmap_axes, kwargs_axes = {}
. For example a very common use case isvmap_axes=0
for a neural network with leading (0
) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case ofnt.stax
, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all,vmap_axes
must be set toNone
, to avoid wrong (and potentially silent) results.implementation (
Union
[NtkImplementation
,int
]) – AnNtkImplementation
value (or anint
0
,1
,2
, or3
). See theNtkImplementation
docstring for details._j_rules (
bool
) – Internal debugging parameter, applicable only whenimplementation
isSTRUCTURED_DERIVATIVES
(3
) orAUTO
(0
). Set toTrue
to allow custom Jacobian rules for intermediary primitivedy/dw
computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set toFalse
to use JVPs or VJPs, via JAX’sjax.jacfwd
orjax.jacrev
. Custom Jacobian rules (True
) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it toFalse
could improve performance._s_rules (
bool
) – Internal debugging parameter, applicable only whenimplementation
isSTRUCTURED_DERIVATIVES
(3
) orAUTO
(0
). Set toTrue
to allow efficient MJJMp rules for structureddy/dw
primitive Jacobians. In practice should be set toTrue
, and setting it toFalse
can lead to dramatic deterioration of performance._fwd (
Optional
[bool
]) – Internal debugging parameter, applicable only whenimplementation
isSTRUCTURED_DERIVATIVES
(3
) orAUTO
(0
). Set toTrue
to allowjax.jvp
in intermediary primitive Jacobiandy/dw
computations,False
to always usejax.vjp
.None
to decide automatically based on input/output sizes. Applicable when_j_rules=False
, or when a primitive does not have a Jacobian rule. Should be set toNone
for best performance.
- Return type:
- Returns:
A function
ntk_fn
that computes the empirical ntk.
Helper functions
A helper function to convert Tensorflow stateful models into functional-style, stateless apply_fn(params, x)
forward pass function and extract the respective params
.
- neural_tangents.experimental.get_apply_fn_and_params(f)[source]
Converts a
tf.Module
into a forward-passapply_fn
andparams
.Use this function to extract
params
to pass to the Tensorflow empirical NTK kernel function.Warning
This function does not distinguish between trainable and non-trainable parameters of the model.
- Parameters:
f (
Module
) – atf.Module
to convert to aapply_fn(params, x)
function. Must have aninput_shape
attribute set (specifying shape ofx
), and be callable or be atf.keras.Model
.- Returns:
A tuple fo
(apply_fn, params)
, whereparams
is aPyTree[tf.Tensor]
.