Typing
Common Type Definitions.
- class neural_tangents._src.utils.typing.AnalyticKernelFn(*args, **kwds)[source]
A type alias for analytic kernel functions.
A kernel function that computes an analytic kernel. Takes either a
Kernel
orjax.numpy.ndarray
inputs and aget
argument that specifies what quantities should be computed by the kernel. Returns either aKernel
object orjax.numpy.ndarray
-s for kernels specified byget
.
- class neural_tangents._src.utils.typing.ApplyFn(*args, **kwds)[source]
A type alias for apply functions.
Apply functions do computations with finite-width neural networks. They are functions that take a PyTree of parameters and an array of inputs and produce an array of outputs.
- neural_tangents._src.utils.typing.Axes
Axes specification, can be integers (
axis=-1
) or sequences (axis=(1, 3)
).
- class neural_tangents._src.utils.typing.EmpiricalGetKernelFn(*args, **kwds)[source]
A type alias for empirical kernel functions accepting a
get
argument.A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.
Equivalent to
EmpiricalKernelFn
, but accepts aget
argument, which can be for exampleget=("nngp", "ntk")
, to compute both kernels together.
- class neural_tangents._src.utils.typing.EmpiricalKernelFn(*args, **kwds)[source]
A type alias for empirical kernel functions computing either NTK or NNGP.
A kernel function that produces an empirical kernel from a single instantiation of a neural network specified by its parameters.
Equivalent to
EmpiricalGetKernelFn
withget="nngp"
orget="ntk"
.
- class neural_tangents._src.utils.typing.InitFn(*args, **kwds)[source]
A type alias for initialization functions.
Initialization functions construct parameters for neural networks given a random key and an input shape. Specifically, they produce a tuple giving the output shape and a PyTree of parameters.
- neural_tangents._src.utils.typing.Kernels
Kernel inputs/outputs of
FanOut
,FanInSum
, etc.
- class neural_tangents._src.utils.typing.LayerKernelFn(*args, **kwds)[source]
A type alias for pure kernel functions.
A pure kernel function takes a PyTree of Kernel object(s) and produces a PyTree of Kernel object(s). These functions are used to define new layer types.
- class neural_tangents._src.utils.typing.MaskFn(*args, **kwds)[source]
A type alias for a masking functions.
Forward-propagate a mask in a layer of a finite-width network.
- class neural_tangents._src.utils.typing.MonteCarloKernelFn(*args, **kwds)[source]
A type alias for Monte Carlo kernel functions.
A kernel function that produces an estimate of an
AnalyticKernel
by monte carlo sampling given aPRNGKey
.
- neural_tangents._src.utils.typing.NTTree
Neural Tangents Tree.
Trees of kernels and arrays naturally emerge in certain neural network computations (for example, when neural networks have nested parallel layers).
Mimicking JAX, we use a lightweight tree structure called an
NTTree
.NTTree
has internal nodes that are either lists or tuples and leaves which are eitherjax.numpy.ndarray
orKernel
objects.