nt.stax – infinite NNGP and NTK

Closed-form NNGP and NTK library.

This library contains layers mimicking those in jax.example_libraries.stax with similar API apart from:

1) Instead of (init_fn, apply_fn) tuple, layers return a triple (init_fn, apply_fn, kernel_fn), where the added kernel_fn maps a Kernel to a new Kernel, and represents the change in the analytic NTK and NNGP kernels (nngp, ntk). These functions are chained / stacked together within the serial or parallel combinators, similarly to init_fn and apply_fn. For details, please see “Neural Tangents: Fast and Easy Infinite Neural Networks in Python”.

2) In layers with random weights, NTK parameterization is used by default (see page 3 in “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”). Standard parameterization can be specified for Conv and Dense layers by a keyword argument parameterization. For details, please see “On the infinite width limit of neural networks with a standard parameterization”.

3) Some functionality may be missing (e.g. jax.example_libraries.stax.BatchNorm), and some may be present only in our library (e.g. CIRCULAR padding, LayerNorm, GlobalAvgPool, GlobalSelfAttention, flexible batch and channel axes etc.).

Example

>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>>     stax.Conv(128, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(256, (3, 3)),
>>>     stax.Relu(),
>>>     stax.Conv(512, (3, 3)),
>>>     stax.Flatten(),
>>>     stax.Dense(10)
>>> )
>>> #
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>>                                                       y_train)
>>> #
>>> # (5, 10) jnp.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>> #
>>> # (5, 10) jnp.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')

Combinators

Layers to combine multiple other layers into one.

parallel(*layers)

Combinator for composing layers in parallel.

repeat(layer, n)

Compose layer in a compiled loop n times.

serial(*layers)

Combinator for composing layers in serial.

Branching

Layers to split outputs into many, or combine many into ones.

FanInConcat([axis])

Fan-in concatenation.

FanInProd()

Fan-in product.

FanInSum()

Fan-in sum.

FanOut(num)

Fan-out.

Linear parametric

Linear layers with trainable parameters.

Conv(out_chan, filter_shape[, strides, ...])

General convolution.

ConvLocal(out_chan, filter_shape[, strides, ...])

General unshared convolution.

ConvTranspose(out_chan, filter_shape[, ...])

General transpose convolution.

Dense(out_dim[, W_std, b_std, batch_axis, ...])

Dense (fully-connected, matrix product).

GlobalSelfAttention(n_chan_out, n_chan_key, ...)

Global scaled dot-product self-attention.

Linear nonparametric

Linear layers without any trainable parameters.

Aggregate([aggregate_axis, batch_axis, ...])

Aggregation operator (graphical neural network).

AvgPool(window_shape[, strides, padding, ...])

Average pooling.

DotGeneral(*[, lhs, rhs, dimension_numbers, ...])

Constant (non-trainable) rhs/lhs Dot General.

Dropout(rate[, mode])

Dropout.

Flatten([batch_axis, batch_axis_out])

Flattening all non-batch dimensions.

GlobalAvgPool([batch_axis, channel_axis])

Global average pooling.

GlobalSumPool([batch_axis, channel_axis])

Global sum pooling.

Identity()

Identity (no-op).

ImageResize(shape, method[, antialias, ...])

Image resize function mimicking jax.image.resize.

Index(idx[, batch_axis, channel_axis])

Index into the array mimicking numpy.ndarray indexing.

LayerNorm([axis, eps, batch_axis, channel_axis])

Layer normalisation.

SumPool(window_shape[, strides, padding, ...])

Sum pooling.

Elementwise nonlinear

Pointwise nonlinear layers. For details, please see “Fast Neural Kernel Embeddings for General Activations”.

ABRelu(a, b[, do_stabilize])

ABReLU nonlinearity, i.e. a * min(x, 0) + b * max(x, 0).

Abs([do_stabilize])

Absolute value nonlinearity.

Cos([a, b, c])

Affine transform of Cos nonlinearity, i.e. a cos(b*x + c).

Elementwise([fn, nngp_fn, d_nngp_fn])

Elementwise application of fn using provided nngp_fn.

ElementwiseNumerical(fn, deg[, df])

Activation function using numerical integration.

Erf([a, b, c])

Affine transform of Erf nonlinearity, i.e. a * Erf(b * x) + c.

Exp([a, b])

Elementwise natural exponent function a * jnp.exp(b * x).

ExpNormalized([gamma, shift, do_clip])

Simulates the "Gaussian normalized kernel".

Gabor()

Gabor function exp(-x^2) * sin(x).

Gaussian([a, b])

Elementwise Gaussian function a * jnp.exp(b * x**2).

Gelu([approximate])

Gelu function.

Hermite(degree)

Hermite polynomials.

LeakyRelu(alpha[, do_stabilize])

Leaky ReLU nonlinearity, i.e. alpha * min(x, 0) + max(x, 0).

Monomial(degree)

Monomials, i.e. x^degree.

Polynomial(coef)

Polynomials, i.e. coef[0] + coef[1] * x + + coef[n] * x**n.

Rbf([gamma])

Dual activation function for normalized RBF or squared exponential kernel.

RectifiedMonomial(degree)

Rectified monomials, i.e. (x >= 0) * x^degree.

Relu([do_stabilize])

ReLU nonlinearity.

Sigmoid_like()

A sigmoid like function f(x) = .5 * erf(x / 2.4020563531719796) + .5.

Sign()

Sign function.

Sin([a, b, c])

Affine transform of Sin nonlinearity, i.e. a sin(b*x + c).

Helper classes

Utility classes for specifying layer properties. For enums, strings can be passed in their place.

AggregateImplementation(value)

Implementation of the Aggregate layer.

AttentionMechanism(value)

Type of nonlinearity to use in a GlobalSelfAttention layer.

Padding(value)

Type of padding in pooling and convolutional layers.

PositionalEmbedding(value)

Type of positional embeddings to use in a GlobalSelfAttention layer.

Slice

For developers

Classes and decorators helpful for constructing your own layers.

Bool(value)

Helper trinary logic class.

Diagonal([input, output])

Helps decide whether to allow the kernel to contain diagonal entries only.

layer(layer_fn)

A convenience decorator to be added to all public layers.

requires(**static_reqs)

Returns a decorator that augments kernel_fn with consistency checks.

supports_masking(remask_kernel)

Returns a decorator that turns layers into layers supporting masking.