nt.stax – infinite NNGP and NTK

Closed-form NNGP and NTK library.

This library contains layers mimicking those in jax.example_libraries.stax with similar API apart from:

1) Instead of (init_fn, apply_fn) tuple, layers return a triple (init_fn, apply_fn, kernel_fn), where the added kernel_fn maps a Kernel to a new Kernel, and represents the change in the analytic NTK and NNGP kernels (nngp, ntk). These functions are chained / stacked together within the serial or parallel combinators, similarly to init_fn and apply_fn. For details, please see “Neural Tangents: Fast and Easy Infinite Neural Networks in Python”.

2) In layers with random weights, NTK parameterization is used by default (see page 3 in “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”). Standard parameterization can be specified for Conv and Dense layers by a keyword argument parameterization. For details, please see “On the infinite width limit of neural networks with a standard parameterization”.

3) Some functionality may be missing (e.g. jax.example_libraries.stax.BatchNorm), and some may be present only in our library (e.g. CIRCULAR padding, LayerNorm, GlobalAvgPool, GlobalSelfAttention, flexible batch and channel axes etc.).


>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(256, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(512, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>>                                                       y_train)
>>> #
>>> # (5, 10) jnp.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>> #
>>> # (5, 10) jnp.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')


Layers to combine multiple other layers into one.


Combinator for composing layers in parallel.

repeat(layer, n)

Compose layer in a compiled loop n times.


Combinator for composing layers in serial.


Layers to split outputs into many, or combine many into ones.


Fan-in concatenation.


Fan-in product.


Fan-in sum.



Linear parametric

Linear layers with trainable parameters.

Conv(out_chan, filter_shape[, strides, ...])

General convolution.

ConvLocal(out_chan, filter_shape[, strides, ...])

General unshared convolution.

ConvTranspose(out_chan, filter_shape[, ...])

General transpose convolution.

Dense(out_dim[, W_std, b_std, batch_axis, ...])

Dense (fully-connected, matrix product).

GlobalSelfAttention(n_chan_out, n_chan_key, ...)

Global scaled dot-product self-attention.

Linear nonparametric

Linear layers without any trainable parameters.

Aggregate([aggregate_axis, batch_axis, ...])

Aggregation operator (graphical neural network).

AvgPool(window_shape[, strides, padding, ...])

Average pooling.

DotGeneral(*[, lhs, rhs, dimension_numbers, ...])

Constant (non-trainable) rhs/lhs Dot General.

Dropout(rate[, mode])


Flatten([batch_axis, batch_axis_out])

Flattening all non-batch dimensions.

GlobalAvgPool([batch_axis, channel_axis])

Global average pooling.

GlobalSumPool([batch_axis, channel_axis])

Global sum pooling.


Identity (no-op).

ImageResize(shape, method[, antialias, ...])

Image resize function mimicking jax.image.resize.

Index(idx[, batch_axis, channel_axis])

Index into the array mimicking numpy.ndarray indexing.

LayerNorm([axis, eps, batch_axis, channel_axis])

Layer normalisation.

SumPool(window_shape[, strides, padding, ...])

Sum pooling.

Elementwise nonlinear

Pointwise nonlinear layers. For details, please see “Fast Neural Kernel Embeddings for General Activations”.

ABRelu(a, b[, do_stabilize])

ABReLU nonlinearity, i.e. a * min(x, 0) + b * max(x, 0).


Absolute value nonlinearity.

Cos([a, b, c])

Affine transform of Cos nonlinearity, i.e. a cos(b*x + c).

Elementwise([fn, nngp_fn, d_nngp_fn])

Elementwise application of fn using provided nngp_fn.

ElementwiseNumerical(fn, deg[, df])

Activation function using numerical integration.

Erf([a, b, c])

Affine transform of Erf nonlinearity, i.e. a * Erf(b * x) + c.

Exp([a, b])

Elementwise natural exponent function a * jnp.exp(b * x).

ExpNormalized([gamma, shift, do_clip])

Simulates the "Gaussian normalized kernel".


Gabor function exp(-x^2) * sin(x).

Gaussian([a, b])

Elementwise Gaussian function a * jnp.exp(b * x**2).


Gelu function.


Hermite polynomials.

LeakyRelu(alpha[, do_stabilize])

Leaky ReLU nonlinearity, i.e. alpha * min(x, 0) + max(x, 0).


Monomials, i.e. x^degree.


Polynomials, i.e. coef[0] + coef[1] * x + + coef[n] * x**n.


Dual activation function for normalized RBF or squared exponential kernel.


Rectified monomials, i.e. (x >= 0) * x^degree.


ReLU nonlinearity.


A sigmoid like function f(x) = .5 * erf(x / 2.4020563531719796) + .5.


Sign function.

Sin([a, b, c])

Affine transform of Sin nonlinearity, i.e. a sin(b*x + c).

Helper classes

Utility classes for specifying layer properties. For enums, strings can be passed in their place.


Implementation of the Aggregate layer.


Type of nonlinearity to use in a GlobalSelfAttention layer.


Type of padding in pooling and convolutional layers.


Type of positional embeddings to use in a GlobalSelfAttention layer.


For developers

Classes and decorators helpful for constructing your own layers.


Helper trinary logic class.

Diagonal([input, output])

Helps decide whether to allow the kernel to contain diagonal entries only.


A convenience decorator to be added to all public layers.


Returns a decorator that augments kernel_fn with consistency checks.


Returns a decorator that turns layers into layers supporting masking.