Common Type Definitions.

class neural_tangents._src.utils.typing.AnalyticKernelFn(*args, **kwargs)[source]

A type alias for analytic kernel functions.

A kernel function that computes an analytic kernel. Takes either a Kernel or jax.numpy.ndarray inputs and a get argument that specifies what quantities should be computed by the kernel. Returns either a Kernel object or jax.numpy.ndarray-s for kernels specified by get.

class neural_tangents._src.utils.typing.ApplyFn(*args, **kwargs)[source]

A type alias for apply functions.

Apply functions do computations with finite-width neural networks. They are functions that take a PyTree of parameters and an array of inputs and produce an array of outputs.


Axes specification, can be integers (axis=-1) or sequences (axis=(1, 3)).

alias of Union[int, Sequence[int]]

class neural_tangents._src.utils.typing.EmpiricalGetKernelFn(*args, **kwargs)[source]

A type alias for empirical kernel functions accepting a get argument.

A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.

Equivalent to EmpiricalKernelFn, but accepts a get argument, which can be for example get=("nngp", "ntk"), to compute both kernels together.

class neural_tangents._src.utils.typing.EmpiricalKernelFn(*args, **kwargs)[source]

A type alias for empirical kernel functions computing either NTK or NNGP.

A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.

Equivalent to EmpiricalGetKernelFn with get="nngp" or get="ntk".

class neural_tangents._src.utils.typing.InitFn(*args, **kwargs)[source]

A type alias for initialization functions.

Initialization functions construct parameters for neural networks given a random key and an input shape. Specifically, they produce a tuple giving the output shape and a PyTree of parameters.


Kernel inputs/outputs of FanOut, FanInSum, etc.

alias of Union[list[Kernel], tuple[Kernel, …]]

class neural_tangents._src.utils.typing.LayerKernelFn(*args, **kwargs)[source]

A type alias for pure kernel functions.

A pure kernel function takes a PyTree of Kernel object(s) and produces a PyTree of Kernel object(s). These functions are used to define new layer types.

class neural_tangents._src.utils.typing.MaskFn(*args, **kwargs)[source]

A type alias for a masking functions.

Forward-propagate a mask in a layer of a finite-width network.

class neural_tangents._src.utils.typing.MonteCarloKernelFn(*args, **kwargs)[source]

A type alias for Monte Carlo kernel functions.

A kernel function that produces an estimate of an AnalyticKernel by monte carlo sampling given a PRNGKey.


Neural Tangents Tree.

Trees of kernels and arrays naturally emerge in certain neural network computations (for example, when neural networks have nested parallel layers).

Mimicking JAX, we use a lightweight tree structure called an NTTree. NTTree has internal nodes that are either lists or tuples and leaves which are either jax.numpy.ndarray or Kernel objects.

alias of Union[list[T], tuple[T, …], T]


A list or tuple of NTTree s.

alias of Union[list[T], tuple[T, …]]

neural_tangents._src.utils.typing.PyTree = typing.Any

A PyTree, see JAX docs for details.


A shape - a tuple of integers, or an NTTree of such tuples.

alias of Union[list[tuple[int, …]], tuple[tuple[int, …], …], tuple[int, …]]


Specifies (input, output, kwargs) axes for vmap in empirical NTK.

alias of Union[Any, None, tuple[Optional[Any], Optional[Any], dict[str, Optional[Any]]]]