nt.experimental – prototypes

Warning

This module contains new highly-experimental prototypes. Please beware that they are not properly tested, not supported, and may suffer from sub-optimal performance. Use at your own risk!

Kernel functions

Finite-width NTK kernel function in Tensorflow. See the Python and Colab usage examples.

neural_tangents.experimental.empirical_ntk_fn_tf(f, trace_axes=(-1,), diagonal_axes=(), vmap_axes=None, implementation=NtkImplementation.JACOBIAN_CONTRACTION, _j_rules=True, _s_rules=True, _fwd=None)[source]

Returns a function to draw a single sample the NTK of a given network f.

This function follows the API of neural_tangents.empirical_ntk_fn, but is applicable to Tensorflow tf.Module, tf.keras.Model, or tf.function, via a TF->JAX->TF roundtrip using tf2jax and jax2tf. Docstring below adapted from neural_tangents.empirical_ntk_fn.

Warning

This function is experimental and risks returning wrong results or performing slowly. It is intended to demonstrate the usage of neural_tangents.empirical_ntk_fn in Tensorflow, but has not been extensively tested. Specifically, it appears to have very long compile times (but OK runtime), is prone to triggering XLA errors, and does not distinguish between trainable and non-trainable parameters of the model.

TODO(romann): support division between trainable and non-trainable variables.

TODO(romann): investigate slow compile times.

Parameters:
  • f (Union[Module, PolymorphicFunction]) –

    tf.Module or tf.function whose NTK we are computing. Must satisfy the following:

    • if a tf.function, must have the signature of f(params, x).

    • if a tf.Module, must be either a tf.keras.Model, or be callable.

    • input signature (f.input_shape for tf.Module or tf.keras.Model, or f.input_signature for tf.function) must be known.

  • trace_axes (Union[int, Sequence[int]]) – output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each axis in trace_axes). This allows to save space and compute if you are only interested in the respective trace, but also improve approximation accuracy if you know that covariance along these pairs of axes converges to a constant * identity matrix in the limit of interest (e.g. infinite width or infinite n_samples). A common use case is the channel / feature / logit axis, since activation slices along such axis are i.i.d. and the respective covariance along the respective pair of axes indeed converges to a constant-diagonal matrix in the infinite width or infinite n_samples limit. Also related to “contracting dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • diagonal_axes (Union[int, Sequence[int]]) – output axes to diagonalize the output kernel over, i.e. compute only the diagonal of the covariance along the respective pair of axes (one pair for each axis in diagonal_axes). This allows to save space and compute, if off-diagonal values along these axes are not needed, but also improve approximation accuracy if their limiting value is known theoretically, e.g. if they vanish in the limit of interest (e.g. infinite width or infinite n_samples). If you further know that on-diagonal values converge to the same constant in your limit of interest, you should specify these axes in trace_axes instead, to save even more compute and gain even more accuracy. A common use case is computing the variance (instead of covariance) along certain axes. Also related to “batch dimensions” in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral)

  • vmap_axes (Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]) –

    A triple of (in_axes, out_axes, kwargs_axes) passed to vmap to evaluate the empirical NTK in parallel ove these axes. Precisely, providing this argument implies that f.call(x, **kwargs) equals to a concatenation along out_axes of f applied to slices of x and **kwargs along in_axes and kwargs_axes. In other words, it certifies that f can be evaluated as a vmap with out_axes=out_axes over x (along in_axes) and those arguments in **kwargs that are present in kwargs_axes.keys() (along kwargs_axes.values()).

    This allows us to evaluate Jacobians much more efficiently. If vmap_axes is not a triple, it is interpreted as in_axes = out_axes = vmap_axes, kwargs_axes = {}. For example a very common use case is vmap_axes=0 for a neural network with leading (0) batch dimension, both for inputs and outputs, and no interactions between different elements of the batch (e.g. no BatchNorm, and, in the case of nt.stax, also no Dropout). However, if there is interaction between batch elements or no concept of a batch axis at all, vmap_axes must be set to None, to avoid wrong (and potentially silent) results.

  • implementation (Union[NtkImplementation, int]) – An NtkImplementation value (or an int 0, 1, 2, or 3). See the NtkImplementation docstring for details.

  • _j_rules (bool) – Internal debugging parameter, applicable only when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow custom Jacobian rules for intermediary primitive dy/dw computations for MJJMPs (matrix-Jacobian-Jacobian-matrix products). Set to False to use JVPs or VJPs, via JAX’s jax.jacfwd or jax.jacrev. Custom Jacobian rules (True) are expected to be not worse, and sometimes better than automated alternatives, but in case of a suboptimal implementation setting it to False could improve performance.

  • _s_rules (bool) – Internal debugging parameter, applicable only when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow efficient MJJMp rules for structured dy/dw primitive Jacobians. In practice should be set to True, and setting it to False can lead to dramatic deterioration of performance.

  • _fwd (Optional[bool]) – Internal debugging parameter, applicable only when implementation is STRUCTURED_DERIVATIVES (3) or AUTO (0). Set to True to allow jax.jvp in intermediary primitive Jacobian dy/dw computations, False to always use jax.vjp. None to decide automatically based on input/output sizes. Applicable when _j_rules=False, or when a primitive does not have a Jacobian rule. Should be set to None for best performance.

Return type:

Callable[..., Any]

Returns:

A function ntk_fn that computes the empirical ntk.

Helper functions

A helper function to convert Tensorflow stateful models into functional-style, stateless apply_fn(params, x) forward pass function and extract the respective params.

neural_tangents.experimental.get_apply_fn_and_params(f)[source]

Converts a tf.Module into a forward-pass apply_fn and params.

Use this function to extract params to pass to the Tensorflow empirical NTK kernel function.

Warning

This function does not distinguish between trainable and non-trainable parameters of the model.

Parameters:

f (Module) – a tf.Module to convert to a apply_fn(params, x) function. Must have an input_shape attribute set (specifying shape of x), and be callable or be a tf.keras.Model.

Returns:

A tuple fo (apply_fn, params), where params is a PyTree[tf.Tensor].