nt.stax
– infinite NNGP and NTK
Closed-form NNGP and NTK library.
This library contains layers mimicking those in
jax.example_libraries.stax
with similar API apart from:
1) Instead of (init_fn, apply_fn)
tuple, layers return a triple
(init_fn, apply_fn, kernel_fn)
, where the added kernel_fn
maps a
Kernel
to a new Kernel
, and
represents the change in the analytic NTK and NNGP kernels
(nngp
, ntk
).
These functions are chained / stacked together within the serial
or
parallel
combinators, similarly to init_fn
and apply_fn
.
For details, please see “Neural Tangents: Fast and Easy Infinite Neural
Networks in Python”.
2) In layers with random weights, NTK parameterization is used by default
(see page 3 in
“Neural Tangent Kernel: Convergence and Generalization in Neural Networks”). Standard parameterization can be
specified for Conv
and Dense
layers by a keyword argument
parameterization
. For details, please see “On the infinite width limit of
neural networks with a standard parameterization”.
3) Some functionality may be missing (e.g.
jax.example_libraries.stax.BatchNorm
), and some may be
present only in our library (e.g. CIRCULAR
padding,
LayerNorm
, GlobalAvgPool
, GlobalSelfAttention
, flexible
batch and channel axes etc.).
Example
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>> #
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>> y_train)
>>> #
>>> # (5, 10) np.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>> #
>>> # (5, 10) np.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')
Combinators
Layers to combine multiple other layers into one.
|
Combinator for composing layers in parallel. |
|
Compose |
|
Combinator for composing layers in serial. |
Branching
Layers to split outputs into many, or combine many into ones.
|
Fan-in concatenation. |
Fan-in product. |
|
|
Fan-in sum. |
|
Fan-out. |
Linear parametric
Linear layers with trainable parameters.
|
General convolution. |
|
General unshared convolution. |
|
General transpose convolution. |
|
Dense (fully-connected, matrix product). |
|
Global scaled dot-product self-attention. |
Linear nonparametric
Linear layers without any trainable parameters.
|
Aggregation operator (graphical neural network). |
|
Average pooling. |
|
Constant (non-trainable) rhs/lhs Dot General. |
|
Dropout. |
|
Flattening all non-batch dimensions. |
|
Global average pooling. |
|
Global sum pooling. |
|
Identity (no-op). |
|
Image resize function mimicking |
|
Index into the array mimicking |
|
Layer normalisation. |
|
Sum pooling. |
Elementwise nonlinear
Pointwise nonlinear layers. For details, please see “Fast Neural Kernel Embeddings for General Activations”.
|
ABReLU nonlinearity, i.e. |
|
Absolute value nonlinearity. |
|
Affine transform of |
|
Elementwise application of |
|
Activation function using numerical integration. |
|
Affine transform of |
|
Elementwise natural exponent function |
|
Simulates the "Gaussian normalized kernel". |
|
Gabor function |
|
Elementwise Gaussian function |
|
Gelu function. |
|
Hermite polynomials. |
|
Leaky ReLU nonlinearity, i.e. |
|
Monomials, i.e. |
|
Polynomials, i.e. |
|
Dual activation function for normalized RBF or squared exponential kernel. |
|
Rectified monomials, i.e. |
|
ReLU nonlinearity. |
A sigmoid like function |
|
|
Sign function. |
|
Affine transform of |
Helper classes
Utility classes for specifying layer properties. For enums, strings can be passed in their place.
|
Implementation of the |
|
Type of nonlinearity to use in a |
|
Type of padding in pooling and convolutional layers. |
|
Type of positional embeddings to use in a |
For developers
Classes and decorators helpful for constructing your own layers.
|
Helper trinary logic class. |
|
Helps decide whether to allow the kernel to contain diagonal entries only. |
|
A convenience decorator to be added to all public layers. |
|
Returns a decorator that augments |
|
Returns a decorator that turns layers into layers supporting masking. |